In [None]:
import os
from PIL import Image
import matplotlib.pyplot as plt
import cv2
import numpy as np

from tqdm import tqdm

In [None]:
def histogram_equalization(image):
    # Convert the image to Lab color space
    lab_image = image.convert('LAB')

    # Split the Lab image into L, a, and b channels
    l_channel, a_channel, b_channel = lab_image.split()

    # Apply histogram equalization to the L channel
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    equalized_l_channel = clahe.apply(np.array(l_channel))

    # Merge the equalized L channel with the original a and b channels
    equalized_lab_image = Image.merge('LAB', (Image.fromarray(equalized_l_channel), a_channel, b_channel))

    # Convert the equalized Lab image back to RGB
    equalized_rgb_image = equalized_lab_image.convert('RGB')

    return equalized_rgb_image

In [None]:
root_dir = r'H:\train_patches'

# Collect pairs of slide and mask paths
train_slide_mask_pairs = []
for slide_name in tqdm(os.listdir(root_dir), desc="Processing slides"):
    slide_dir = os.path.join(root_dir, slide_name)
    
    tissue_dir = os.path.join(slide_dir, "tissues")
    mask_dir = os.path.join(slide_dir, "masks")
    
    for patch_name in os.listdir(tissue_dir):
        mask_name = "mask" + patch_name[5:]
       
        patch_path = os.path.join(tissue_dir, patch_name)
        mask_path = os.path.join(mask_dir, mask_name)
        
        train_slide_mask_pairs.append((patch_path, mask_path))

In [None]:
root_dir = r'H:\test_patches'

# Collect pairs of slide and mask paths
test_slide_mask_pairs = []
for slide_name in tqdm(os.listdir(root_dir), desc="Processing slides"):
    slide_dir = os.path.join(root_dir, slide_name)
    
    tissue_dir = os.path.join(slide_dir, "tissues")
    mask_dir = os.path.join(slide_dir, "masks")
    
    for patch_name in os.listdir(tissue_dir):
        mask_name = "mask" + patch_name[5:]
       
        patch_path = os.path.join(tissue_dir, patch_name)
        mask_path = os.path.join(mask_dir, mask_name)
        
        test_slide_mask_pairs.append((patch_path, mask_path))

In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

class CustomDataset(Dataset):
    def __init__(self, pairs, transform):
        self.pairs = pairs
        self.transform = transform

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        tissue_path, mask_path = self.pairs[idx]
        
        image = Image.open(tissue_path)
        mask = Image.open(mask_path)

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask
    
# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
])

In [None]:
# Create custom datasets for training and testing
total_dataset = CustomDataset(pairs=train_slide_mask_pairs, transform=transform)

In [None]:
from torch.utils.data import random_split

# Split dataset into train and validation sets
total_samples = len(total_dataset)
train_size = int(0.8 * total_samples)
valid_size = total_samples - train_size

train_dataset, valid_dataset = random_split(total_dataset, [train_size, valid_size])

In [None]:
def display_samples(dataset, num_samples=5):
    for i in range(num_samples):
        image, mask = dataset[i]

        plt.figure(figsize=(12, 4))
    
        plt.subplot(1, 3, 1)
        plt.imshow(image.permute(1, 2, 0).numpy())  # Assuming the image is in the shape (3, 512, 512)
        plt.title("Original Image")

        plt.subplot(1, 3, 2)
        plt.imshow(mask[0], cmap='gray')  # Assuming the mask is in the shape (1, 512, 512)
        plt.title("Ground Truth Mask")


    plt.show()
    
"""
# Display samples row by row
print("Train dataset")
display_samples(train_dataset, num_samples=10)

print("Valid dataset")
display_samples(valid_dataset, num_samples=10)
"""

In [None]:
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(degrees=(-45, 45)),
    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.1, hue=0.1),
    transforms.ToTensor(),
])

valid_transform = transforms.Compose([
    transforms.ToTensor(),
])


train_dataset.dataset.transform = train_transform
valid_dataset.dataset.transform = valid_transform

In [None]:
# Create data loaders for train and validation sets
bs = 16

train_dataloader = DataLoader(train_dataset, batch_size=bs, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=bs, shuffle=False)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define the U-Net architecture
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 1, kernel_size=2, stride=2)
        )

    def forward(self, x):
        # Encoder
        x1 = self.encoder(x)

        # Decoder
        x = self.decoder(x1)

        return x


# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Print the device
print("Device:", device)

# Create an instance of the UNet model
model = UNet()
model.to(device)

# Print the model architecture
print(model)

criterion = nn.MSELoss()
print(criterion)
optimizer = optim.Adam(model.parameters(),lr=0.001, weight_decay=0.0001)
print(optimizer)

In [None]:
from tqdm import tqdm

# Set the number of epochs
num_epochs = 10

# Initialize lists to store loss values for plotting
train_losses = []
train_dice_scores = []
valid_losses = []
valid_dice_scores = []

def dice_coefficient(y_pred, y_true):
    intersection = torch.sum(y_true * y_pred)
    union = torch.sum(y_true) + torch.sum(y_pred)
    return (2.0 * intersection) / (union + 1e-8)

# Training loop
for epoch in range(1, num_epochs+1):
    
    model.train()
    epoch_train_loss = 0.0
    total_train_dice = 0.0
    for images, masks in  tqdm(train_dataloader, desc=f"Epoch {epoch}/{num_epochs}", unit="batch"):
        images, masks = images.to(device), masks.to(device)

        # Forward pass
        outputs = model(images)

        # Calculate the loss
        loss = criterion(outputs, masks)
        train_batch_dice = dice_coefficient(outputs, masks)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_train_dice += train_batch_dice.item()
        epoch_train_loss += loss.item()

    # Calculate average training loss for the epoch
    avg_train_loss = epoch_train_loss / len(train_dataloader)
    avg_train_dice = total_train_dice / len(train_dataloader)
    
    train_losses.append(avg_train_loss)
    train_dice_scores.append(avg_train_dice)

    # Validation
    
    model.eval()  
    epoch_valid_loss = 0.0
    total_valid_dice = 0.0
    with torch.no_grad():
        for images, masks in tqdm(valid_dataloader, desc=f"Epoch {epoch}/{num_epochs}", unit="batch"):
            images, masks = images.to(device), masks.to(device)

            # Forward pass
            outputs = model(images)

            # Calculate the validation loss
            valid_loss = criterion(outputs, masks)
            epoch_valid_loss += valid_loss.item()
            
            # Calculate the Dice coefficient for validation
            val_batch_dice = dice_coefficient(outputs, masks)
            total_valid_dice += val_batch_dice.item()

    # Calculate average validation loss for the epoch
    avg_valid_loss = epoch_valid_loss / len(valid_dataloader)
    avg_valid_dice = total_valid_dice / len(valid_dataloader)
    
    valid_losses.append(avg_valid_loss)
    valid_dice_scores.append(avg_valid_dice)
    
    # Save the trained model
    torch.save(model.state_dict(), f"unet_model_{epoch}.pth")

    # Print losses
    print(f"Epoch {epoch}/{num_epochs}, Train Loss: {avg_train_loss}")
    print(f"Epoch {epoch}/{num_epochs}, Valid Loss: {avg_valid_loss}")
    
    print(f"Epoch {epoch}/{num_epochs}, Train Dice Coeff: {avg_train_dice}")
    print(f"Epoch {epoch}/{num_epochs}, Valid Dice Coeff: {avg_valid_dice}")

In [None]:
# Plotting loss and Dice coefficient curves
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(valid_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_dice_scores, label='Training Dice Coefficient')
plt.plot(valid_dice_scores, label='Validation Dice Coefficient')
plt.xlabel('Epoch')
plt.ylabel('Dice Coefficient')
plt.legend()

plt.show()

In [None]:
# Pick for dice

best_epoch = np.argmax(valid_dice_scores)
print(best_epoch)
saved_model_path = f"unet_model_{best_epoch}.pth"
model.load_state_dict(torch.load(saved_model_path))

In [None]:
# Or pick for loss

best_epoch = np.argmin(valid_losses)
print(best_epoch)
saved_model_path = f"unet_model_{best_epoch}.pth"
model.load_state_dict(torch.load(saved_model_path))

In [None]:
# Load the test dataset
test_dataset = CustomDataset(test_slide_mask_pairs, transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
from sklearn.metrics import jaccard_score

# Initialize lists to store predictions and ground truth masks
all_predictions = []
all_ground_truths = []

# Perform inference
model.eval()  # Set the model to evaluation mode
with torch.no_grad():
    for images, masks in tqdm(test_dataloader):
        images, masks = images.to(device), masks.to(device)

        # Forward pass
        predictions = model(images)

        # Store predictions and ground truth mask
        all_predictions.append(predictions.cpu().numpy())
        all_ground_truths.append(masks.cpu().numpy())
        
print("Done")