In [None]:
import os
import numpy as np
from PIL import Image, ImageEnhance
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch

# Define paths
image_folder = os.path.join("data", "Sliced_Images")
mask_folder = os.path.join("data", "Sliced_masks")

# Get sorted lists of image and mask files
image_files = sorted(os.listdir(image_folder))
mask_files = sorted(os.listdir(mask_folder))

# Ensure matching number of images and masks
assert len(image_files) == len(mask_files), "Mismatch between image and mask files."

def crop_black_borders(image, threshold=30):
    """Crop black borders from an image based on a threshold for black pixels."""
    img_array = np.array(image)
    gray_img = np.mean(img_array, axis=2)  # Convert to grayscale by averaging channels

    # Initialize cropping boundaries
    top, bottom = 0, gray_img.shape[0]
    left, right = 0, gray_img.shape[1]

    # Crop from the top
    while top < bottom and np.mean(gray_img[top, :]) <= threshold:
        top += 1

    # Crop from the bottom
    while bottom > top and np.mean(gray_img[bottom - 1, :]) <= threshold:
        bottom -= 1

    # Crop from the left
    while left < right and np.mean(gray_img[:, left]) <= threshold:
        left += 1

    # Crop from the right
    while right > left and np.mean(gray_img[:, right - 1]) <= threshold:
        right -= 1

    # Crop the image to the calculated bounds
    cropped_image = image.crop((left, top, right, bottom))
    return cropped_image

def preprocess_image(image_path):
    """Preprocess an image: crop black borders and enhance contrast."""
    image = Image.open(image_path).convert("RGB")
    cropped_image = crop_black_borders(image)

    # Enhance contrast
    enhancer = ImageEnhance.Contrast(cropped_image)
    enhanced_image = enhancer.enhance(10)

    return enhanced_image

def preprocess_mask(mask_path):
    """Convert mask to binary and ensure tissue is white and background is black."""
    mask = Image.open(mask_path).convert("L")  # Convert mask to grayscale
    mask_array = np.array(mask)

    # Apply binary threshold and ensure tissue is white, background is black
    binary_mask = np.where(mask_array > 0, 1, 0).astype(np.uint8)  # Normalize mask to [0, 1]
    return Image.fromarray(binary_mask * 255)  # Return a PIL Image with values 0 or 255

class TissueDataset(Dataset):
    def __init__(self, image_files, mask_files, image_folder, mask_folder, transform=None, mask_transform=None):
        self.image_files = image_files
        self.mask_files = mask_files
        self.image_folder = image_folder
        self.mask_folder = mask_folder
        self.transform = transform
        self.mask_transform = mask_transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_folder, self.image_files[idx])
        mask_path = os.path.join(self.mask_folder, self.mask_files[idx])

        image = preprocess_image(img_path)
        mask = preprocess_mask(mask_path)

        if self.transform:
            image = self.transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)

        return image, mask

# Define transformations
image_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

# Mask transform - avoid normalization if not needed
mask_transform = transforms.Compose([
    transforms.Resize((128, 128), interpolation=Image.NEAREST),
    transforms.ToTensor(),
])

# Create dataset and dataloaders
dataset = TissueDataset(
    image_files, mask_files,
    image_folder, mask_folder,
    transform=image_transform,
    mask_transform=mask_transform
)

# Split dataset into training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [None]:
import matplotlib.pyplot as plt

def visualize_batch_from_loader(loader, num_batches=1):
    """Visualize a few batches of images and masks from the DataLoader."""
    loader_iter = iter(loader)
    
    for batch_idx in range(num_batches):
        # Get the next batch of images and masks
        images, masks = next(loader_iter)
        
        # Move tensors to CPU if necessary and convert to numpy
        images_np = images.permute(0, 2, 3, 1).cpu().numpy()
        masks_np = masks.squeeze(1).cpu().numpy()  # Remove channel dimension for masks

        # Display images and masks side by side
        batch_size = images_np.shape[0]
        fig, axes = plt.subplots(batch_size, 2, figsize=(10, 5 * batch_size))
        
        for i in range(batch_size):
            axes[i, 0].imshow(images_np[i])
            axes[i, 0].set_title(f"Image {batch_idx * batch_size + i + 1}")
            axes[i, 0].axis("off")
            
            axes[i, 1].imshow(masks_np[i], cmap="gray")
            axes[i, 1].set_title(f"Mask {batch_idx * batch_size + i + 1}")
            axes[i, 1].axis("off")
        
        plt.show()

# Use this function to visualize a few batches of images and masks from the train_loader
visualize_batch_from_loader(train_loader, num_batches=7)


In [None]:
import matplotlib.pyplot as plt







def verify_data_alignment(dataset, num_samples=5):
    """Verify data processing by displaying original images, preprocessed images, and processed masks."""
    for i in range(num_samples):
        # Retrieve the preprocessed image and mask
        preprocessed_image, processed_mask = dataset[i]

        # Convert tensors to PIL images for visualization
        preprocessed_image_pil = transforms.ToPILImage()(preprocessed_image)
        processed_mask_pil = transforms.ToPILImage()(processed_mask)

        # Load the original image
        original_image_path = os.path.join(image_folder, image_files[i])
        original_image = Image.open(original_image_path).convert("RGB")

        # Display the original image, preprocessed image, and processed mask side by side
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))

        axs[0].imshow(original_image)
        axs[0].set_title(f"Original Image {i + 1}")
        axs[0].axis("off")

        axs[1].imshow(preprocessed_image_pil)
        axs[1].set_title(f"Preprocessed Image {i + 1}")
        axs[1].axis("off")

        axs[2].imshow(processed_mask_pil, cmap="gray")
        axs[2].set_title(f"Processed Mask {i + 1}")
        axs[2].axis("off")

        plt.show()

# Run the verification for the first few images and masks
verify_data_alignment(dataset, num_samples=10)

In [None]:
from transformers import Mask2FormerForUniversalSegmentation, Mask2FormerImageProcessor
import torch
import gc

def clear_model_and_cache():
    """Utility function to delete existing model and optimizer objects and clear GPU memory."""
    if 'model' in globals():
        print("Deleting existing model...")
        del globals()['optimizer']
    gc.collect()
    torch.cuda.empty_cache()

# Clear any existing model and cache
clear_model_and_cache()

# Load the image processor with relevant settings
image_processor = Mask2FormerImageProcessor.from_pretrained(
    "facebook/mask2former-swin-base-IN21k-ade-semantic",
    do_rescale=False,   # Skip rescaling if images are already normalized
    do_normalize=True,  # Normalize images if needed
    do_resize=False     # Skip resizing as we're handling this during preprocessing
)

# Load the Mask2Former model for binary segmentation
model = Mask2FormerForUniversalSegmentation.from_pretrained(
    "facebook/mask2former-swin-base-IN21k-ade-semantic",
    num_labels=2,                     # Binary segmentation (background and tissue)
    ignore_mismatched_sizes=True      # Allow resizing of model parameters if dimensions do not match
)

# Freeze all layers in the model initially
for param in model.parameters():
    param.requires_grad = False

# Unfreeze the class predictor layer only
for name, param in model.named_parameters():
    if 'class_predictor' in name:
        param.requires_grad = True

# Display the trainable layers for confirmation
print("Trainable layers:")
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"{name} is trainable")



In [None]:
model.named_modules






In [None]:
import torch.nn as nn

class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        # Apply sigmoid to inputs if not already done
        inputs = torch.sigmoid(inputs)
        
        # Flatten
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()
        dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
        return 1 - dice



In [None]:
import torch.nn as nn
# combined dice and BCE loss function
# Define Combined Dice and BCE Loss
class CombinedDiceBCELoss(nn.Module):
    def __init__(self, dice_weight=0.5, bce_weight=0.5, smooth=1e-6):
        super(CombinedDiceBCELoss, self).__init__()
        self.dice_weight = dice_weight
        self.bce_weight = bce_weight
        self.smooth = smooth
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, logits, targets):
        # Dice Loss
        probs = torch.sigmoid(logits)
        intersection = (probs * targets).sum(dim=(1, 2))  # Sum over spatial dimensions only
        dice_loss = 1 - (2. * intersection + self.smooth) / (probs.sum(dim=(1, 2)) + targets.sum(dim=(1, 2)) + self.smooth)
        dice_loss = dice_loss.mean()  # Average over the batch

        # BCE Loss
        bce_loss = self.bce(logits, targets)

        # Combine losses
        return self.dice_weight * dice_loss + self.bce_weight * bce_loss

In [None]:
tissue_logits = torch.autograd.Variable(tissue_logits, requires_grad=True) #TEMPORARY FIX! THIS NEEDS TO BE CHANGED

In [None]:
print(torch.cuda.is_available())

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from PIL import Image
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# print the device
print(device)

# Set a smaller learning rate for fine-tuning
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

# Assuming DiceLoss is already defined in your environment
criterion = DiceLoss()

# Helper functions
def resize_mask(mask, target_shape=(512, 512)):
    """Resize a mask to the target shape using nearest neighbor interpolation."""
    mask = np.squeeze(mask).astype(np.uint8)
    mask = Image.fromarray(mask).convert("L")
    resized_mask = mask.resize(target_shape, Image.NEAREST)
    return np.array(resized_mask)

def invert_mask(mask):
    """Invert binary mask (0 becomes 1 and vice versa)."""
    return np.where(mask == 0, 1, 0).astype(np.uint8)

# Metric functions
def iou_score(pred, target):
    pred = (pred > 0).astype(np.uint8)
    target = (target > 0).astype(np.uint8)
    intersection = np.logical_and(pred, target)
    union = np.logical_or(pred, target)
    return np.sum(intersection) / np.sum(union) if np.sum(union) > 0 else 1.0

def dice_score(pred, target):
    pred = (pred > 0).astype(np.uint8)
    target = (target > 0).astype(np.uint8)
    intersection = np.sum(pred * target)
    return (2. * intersection) / (np.sum(pred) + np.sum(target)) if (np.sum(pred) + np.sum(target) > 0) else 1.0

# Unfreezing the class predictor and other necessary layers
for name, param in model.named_parameters():
    if 'class_predictor' in name:
        param.requires_grad = True
# send model to device
model.to(device)



def train(model, train_loader, criterion, optimizer, num_epochs=5):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        total_iou, total_dice = 0, 0
        for images, masks in tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}] Training"):
            images = [img.to(device) for img in images]
            masks = masks.to(device)
            
            # Process inputs and get model outputs
            inputs = image_processor(images=images, return_tensors="pt").to(device)
            outputs = model(**inputs)
            tissue_logits = outputs.masks_queries_logits[:, 1].requires_grad_()  # Ensure tissue logits require gradients

          
            target_shape = tissue_logits.shape[-2:]  # Get the height, width from logits shape
            
            # Resize masks to match output shape
            masks_resized = torch.stack([
                torch.tensor(invert_mask(resize_mask(mask.cpu().numpy(), target_shape)), dtype=torch.float32, device=device)
                for mask in masks
            ])
      

            # Compute Dice loss
            loss = criterion(tissue_logits, masks_resized)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            
            # Calculate IoU and Dice metrics
            pred_mask = (torch.sigmoid(tissue_logits) > 0.5).float()
            intersection = (pred_mask * masks_resized).sum((1, 2))
            union = pred_mask.sum((1, 2)) + masks_resized.sum((1, 2)) - intersection
            iou = (intersection / (union + 1e-6)).mean().item()
            dice = (2 * intersection / (pred_mask.sum((1, 2)) + masks_resized.sum((1, 2)) + 1e-6)).mean().item()

            total_iou += iou
            total_dice += dice
        
        avg_loss = running_loss / len(train_loader)
        avg_iou = total_iou / len(train_loader)
        avg_dice = total_dice / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Avg IoU: {avg_iou:.4f}, Avg Dice: {avg_dice:.4f}")

# Validation function with shape debugging
def validate(model, val_loader, criterion):
    model.eval()
    val_loss = 0.0
    total_iou, total_dice = 0, 0
    with torch.no_grad():
        for images, masks in tqdm(val_loader, desc="Validation"):
            images = [img.to(device) for img in images]
            masks = masks.to(device)
            
            # Process inputs and get model outputs
            inputs = image_processor(images=images, return_tensors="pt").to(device)
            outputs = model(**inputs)
            tissue_logits = outputs.masks_queries_logits[:, 1]
            
           
            target_shape = tissue_logits.shape[-2:]  # Get the height, width from logits shape
            
            # Resize masks to match output shape
            masks_resized = torch.stack([
                torch.tensor(invert_mask(resize_mask(mask.cpu().numpy(), target_shape)), dtype=torch.float32, device=device)
                for mask in masks
            ])
    

            # Compute loss
            loss = criterion(tissue_logits, masks_resized)
            val_loss += loss.item()
            
            # Calculate IoU and Dice metrics
            pred_mask = (torch.sigmoid(tissue_logits) > 0.5).float()
            intersection = (pred_mask * masks_resized).sum((1, 2))
            union = pred_mask.sum((1, 2)) + masks_resized.sum((1, 2)) - intersection
            iou = (intersection / (union + 1e-6)).mean().item()
            dice = (2 * intersection / (pred_mask.sum((1, 2)) + masks_resized.sum((1, 2)) + 1e-6)).mean().item()

            total_iou += iou
            total_dice += dice
        
    avg_val_loss = val_loss / len(val_loader)
    avg_iou = total_iou / len(val_loader)
    avg_dice = total_dice / len(val_loader)
    print(f"Validation Loss: {avg_val_loss:.4f}, Avg IoU: {avg_iou:.4f}, Avg Dice: {avg_dice:.4f}")

# Run training and validation
train(model, train_loader, criterion, optimizer, num_epochs=10)
validate(model, val_loader, criterion)


In [None]:
model.save_pretrained('fine-tuned-mask2former')
image_processor.save_pretrained('fine-tuned-mask2former')

In [None]:
# load thefine tuned model
model_test = Mask2FormerForUniversalSegmentation.from_pretrained('fine-tuned-mask2former')
image_processor_test = Mask2FormerImageProcessor.from_pretrained('fine-tuned-mask2former')


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import torch

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Resize and normalize mask
def resize_and_normalize_mask(mask, target_shape=(512, 512)):
    """Resize a mask to the target shape and normalize values to [0, 1]."""
    mask = np.squeeze(mask).astype(np.uint8)  # Remove extra dimensions and convert to uint8
    mask_image = Image.fromarray(mask)
    resized_mask = mask_image.resize(target_shape, Image.NEAREST)
    return np.array(resized_mask) / 255  # Normalize to range [0, 1]

# Invert mask (0 becomes 1, and 1 becomes 0)
def invert_mask(mask):
    mask = np.squeeze(mask).astype(np.uint8)  # Ensure mask is 2D and uint8
    return np.where(mask == 0, 1, 0).astype(np.uint8)

# Calculate IoU
def calculate_iou(pred, target):
    intersection = np.logical_and(pred, target)
    union = np.logical_or(pred, target)
    return np.sum(intersection) / np.sum(union) if np.sum(union) > 0 else 1.0

# Calculate Dice score
def calculate_dice(pred, target):
    intersection = np.sum(pred * target)
    return (2. * intersection) / (np.sum(pred) + np.sum(target)) if (np.sum(pred) + np.sum(target)) > 0 else 1.0

# Run inference in batches
def run_inference(model, image_processor, dataloader, device):
    model.eval()
    predicted_masks = []
    ground_truth_masks = []
    images_list = []

    with torch.no_grad():
        for images, masks in tqdm(dataloader, desc="Inference"):
            images = [img.to(device) for img in images]
            inputs = image_processor(images=images, return_tensors="pt").to(device)
            outputs = model(**inputs)
            mask_logits = outputs.masks_queries_logits
            tissue_logits = mask_logits[:, 1]  # Select tissue logits

            predicted_binary_masks = (torch.sigmoid(tissue_logits) > 0.5).cpu().numpy().astype(np.uint8)
            predicted_masks.extend(predicted_binary_masks)
            ground_truth_masks.extend(masks.cpu().numpy())
            images_list.extend([img.cpu().numpy().transpose(1, 2, 0) for img in images])

    return predicted_masks, ground_truth_masks, images_list

# Evaluate segmentation with IoU and Dice calculations
def evaluate_segmentation(predicted_masks, ground_truth_masks, target_shape=(512, 512)):
    iou_scores = []
    dice_scores = []

    for pred, target in zip(predicted_masks, ground_truth_masks):
        pred_inverted = invert_mask(pred)  # Invert and ensure correct dtype and shape
        target_inverted = invert_mask(target)

        pred_resized = np.array(Image.fromarray(pred_inverted).resize(target_shape, Image.NEAREST))
        target_resized = np.array(Image.fromarray(target_inverted).resize(target_shape, Image.NEAREST))

        iou_score = calculate_iou(pred_resized, target_resized)
        dice_score = calculate_dice(pred_resized, target_resized)

        iou_scores.append(iou_score)
        dice_scores.append(dice_score)

    avg_iou = np.mean(iou_scores)
    avg_dice = np.mean(dice_scores)

    print(f"Average IoU: {avg_iou:.4f}")
    print(f"Average Dice Coefficient: {avg_dice:.4f}")
    return dice_scores, iou_scores

# Visualization to verify alignment
def verify_alignment(images, predicted_masks, ground_truth_masks, num_samples=5, target_shape=(512, 512)):
    for i in range(min(num_samples, len(predicted_masks))):
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))

        # Original image
        axs[0].imshow(images[i])
        axs[0].set_title(f"Image {i + 1}")
        axs[0].axis('off')

        # Predicted mask
        pred_inverted = invert_mask(predicted_masks[i])
        pred_resized = np.array(Image.fromarray(pred_inverted).resize(target_shape, Image.NEAREST))
        axs[1].imshow(pred_resized, cmap='gray')
        axs[1].set_title(f"Predicted Mask {i + 1}")
        axs[1].axis('off')

        # Ground truth mask
        gt_inverted = invert_mask(ground_truth_masks[i])
        gt_resized = np.array(Image.fromarray(gt_inverted).resize(target_shape, Image.NEAREST))
        axs[2].imshow(gt_resized, cmap='gray')
        axs[2].set_title(f"Ground Truth Mask {i + 1}")
        axs[2].axis('off')

        plt.show()

# Example usage
# Run inference
predicted_masks, ground_truth_masks, images_list = run_inference(model, image_processor, train_loader, device)

# Evaluate segmentation
dice_scores, iou_scores = evaluate_segmentation(predicted_masks, ground_truth_masks)

# Verify alignment visually
verify_alignment(images_list, predicted_masks, ground_truth_masks, num_samples=50)



