In [None]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np

# DATASET HANDLING

In [None]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform

        self.image_dir = os.path.join(root_dir, "Tissues")
        self.mask_dir = os.path.join(root_dir, "Masks", "binary_mask")

        self.image_files = os.listdir(self.image_dir)
        self.mask_files = os.listdir(self.mask_dir)

        # Ensure the lists are sorted for matching image-mask pairs
        self.image_files.sort()
        self.mask_files.sort()

        # Check if the number of images and masks match
        assert len(self.image_files) == len(self.mask_files), "Number of images and masks do not match!"

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_name = self.image_files[idx]
        mask_name = self.mask_files[idx]

        image_path = os.path.join(self.image_dir, image_name)
        mask_path = os.path.join(self.mask_dir, mask_name)

        image = Image.open(image_path)
        mask = Image.open(mask_path)

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Create an instance of the custom dataset
dataset = CustomDataset(root_dir="H:\down_scaled_level6_train_processed", transform=transform)

In [None]:
from torch.utils.data import random_split

# Split dataset into train and validation sets
total_samples = len(dataset)
train_size = int(0.8 * total_samples)
valid_size = total_samples - train_size

train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size])

In [None]:
def display_samples(dataset, num_samples=5):
    for i in range(num_samples):
        image, mask = dataset[i]

        plt.figure(figsize=(12, 4))
    
        plt.subplot(1, 3, 1)
        plt.imshow(image.permute(1, 2, 0).numpy())  # Assuming the image is in the shape (3, 512, 512)
        plt.title("Original Image")

        plt.subplot(1, 3, 2)
        plt.imshow(mask[0], cmap='gray')  # Assuming the mask is in the shape (1, 512, 512)
        plt.title("Ground Truth Mask")


    plt.show()
"""
# Display samples row by row
print("Train dataset")
display_samples(train_dataset, num_samples=10)

print("Valid dataset")
display_samples(valid_dataset, num_samples=10)
"""

In [None]:
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(degrees=(-45, 45)),
    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.1, hue=0.1),
    transforms.ToTensor(),
])

valid_transform = transforms.Compose([
    transforms.ToTensor(),
])


train_dataset.dataset.transform = train_transform
valid_dataset.dataset.transform = valid_transform

In [None]:
# Create data loaders for train and validation sets
bs = 4

train_dataloader = DataLoader(train_dataset, batch_size=bs, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=bs, shuffle=False)

# MODEL BUILDING

In [None]:
import torch.nn as nn
import torch.optim as optim

class AttentionBlock(nn.Module):
    def __init__(self, in_channels):
        super(AttentionBlock, self).__init__()

        self.convolution = nn.Conv2d(in_channels, 1, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        weights = self.sigmoid(self.convolution(x))
        return x * weights

class UNetWithAttention(nn.Module):
    def __init__(self):
        super(UNetWithAttention, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        # Attention block
        self.attention = AttentionBlock(64)

        # Decoder
        self.decoder = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 1, kernel_size=2, stride=2)
        )

    def forward(self, x):
        # Encoder
        x1 = self.encoder(x)

        # Attention
        x_att = self.attention(x1)

        # Decoder
        x = self.decoder(x_att)

        return x

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Print the device
print("Device:", device)

# Create an instance of the UNetWithAttention model
model = UNetWithAttention()
model.to(device)

# Print the model architecture
print(model)

criterion = nn.MSELoss()
print(criterion)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
print(optimizer)

# MAIN TRAINING AND VALIDATION LOOP

In [None]:
from tqdm import tqdm

# Set the number of epochs
num_epochs = 40

# Initialize lists to store loss values for plotting
train_losses = []
train_dice_scores = []
valid_losses = []
valid_dice_scores = []

def dice_coefficient(y_pred, y_true):
    intersection = torch.sum(y_true * y_pred)
    union = torch.sum(y_true) + torch.sum(y_pred)
    return (2.0 * intersection) / (union + 1e-8)

# Training loop
for epoch in range(1, num_epochs+1):
    
    model.train()
    epoch_train_loss = 0.0
    total_train_dice = 0.0
    for images, masks in  tqdm(train_dataloader, desc=f"Epoch {epoch}/{num_epochs}", unit="batch"):
        images, masks = images.to(device), masks.to(device)

        # Forward pass
        outputs = model(images)

        # Calculate the loss
        loss = criterion(outputs, masks)
        train_batch_dice = dice_coefficient(outputs, masks)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_train_dice += train_batch_dice.item()
        epoch_train_loss += loss.item()

    # Calculate average training loss for the epoch
    avg_train_loss = epoch_train_loss / len(train_dataloader)
    avg_train_dice = total_train_dice / len(train_dataloader)
    
    train_losses.append(avg_train_loss)
    train_dice_scores.append(avg_train_dice)

    # Validation
    
    model.eval()  
    epoch_valid_loss = 0.0
    total_valid_dice = 0.0
    with torch.no_grad():
        for images, masks in tqdm(valid_dataloader, desc=f"Epoch {epoch}/{num_epochs}", unit="batch"):
            images, masks = images.to(device), masks.to(device)

            # Forward pass
            outputs = model(images)

            # Calculate the validation loss
            valid_loss = criterion(outputs, masks)
            epoch_valid_loss += valid_loss.item()
            
            # Calculate the Dice coefficient for validation
            val_batch_dice = dice_coefficient(outputs, masks)
            total_valid_dice += val_batch_dice.item()

    # Calculate average validation loss for the epoch
    avg_valid_loss = epoch_valid_loss / len(valid_dataloader)
    avg_valid_dice = total_valid_dice / len(valid_dataloader)
    
    valid_losses.append(avg_valid_loss)
    valid_dice_scores.append(avg_valid_dice)
    
    # Save the trained model
    torch.save(model.state_dict(), f"unet_model_{epoch}.pth")

    # Print losses
    print(f"Epoch {epoch}/{num_epochs}, Train Loss: {avg_train_loss}")
    print(f"Epoch {epoch}/{num_epochs}, Valid Loss: {avg_valid_loss}")
    
    print(f"Epoch {epoch}/{num_epochs}, Train Dice Coeff: {avg_train_dice}")
    print(f"Epoch {epoch}/{num_epochs}, Valid Dice Coeff: {avg_valid_dice}")

In [None]:
# Plotting loss and Dice coefficient curves
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(valid_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_dice_scores, label='Training Dice Coefficient')
plt.plot(valid_dice_scores, label='Validation Dice Coefficient')
plt.xlabel('Epoch')
plt.ylabel('Dice Coefficient')
plt.legend()

plt.show()

In [None]:
# Pick for dice

best_epoch = np.argmax(valid_dice_scores)
print(best_epoch)
saved_model_path = f"unet_model_{best_epoch}.pth"
model.load_state_dict(torch.load(saved_model_path))

In [None]:
# Or pick for loss

best_epoch = np.argmin(valid_losses)
print(best_epoch)
saved_model_path = f"unet_model_{best_epoch}.pth"
model.load_state_dict(torch.load(saved_model_path))

# TEST CASE 

In [None]:
# Load the test dataset
test_dataset = CustomDataset(root_dir="H:\down_scaled_level6_test_resized_equalized", transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
from sklearn.metrics import jaccard_score

# Initialize lists to store predictions and ground truth masks
all_predictions_equalized = []
all_ground_truths = []

# Perform inference
model.eval()  # Set the model to evaluation mode
with torch.no_grad():
    for images, masks in tqdm(test_dataloader):
        images, masks = images.to(device), masks.to(device)

        # Forward pass
        predictions = model(images)

        # Store predictions and ground truth masks
        all_predictions_equalized.append(predictions.cpu().numpy())
        all_ground_truths.append(masks.cpu().numpy())
        
print("Done")

In [None]:
# Assuming you have a list of thresholds
thresholds = np.linspace(0.4, 0.7, 10)  # Adjust the range and step according to your needs

best_jaccard_acc = 0.0
best_threshold = 0.0

for threshold in thresholds:
    # Apply the threshold to get binary predictions
    y_pred_binary = (all_predictions_equalized > threshold).flatten()

    # Calculate Jaccard accuracy
    jaccard_acc = jaccard_score(y_true_binary, y_pred_binary, average='binary')
    print("Threshold: ", threshold)
    print("Accuracy: ", jaccard_acc)
    print("------------------")

    # Check if the current threshold gives a higher Jaccard accuracy
    if jaccard_acc > best_jaccard_acc:
        best_jaccard_acc = jaccard_acc
        best_threshold = threshold

print(f"Best Jaccard Accuracy: {best_jaccard_acc} at Threshold: {best_threshold}")

all_predictions_equalized = np.concatenate(all_predictions_equalized, axis=0)

In [None]:
# Visualize samples
num_samples_to_visualize = len(test_dataset)
for i in range(num_samples_to_visualize):
    # Assuming each element in the dataset is a tuple (image, mask)
    sample = test_dataset[i]
    
    # Convert image and mask tensors to numpy arrays
    image = sample[0].numpy()
    mask = sample[1].numpy()
    
    model_prediction = all_predictions_equalized[i, 0] 
    binary_prediction = (model_prediction > best_threshold).astype(np.uint8)  

    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 3, 1)
    plt.imshow(image.transpose((1, 2, 0)))  # Assuming the image is in the shape (3, 512, 512)
    plt.title("Original Image")

    plt.subplot(1, 3, 2)
    plt.imshow(mask[0], cmap='gray')  # Assuming the mask is in the shape (1, 512, 512)
    plt.title("Ground Truth Mask")

    plt.subplot(1, 3, 3)
    plt.imshow(binary_prediction, cmap='gray')  # Display binary prediction
    plt.title("Binary Model Prediction")

    plt.show()