<a href="https://colab.research.google.com/github/alima-parveen/diabetic_retinopathy/blob/main/Modular_Vision_Based_Multi_task_Learning_for_Eye_Disease_Diagnosis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import random
from tqdm import tqdm
import csv
import matplotlib.pyplot as plt
import seaborn as sns
from google.colab import drive

In [None]:
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Set random seed for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

In [None]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

##IDRiD Grading Dataset Class for Single Task (baseline)

In [None]:
# Dataset class for DR grading
class IDRiDGradingDataset(Dataset):
    def __init__(self, image_dir, grading_csv, transform=None, mode="train"):
        self.image_dir = image_dir
        self.transform = transform
        self.mode = mode
        self.grading_df = pd.read_csv(grading_csv)

        # Load image names
        self.image_names = sorted(os.listdir(image_dir))

        # Filter images based on mode (IDRiD-D: 1–413 for train, 414–516 for test)
        if mode == "train":
            self.image_names = [name for name in self.image_names if int(name.split("_")[1].split(".")[0]) <= 413]
        else:
            self.image_names = [name for name in self.image_names if int(name.split("_")[1].split(".")[0]) ]

        print(f"Mode: {mode}, Found {len(self.image_names)} images in {image_dir}")
        if len(self.image_names) == 0:
            raise ValueError(f"No images found in {image_dir} for mode {mode}")

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img_name = self.image_names[idx]
        img_path = os.path.join(self.image_dir, img_name)
        img = Image.open(img_path).convert("RGB")

        # Load grading label
        img_id = img_name.split(".")[0]
        row = self.grading_df[self.grading_df["Image name"] == img_id]
        if not row.empty:
            grading_label = torch.tensor(row["Retinopathy grade"].values[0], dtype=torch.long)
        else:
            print(f"Warning: No grading label for {img_id}, using default 0")
            grading_label = torch.tensor(0, dtype=torch.long)

        if self.transform:
            img = self.transform(img)

        return img, grading_label, img_name  # Return img_name for visualization

In [None]:
# Classification model
class GradingModel(nn.Module):
    def __init__(self, num_classes=5):
        super(GradingModel, self).__init__()
        vgg16 = models.vgg16(pretrained=True)
        self.encoder = vgg16.features
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.pool(x).view(x.size(0), -1)
        return self.classifier(x)

In [None]:
# Training and evaluation
def train_grading_model(model, train_loader, val_loader, num_epochs=10, results_dir="results"):
    # Create results directory
    os.makedirs(results_dir, exist_ok=True)
    log_file = os.path.join(results_dir, "training_log.txt")
    metrics_file = os.path.join(results_dir, "metrics.csv")

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5)

    if len(val_loader) == 0:
        raise ValueError("Validation DataLoader is empty. Check dataset configuration.")

    best_val_loss = float("inf")
    early_stop_count = 0
    metrics = []

    # Initialize metrics CSV
    with open(metrics_file, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["Epoch", "Train Loss", "Val Loss", "Accuracy"])

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        train_loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]", leave=False)
        for images, grades, _ in train_loop:
            images, grades = images.to(device), grades.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, grades)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            train_loop.set_postfix(loss=loss.item())

        train_loss /= len(train_loader)

        # Validation
        model.eval()
        val_loss = 0
        val_preds, val_labels = [], []
        val_loop = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]", leave=False)
        with torch.no_grad():
            for images, grades, _ in val_loop:
                images, grades = images.to(device), grades.to(device)
                outputs = model(images)
                loss = criterion(outputs, grades)
                val_loss += loss.item()
                val_preds.extend(torch.argmax(outputs, dim=1).cpu().numpy())
                val_labels.extend(grades.cpu().numpy())
                val_loop.set_postfix(loss=loss.item())

        val_loss /= len(val_loader)
        accuracy = accuracy_score(val_labels, val_preds)

        # Log to console and file
        log_message = (f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, "
                       f"Val Loss: {val_loss:.4f}, Accuracy: {accuracy:.4f}")
        print(log_message)
        with open(log_file, "a") as f:
            f.write(log_message + "\n")

        # Save metrics
        metrics.append([epoch + 1, train_loss, val_loss, accuracy])
        with open(metrics_file, "a", newline="") as f:
            writer = csv.writer(f)
            writer.writerow([epoch + 1, train_loss, val_loss, accuracy])

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            early_stop_count = 0
            torch.save(model.state_dict(), os.path.join(results_dir, "best_grading_model.pth"))
        else:
            early_stop_count += 1
            if early_stop_count >= 5:
                print("Early stopping triggered")
                break

        scheduler.step(val_loss)

    # Final evaluation
    final_accuracy = accuracy_score(val_labels, val_preds)
    class_names = ["No DR", "Mild", "Moderate", "Severe", "Proliferative"]
    report = classification_report(val_labels, val_preds, target_names=class_names, digits=4)
    conf_matrix = confusion_matrix(val_labels, val_labels)

    # Save final results
    with open(log_file, "a") as f:
        f.write(f"\nFinal Validation Accuracy: {final_accuracy:.4f}\n")
        f.write("\nClassification Report:\n")
        f.write(report)

    with open(os.path.join(results_dir, "classification_report.txt"), "w") as f:
        f.write(report)

    # Save confusion matrix
    np.savetxt(os.path.join(results_dir, "confusion_matrix.csv"), conf_matrix, delimiter=",", fmt="%d")

    # Visualize confusion matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
    plt.title("Confusion Matrix")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.savefig(os.path.join(results_dir, "confusion_matrix.png"))
    plt.close()

    return final_accuracy, val_preds, val_labels

In [None]:
# Visualize sample results
def visualize_samples(model, val_loader, results_dir, num_samples=2):
    model.eval()
    class_names = ["No DR", "Mild", "Moderate", "Severe", "Proliferative"]
    samples_processed = 0

    os.makedirs(results_dir, exist_ok=True)

    with torch.no_grad():
        for images, grades, img_names in val_loader:
            images = images.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()

            for i in range(len(images)):
                if samples_processed >= num_samples:
                    break
                img = images[i].cpu().permute(1, 2, 0).numpy()
                img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])  # Denormalize
                img = np.clip(img, 0, 1)

                plt.figure(figsize=(6, 6))
                plt.imshow(img)
                plt.title(f"Image: {img_names[i]}\nPred: {class_names[preds[i]]}\nTrue: {class_names[grades[i]]}")
                plt.axis("off")
                plt.savefig(os.path.join(results_dir, f"sample_{samples_processed+1}.png"))
                plt.close()

                samples_processed += 1

            if samples_processed >= num_samples:
                break

In [None]:
# Dataset paths
grading_train_img_dir = "/content/drive/MyDrive/IDRiD Dataset/B. Disease Grading/1. Original Images/a. Training Set"
grading_test_img_dir = "/content/drive/MyDrive/IDRiD Dataset/B. Disease Grading/1. Original Images/b. Testing Set"
grading_train_csv = "/content/drive/MyDrive/IDRiD Dataset/B. Disease Grading/2. Groundtruths/a. IDRiD_Disease Grading_Training Labels.csv"
grading_test_csv = "/content/drive/MyDrive/IDRiD Dataset/B. Disease Grading/2. Groundtruths/b. IDRiD_Disease Grading_Testing Labels.csv"
results_dir = "/content/drive/MyDrive/IDRiD Dataset/Grading"

In [None]:
# Transforms
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
# Datasets
train_dataset = IDRiDGradingDataset(
    image_dir=grading_train_img_dir,
    grading_csv=grading_train_csv,
    transform=transform,
    mode="train"
)
val_dataset = IDRiDGradingDataset(
    image_dir=grading_test_img_dir,
    grading_csv=grading_test_csv,
    transform=transform,
    mode="test"
)

Mode: train, Found 413 images in /content/drive/MyDrive/IDRiD Dataset/B. Disease Grading/1. Original Images/a. Training Set
Mode: test, Found 103 images in /content/drive/MyDrive/IDRiD Dataset/B. Disease Grading/1. Original Images/b. Testing Set


In [None]:
# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

In [None]:
# Debug dataset sizes
print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")

Train dataset size: 413
Validation dataset size: 103


In [None]:
# Train model
model = GradingModel().to(device)
final_accuracy, val_preds, val_labels = train_grading_model(model, train_loader, val_loader, num_epochs=10, results_dir=results_dir)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:03<00:00, 162MB/s]


Epoch 1/10, Train Loss: 1.4239, Val Loss: 1.2309, Accuracy: 0.4660




Epoch 2/10, Train Loss: 1.1867, Val Loss: 1.1510, Accuracy: 0.5243




Epoch 3/10, Train Loss: 1.0701, Val Loss: 1.2190, Accuracy: 0.4175




Epoch 4/10, Train Loss: 1.0145, Val Loss: 1.1845, Accuracy: 0.5534




Epoch 5/10, Train Loss: 0.9272, Val Loss: 1.0886, Accuracy: 0.5922




Epoch 6/10, Train Loss: 0.8464, Val Loss: 1.0753, Accuracy: 0.6117




Epoch 7/10, Train Loss: 0.7999, Val Loss: 1.1163, Accuracy: 0.5728




Epoch 8/10, Train Loss: 0.7370, Val Loss: 1.0755, Accuracy: 0.5631




Epoch 9/10, Train Loss: 0.6600, Val Loss: 1.1007, Accuracy: 0.5534


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 10/10, Train Loss: 0.6111, Val Loss: 1.2101, Accuracy: 0.6117


In [None]:
# Visualize samples
visualize_samples(model, val_loader, results_dir, num_samples=2)

print(f"Final Validation Accuracy: {final_accuracy:.4f}")

Final Validation Accuracy: 0.6117


##IDRiD Lesion Segmentation Dataset Class for Single Task (Baseline)

In [None]:
from torchvision.models import VGG16_Weights
import cv2
import logging
from pathlib import Path

In [None]:
# Set up logging
logging.basicConfig(filename='/content/drive/MyDrive/IDRiD Dataset/segmentation_training.log', level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')

In [None]:
# Dataset class for single-task binary segmentation
class IDRiDSegmentationDataset(Dataset):
    def __init__(self, image_dir, seg_dir, transform=None, mode="train"):
        self.image_dir = image_dir
        self.seg_dir = seg_dir
        self.transform = transform
        self.mode = mode
        self.lesion_suffixes = {
            '1. Microaneurysms': 'MA',
            '2. Haemorrhages': 'HE',
            '3. Hard Exudates': 'EX',
            '4. Soft Exudates': 'SE',
            '5. Optic Disc': 'OD'
        }

        # Load image names
        self.image_names = sorted([f for f in os.listdir(image_dir) if f.lower().endswith('.jpg')])

        # Filter images based on mode (IDRiD-S: 1–54 for train, 55–81 for test)
        if mode == "train":
            self.image_names = [name for name in self.image_names if int(name.split("_")[1].split(".")[0]) <= 54]
        else:
            self.image_names = [name for name in self.image_names if int(name.split("_")[1].split(".")[0]) >= 55]

        logging.info(f"Mode: {mode}, Found {len(self.image_names)} images in {image_dir}")
        print(f"Mode: {mode}, Found {len(self.image_names)} images in {image_dir}")
        if len(self.image_names) == 0:
            raise ValueError(f"No images found in {image_dir} for mode {mode}")

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img_name = self.image_names[idx]
        img_path = os.path.join(self.image_dir, img_name)
        try:
            img = Image.open(img_path).convert("RGB")
            logging.debug(f"Loaded image: {img_path}")
        except FileNotFoundError:
            logging.error(f"File not found: {img_path}")
            print(f"File not found: {img_path}")
            img = Image.new('RGB', (256, 256))
            return torch.zeros(3, 256, 256), torch.zeros(256, 256), False, img_name

        # Load and combine masks for all lesion types into a single binary mask
        img_id = img_name.split(".")[0]
        seg_mask = torch.zeros(256, 256)
        valid_mask = False
        for lesion_dir, suffix in self.lesion_suffixes.items():
            seg_path = os.path.join(self.seg_dir, lesion_dir, f"{img_id}_{suffix}.tif")
            if os.path.exists(seg_path):
                try:
                    mask = cv2.imread(seg_path, cv2.IMREAD_GRAYSCALE)
                    if mask is None:
                        logging.warning(f"Failed to load mask at {seg_path}, using zero mask")
                        mask = np.zeros((256, 256), dtype=np.uint8)
                    else:
                        mask = cv2.resize(mask, (256, 256))
                        mask_binary = torch.tensor(mask > 0, dtype=torch.float32)
                        seg_mask = torch.logical_or(seg_mask, mask_binary).float()  # Ensure float32
                        if mask_binary.sum() > 0:
                            logging.debug(f"Loaded non-empty mask for {lesion_dir} in {img_name}: {mask_binary.sum()} pixels")
                            valid_mask = True
                        else:
                            logging.debug(f"Loaded empty mask for {lesion_dir} in {img_name}")
                except Exception as e:
                    logging.warning(f"Error loading mask {seg_path}: {str(e)}")
            else:
                logging.debug(f"Mask not found for {seg_path}")

        if not valid_mask:
            logging.warning(f"No valid masks found for {img_name}")

        if self.transform:
            img = self.transform(img)

        logging.debug(f"seg_mask type: {seg_mask.dtype}, shape: {seg_mask.shape}")
        return img, seg_mask, valid_mask, img_name

In [None]:
# Segmentation model for single binary mask
class SegmentationModel(nn.Module):
    def __init__(self):
        super(SegmentationModel, self).__init__()
        self.backbone = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features
        self.in_channels = 512  # VGG16 final feature map has 512 channels

        # UNet-like decoder matching MTL structure
        self.decoder = nn.ModuleList([
            nn.Conv2d(self.in_channels, 512, 3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(512, 256, 3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(256, 128, 3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(64, 32, 3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(32, 1, 3, padding=1)  # Single channel for binary segmentation
        ])

        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.backbone(x)
        for layer in self.decoder:
            x = layer(x)
        return x

In [None]:
# Dice score for binary segmentation
def dice_score(pred, target, threshold=0.1):
    pred = (torch.sigmoid(pred) > threshold).float()
    smooth = 1e-5
    intersection = (pred * target).sum()
    pred_sum = pred.sum()
    target_sum = target.sum()
    dice = (2. * intersection + smooth) / (pred_sum + target_sum + smooth)
    logging.debug(f"Dice score: {dice:.4f}, Intersection: {intersection:.4f}, Pred sum: {pred_sum:.4f}, Target sum: {target_sum:.4f}")
    return dice.item()

In [None]:
# Training and evaluation
def train_segmentation_model(model, train_loader, val_loader, num_epochs=20, results_dir="/content/drive/MyDrive/IDRiD Dataset/Segmentation"):
    os.makedirs(results_dir, exist_ok=True)
    log_file = os.path.join(results_dir, "segmentation_training.log")
    model_file = os.path.join(results_dir, "best_segmentation_model.pth")

    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(10.0).to(device))
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=1e-4, max_lr=2e-3, step_size_up=len(train_loader)*3, mode='triangular2')

    if len(val_loader) == 0:
        raise ValueError("Validation DataLoader is empty. Check dataset configuration.")

    best_dice = 0.0
    patience_counter = 0
    patience = 5

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        valid_mask_batches = 0
        train_loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]", leave=True)
        for images, masks, valid_masks, _ in train_loop:
            images, masks = images.to(device), masks.to(device)
            valid_masks = valid_masks.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            seg_loss = 0
            if valid_masks.any():
                valid_indices = valid_masks.nonzero(as_tuple=True)[0]
                if len(valid_indices) > 0:
                    logging.debug(f"Outputs type: {outputs.dtype}, Masks type: {masks.dtype}")
                    seg_loss = criterion(outputs.squeeze(1)[valid_indices], masks[valid_indices].float())
                    valid_mask_batches += 1

            seg_loss.backward()
            optimizer.step()
            scheduler.step()

            train_loss += seg_loss.item() if seg_loss != 0 else 0
            train_loop.set_postfix({'Seg Loss': f'{seg_loss.item():.4f}' if seg_loss != 0 else '0.0', 'Valid Masks': valid_mask_batches})

        avg_train_loss = train_loss / valid_mask_batches if valid_mask_batches > 0 else 0.0

        model.eval()
        val_loss = 0.0
        val_dice_scores = []
        valid_mask_batches = 0
        val_loop = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]", leave=False)
        with torch.no_grad():
            for images, masks, valid_masks, _ in val_loop:
                images, masks = images.to(device), masks.to(device)
                valid_masks = valid_masks.to(device)
                outputs = model(images)

                seg_loss = 0
                if valid_masks.any():
                    valid_indices = valid_masks.nonzero(as_tuple=True)[0]
                    if len(valid_indices) > 0:
                        logging.debug(f"Validation Outputs type: {outputs.dtype}, Masks type: {masks.dtype}")
                        seg_loss = criterion(outputs.squeeze(1)[valid_indices], masks[valid_indices].float())
                        val_dice = dice_score(outputs.squeeze(1)[valid_indices], masks[valid_indices].float())
                        val_dice_scores.append(val_dice)
                        valid_mask_batches += 1

                val_loss += seg_loss.item() if seg_loss != 0 else 0
                val_loop.set_postfix({'Seg Loss': f'{seg_loss.item():.4f}' if seg_loss != 0 else '0.0', 'Valid Masks': valid_mask_batches})

        avg_val_loss = val_loss / valid_mask_batches if valid_mask_batches > 0 else 0.0
        avg_dice = np.mean(val_dice_scores) if val_dice_scores else 0.0

        log_message = f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Dice: {avg_dice:.4f}"
        logging.info(log_message)
        print(log_message)

        # Early stopping based on Dice score
        if avg_dice > best_dice:
            best_dice = avg_dice
            torch.save(model.state_dict(), model_file)
            logging.info("Saved best model based on Dice score")
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                logging.info(f"Early stopping triggered after {patience} epochs without improvement")
                print(f"Early stopping at epoch {epoch+1}")
                break

        scheduler.step()

    return best_dice

In [None]:
# Visualization
def visualize_segmentation_results(model, val_loader, results_dir, num_samples=5):
    model.eval()
    samples_processed = 0
    os.makedirs(results_dir, exist_ok=True)

    with torch.no_grad():
        for images, masks, valid_masks, img_names in val_loader:
            images = images.to(device)
            outputs = model(images)
            seg_preds = torch.sigmoid(outputs.squeeze(1)) > 0.1

            for i in range(min(num_samples - samples_processed, images.size(0))):
                if not valid_masks[i]:
                    continue
                img = images[i].cpu().permute(1, 2, 0).numpy()
                img = (img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])).clip(0, 1)
                pred_mask = seg_preds[i].cpu().numpy()
                true_mask = masks[i].cpu().numpy()

                plt.figure(figsize=(15, 5))
                plt.subplot(1, 3, 1)
                plt.title(f"Image: {img_names[i]}")
                plt.imshow(img)
                plt.axis("off")
                plt.subplot(1, 3, 2)
                plt.title("Predicted Mask")
                plt.imshow(pred_mask, cmap="gray")
                plt.axis("off")
                plt.subplot(1, 3, 3)
                plt.title("Ground Truth Mask")
                plt.imshow(true_mask, cmap="gray")
                plt.axis("off")
                plt.savefig(os.path.join(results_dir, f"segmentation_result_{img_names[i]}.png"))
                plt.close()
                logging.info(f"Generated visualization for {img_names[i]}")
                samples_processed += 1

            if samples_processed >= num_samples:
                break

In [None]:
seg_train_img_dir = "/content/drive/MyDrive/IDRiD Dataset/A. Segmentation/1. Original Images/a. Training Set"
seg_test_img_dir = "/content/drive/MyDrive/IDRiD Dataset/A. Segmentation/1. Original Images/b. Testing Set"
seg_train_mask_dir = "/content/drive/MyDrive/IDRiD Dataset/A. Segmentation/2. All Segmentation Groundtruths/a. Training Set"
seg_test_mask_dir = "/content/drive/MyDrive/IDRiD Dataset/A. Segmentation/2. All Segmentation Groundtruths/b. Testing Set"
results_dir = "/content/drive/MyDrive/IDRiD Dataset/Segmentation"

In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(30),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.8, 1.2)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
val_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = IDRiDSegmentationDataset(
    image_dir=seg_train_img_dir,
    seg_dir=seg_train_mask_dir,
    transform=transform,
    mode="train"
)
val_dataset = IDRiDSegmentationDataset(
    image_dir=seg_test_img_dir,
    seg_dir=seg_test_mask_dir,
    transform=val_transform,
    mode="test"
)

Mode: train, Found 54 images in /content/drive/MyDrive/IDRiD Dataset/A. Segmentation/1. Original Images/a. Training Set
Mode: test, Found 27 images in /content/drive/MyDrive/IDRiD Dataset/A. Segmentation/1. Original Images/b. Testing Set


In [None]:
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

In [None]:
logging.info(f"Train dataset size: {len(train_dataset)}")
logging.info(f"Validation dataset size: {len(val_dataset)}")
print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")

Train dataset size: 54
Validation dataset size: 27


In [None]:
model = SegmentationModel().to(device)
final_dice = train_segmentation_model(model, train_loader, val_loader, num_epochs=10, results_dir=results_dir)

Epoch 1/10 [Train]: 100%|██████████| 14/14 [02:45<00:00, 11.81s/it, Seg Loss=0.8190, Valid Masks=14]


Epoch 1/10, Train Loss: 1.4809, Val Loss: 0.9227, Dice: 0.0821


Epoch 2/10 [Train]: 100%|██████████| 14/14 [00:15<00:00,  1.07s/it, Seg Loss=1.0386, Valid Masks=14]


Epoch 2/10, Train Loss: 0.9321, Val Loss: 0.8727, Dice: 0.0821


Epoch 3/10 [Train]: 100%|██████████| 14/14 [00:13<00:00,  1.04it/s, Seg Loss=0.9088, Valid Masks=14]


Epoch 3/10, Train Loss: 0.8755, Val Loss: 0.9138, Dice: 0.0821


Epoch 4/10 [Train]: 100%|██████████| 14/14 [00:13<00:00,  1.02it/s, Seg Loss=0.6587, Valid Masks=14]


Epoch 4/10, Train Loss: 0.8412, Val Loss: 0.8581, Dice: 0.0821


Epoch 5/10 [Train]: 100%|██████████| 14/14 [00:13<00:00,  1.04it/s, Seg Loss=0.6942, Valid Masks=14]


Epoch 5/10, Train Loss: 0.8099, Val Loss: 0.8048, Dice: 0.0879


Epoch 6/10 [Train]: 100%|██████████| 14/14 [00:15<00:00,  1.09s/it, Seg Loss=1.9781, Valid Masks=14]


Epoch 6/10, Train Loss: 0.7969, Val Loss: 0.7273, Dice: 0.1009


Epoch 7/10 [Train]: 100%|██████████| 14/14 [00:15<00:00,  1.10s/it, Seg Loss=0.8089, Valid Masks=14]


Epoch 7/10, Train Loss: 0.7936, Val Loss: 0.8028, Dice: 0.0879


Epoch 8/10 [Train]: 100%|██████████| 14/14 [00:13<00:00,  1.05it/s, Seg Loss=0.6603, Valid Masks=14]


Epoch 8/10, Train Loss: 0.7519, Val Loss: 0.8235, Dice: 0.1269


Epoch 9/10 [Train]: 100%|██████████| 14/14 [00:15<00:00,  1.08s/it, Seg Loss=1.3287, Valid Masks=14]


Epoch 9/10, Train Loss: 0.7382, Val Loss: 0.7558, Dice: 0.1024


Epoch 10/10 [Train]: 100%|██████████| 14/14 [00:13<00:00,  1.04it/s, Seg Loss=0.6378, Valid Masks=14]
                                                                                                

Epoch 10/10, Train Loss: 0.7089, Val Loss: 0.7206, Dice: 0.1107




In [None]:
visualize_segmentation_results(model, val_loader, results_dir)

logging.info(f"Final Dice Score: {final_dice:.4f}")
print(f"Final Dice Score: {final_dice:.4f}")

Final Dice Score: 0.1269


##IDRiD Multitask Working

In [None]:
import logging
from pathlib import Path
from torch.optim.lr_scheduler import CyclicLR

In [None]:
# Set up logging
logging.basicConfig(filename='/content/drive/MyDrive/IDRiD Dataset/training_2.log', level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')

In [None]:
# Custom Dataset for IDRiD
class IDRiDDataset(Dataset):
    def __init__(self, image_dir_seg, image_dir_grad, grading_csv=None, seg_dir=None, mode='both', transform=None):
        self.image_dir_seg = image_dir_seg
        self.image_dir_grad = image_dir_grad
        self.grading_csv = pd.read_csv(grading_csv) if grading_csv else None
        self.seg_dir = seg_dir
        self.mode = mode
        self.transform = transform

        # Lesion-specific suffixes
        self.lesion_suffixes = {
            '1. Microaneurysms': 'MA',
            '2. Haemorrhages': 'HE',
            '3. Hard Exudates': 'EX',
            '4. Soft Exudates': 'SE',
            '5. Optic Disc': 'OD'
        }

        # Load image names
        self.images = []
        self.is_segmentation = {}
        self.mask_files = {}
        if mode in ['segmentation', 'both'] and seg_dir:
            seg_images = [f for f in os.listdir(image_dir_seg) if f.lower().endswith('.jpg')]
            self.images.extend(seg_images)
            for img in seg_images:
                self.is_segmentation[img] = True
                self.mask_files[img] = {}
            logging.info(f"Loaded {len(seg_images)} segmentation images from {image_dir_seg}")
        if mode in ['grading', 'both'] and self.grading_csv is not None:
            grad_images = [f"{name}.jpg" for name in self.grading_csv['Image name']]
            for img in grad_images:
                self.is_segmentation[img] = img in os.listdir(image_dir_seg)
            self.images.extend(grad_images)
            logging.info(f"Loaded {len(grad_images)} grading images from CSV")
        self.images = list(set(self.images))
        logging.info(f"Total unique images: {len(self.images)}")

        # Filter images and detect mask filenames
        valid_images = []
        mask_counts = {lesion: 0 for lesion in self.lesion_suffixes}
        mask_exists = {lesion: 0 for lesion in self.lesion_suffixes}
        for img in self.images:
            img_path = os.path.join(image_dir_seg, img) if img in os.listdir(image_dir_seg) else os.path.join(image_dir_grad, img)
            if not os.path.exists(img_path):
                logging.warning(f"Image not found: {img_path}")
                continue
            valid_images.append(img)
            if self.is_segmentation.get(img, False) and mode in ['segmentation', 'both'] and seg_dir:
                img_base = img.split('.')[0]
                num_part = img_base.split('_')[1]
                for lesion, suffix in self.lesion_suffixes.items():
                    patterns = [
                        f"{img_base}_{suffix}.tif",
                        f"{img_base}_{suffix}.TIF",
                        f"{img_base}.tif",
                        f"IDRiD_{num_part:0>3}_{suffix}.tif",
                        f"IDRiD_{num_part:0>3}.tif"
                    ]
                    mask_path = None
                    for pattern in patterns:
                        path = os.path.join(seg_dir, lesion, pattern)
                        if os.path.exists(path):
                            mask_path = path
                            self.mask_files[img][lesion] = pattern
                            mask_exists[lesion] += 1
                            try:
                                mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                                if mask is not None and np.sum(mask > 0) > 0:
                                    mask_counts[lesion] += 1
                            except Exception as e:
                                logging.warning(f"Failed to load mask {mask_path}: {str(e)}")
                            break
                    if not mask_path:
                        logging.debug(f"No mask found for {img} in {lesion}")
        self.images = valid_images
        logging.info(f"Valid images after filtering: {len(self.images)}")
        logging.info(f"Mask files exist: {mask_exists}")
        logging.info(f"Non-empty mask counts: {mask_counts}")
        print(f"Mask files exist: {mask_exists}")
        print(f"Non-empty mask counts: {mask_counts}")

        # Upsample segmentation images only
        if mode == 'both':
            seg_images = [img for img in self.images if self.is_segmentation.get(img, False)]
            if len(seg_images) < 54:
                seg_images = seg_images * (54 // len(seg_images) + 1)
                seg_images = seg_images[:54]
                self.images = seg_images + [img for img in self.images if not self.is_segmentation.get(img, False)]
                logging.info(f"Upsampled segmentation images to {len(seg_images)}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.image_dir_seg, img_name)
        if not os.path.exists(img_path):
            img_path = os.path.join(self.image_dir_grad, img_name)

        try:
            image = Image.open(img_path).convert('RGB')
            logging.debug(f"Loaded image: {img_path}")
        except FileNotFoundError:
            logging.error(f"File not found: {img_path}")
            print(f"File not found: {img_path}")
            image = Image.new('RGB', (512, 512))
            return (torch.zeros(3, 512, 512), torch.tensor(-1, dtype=torch.long),
                    torch.zeros(512, 512), False, img_name)

        grading_label = torch.tensor(-1, dtype=torch.long)
        seg_mask = torch.zeros(512, 512)
        valid_mask = False

        if self.mode in ['grading', 'both'] and self.grading_csv is not None:
            img_base = img_name.split('.')[0]
            if img_base in self.grading_csv['Image name'].values:
                label_row = self.grading_csv[self.grading_csv['Image name'] == img_base]
                grading_label = torch.tensor(int(label_row['Retinopathy grade'].values[0]), dtype=torch.long)
                logging.debug(f"Loaded grading label for {img_name}: {grading_label}")

        if self.mode in ['segmentation', 'both'] and self.seg_dir and self.is_segmentation.get(img_name, False):
            mask_found = False
            for lesion in self.lesion_suffixes:
                mask_filename = self.mask_files.get(img_name, {}).get(lesion)
                if not mask_filename:
                    logging.debug(f"No mask filename for {img_name} in {lesion}")
                    continue
                mask_path = os.path.join(self.seg_dir, lesion, mask_filename)
                if os.path.exists(mask_path):
                    mask_found = True
                    try:
                        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                        if mask is None:
                            logging.error(f"Failed to read mask {mask_path}")
                            continue
                        mask = cv2.resize(mask, (512, 512))
                        mask_binary = torch.tensor(mask > 0, dtype=torch.float32)
                        seg_mask = torch.logical_or(seg_mask, mask_binary)
                        if mask_binary.sum() > 0:
                            logging.debug(f"Loaded non-empty mask for {lesion} in {img_name}: {mask_binary.sum()} pixels")
                        else:
                            logging.debug(f"Loaded empty mask for {lesion} in {img_name}")
                    except Exception as e:
                        logging.error(f"Error loading mask {mask_path}: {str(e)}")
                else:
                    logging.debug(f"Mask not found: {mask_path}")
            valid_mask = mask_found
            if not mask_found:
                logging.warning(f"No mask files found for {img_name} (segmentation image)")
            elif seg_mask.sum() == 0:
                logging.debug(f"All masks empty for {img_name} (segmentation image)")

        if self.transform:
            image = self.transform(image)

        return image, grading_label, seg_mask, valid_mask, img_name

In [None]:
# Enhanced data augmentation
train_transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(30),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.8, 1.2)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
# Load datasets
base_path = '/content/drive/MyDrive/IDRiD Dataset'
try:
    train_dataset = IDRiDDataset(
        image_dir_seg=f'{base_path}/A. Segmentation/1. Original Images/a. Training Set',
        image_dir_grad=f'{base_path}/B. Disease Grading/1. Original Images/a. Training Set',
        grading_csv=f'{base_path}/B. Disease Grading/2. Groundtruths/a. IDRiD_Disease Grading_Training Labels.csv',
        seg_dir=f'{base_path}/A. Segmentation/2. All Segmentation Groundtruths/a. Training Set',
        mode='both',
        transform=train_transform
    )
    test_dataset = IDRiDDataset(
        image_dir_seg=f'{base_path}/A. Segmentation/1. Original Images/b. Testing Set',
        image_dir_grad=f'{base_path}/B. Disease Grading/1. Original Images/b. Testing Set',
        grading_csv=f'{base_path}/B. Disease Grading/2. Groundtruths/b. IDRiD_Disease Grading_Testing Labels.csv',
        seg_dir=f'{base_path}/A. Segmentation/2. All Segmentation Groundtruths/b. Testing Set',
        mode='both',
        transform=val_transform
    )
    logging.info("Datasets loaded successfully")
    print("Datasets loaded successfully")
except Exception as e:
    logging.error(f"Error loading datasets: {str(e)}")
    print(f"Error loading datasets: {str(e)}")
    raise

Mask files exist: {'1. Microaneurysms': 54, '2. Haemorrhages': 53, '3. Hard Exudates': 54, '4. Soft Exudates': 26, '5. Optic Disc': 54}
Non-empty mask counts: {'1. Microaneurysms': 54, '2. Haemorrhages': 53, '3. Hard Exudates': 54, '4. Soft Exudates': 26, '5. Optic Disc': 54}
Mask files exist: {'1. Microaneurysms': 27, '2. Haemorrhages': 27, '3. Hard Exudates': 27, '4. Soft Exudates': 14, '5. Optic Disc': 27}
Non-empty mask counts: {'1. Microaneurysms': 27, '2. Haemorrhages': 27, '3. Hard Exudates': 27, '4. Soft Exudates': 14, '5. Optic Disc': 27}
Datasets loaded successfully


In [None]:
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

In [None]:
# Compute class weights for grading
grading_csv = pd.read_csv(f'{base_path}/B. Disease Grading/2. Groundtruths/a. IDRiD_Disease Grading_Training Labels.csv')
class_counts = grading_csv['Retinopathy grade'].value_counts().sort_index()
class_weights = 1.0 / class_counts
class_weights = class_weights / class_weights.sum() * len(class_weights)
class_weights = torch.tensor(class_weights.values, dtype=torch.float32).to(device)
logging.info(f"Class weights for grading: {class_weights.tolist()}")

In [None]:
# Modular MTL Model with Dynamic Routing
class MultiTaskModel(nn.Module):
    def __init__(self):
        super(MultiTaskModel, self).__init__()
        self.backbone = models.mobilenet_v2(weights='IMAGENET1K_V1').features
        self.in_channels = 1280

        # Classification expert
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Dropout(0.5),
            nn.Linear(self.in_channels, 5)
        )

        # Segmentation expert
        self.decoder = nn.ModuleList([
            nn.Conv2d(self.in_channels, 512, 3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(512, 256, 3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(256, 128, 3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(64, 32, 3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(32, 1, 3, padding=1)
        ])

        # Dynamic routing gate
        self.gate = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(self.in_channels, 2),
            nn.Softmax(dim=1)
        )

        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x, task='both'):
        features = self.backbone(x)
        gate_weights = self.gate(features)  # [batch_size, 2]

        outputs = {}
        if task in ['grading', 'both']:
            class_out = self.classifier(features)
            outputs['grading'] = gate_weights[:, 0].unsqueeze(1) * class_out

        if task in ['segmentation', 'both']:
            seg_out = features
            for layer in self.decoder:
                seg_out = layer(seg_out)
            outputs['segmentation'] = gate_weights[:, 1].unsqueeze(1).unsqueeze(2).unsqueeze(3) * seg_out

        return outputs

In [None]:
# Focal loss for segmentation
def focal_loss(pred, target, alpha=0.75, gamma=1.0):
    bce = nn.BCEWithLogitsLoss(reduction='none')(pred, target)
    pt = torch.exp(-bce)
    focal = alpha * (1 - pt) ** gamma * bce
    return focal.mean()

In [None]:
# Dice score for binary segmentation
def dice_score(pred, target, threshold=0.1):
    pred = (torch.sigmoid(pred) > threshold).float()
    smooth = 1e-5
    intersection = (pred * target).sum()
    pred_sum = pred.sum()
    target_sum = target.sum()
    dice = (2. * intersection + smooth) / (pred_sum + target_sum + smooth)
    logging.debug(f"Dice score: {dice:.4f}, Intersection: {intersection:.4f}, Pred sum: {pred_sum:.4f}, Target sum: {target_sum:.4f}")
    return dice.item()

In [None]:
# Training function with early stopping
def train_model(model, train_loader, test_loader, epochs=50, patience=5):
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=5e-4)
    scheduler = CyclicLR(optimizer, base_lr=1e-4, max_lr=2e-3, step_size_up=len(train_loader)*3, mode='triangular2')
    class_criterion = nn.CrossEntropyLoss(weight=class_weights, ignore_index=-1)

    best_acc = 0.0
    patience_counter = 0
    try:
        for epoch in range(epochs):
            model.train()
            train_class_loss, train_seg_loss = 0.0, 0.0
            valid_mask_batches = 0
            train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=True)
            for images, grading_labels, seg_masks, valid_masks, img_names in train_bar:
                images, grading_labels, seg_masks = images.to(device), grading_labels.to(device), seg_masks.to(device)
                valid_masks = valid_masks.to(device)

                optimizer.zero_grad()
                outputs = model(images, task='both')

                class_loss = class_criterion(outputs['grading'], grading_labels) if 'grading' in outputs else 0
                seg_loss = 0
                if 'segmentation' in outputs and valid_masks.any():
                    valid_indices = valid_masks.nonzero(as_tuple=True)[0]
                    if len(valid_indices) > 0:
                        seg_loss = focal_loss(outputs['segmentation'].squeeze(1)[valid_indices], seg_masks[valid_indices])

                class_weight = 0.7 if valid_masks.any() else 1.0
                seg_weight = 0.3 if valid_masks.any() else 0.0
                loss = class_weight * class_loss + seg_weight * seg_loss
                loss.backward()
                optimizer.step()
                scheduler.step()

                train_class_loss += class_loss.item() if class_loss != 0 else 0
                train_seg_loss += seg_loss.item() if seg_loss != 0 else 0
                if valid_masks.any():
                    valid_mask_batches += 1

                train_bar.set_postfix({
                    'Class Loss': f'{class_loss.item():.4f}' if class_loss != 0 else 'N/A',
                    'Seg Loss': f'{seg_loss.item():.4f}' if seg_loss != 0 else '0.0',
                    'Valid Masks': valid_mask_batches
                })

            avg_class_loss = train_class_loss / len(train_loader)
            avg_seg_loss = train_seg_loss / valid_mask_batches if valid_mask_batches > 0 else 0.0
            logging.info(f"Epoch {epoch+1}/{epochs}, Class Loss: {avg_class_loss:.4f}, Seg Loss: {avg_seg_loss:.4f}, Valid Mask Batches: {valid_mask_batches}")
            print(f"Epoch {epoch+1}/{epochs} - Class Loss: {avg_class_loss:.4f}, Seg Loss: {avg_seg_loss:.4f}, Valid Mask Batches: {valid_mask_batches}")

            # Evaluation
            model.eval()
            class_preds, class_true = [], []
            seg_dice_scores = []
            with torch.no_grad():
                val_bar = tqdm(test_loader, desc="Validation", leave=False)
                for images, grading_labels, seg_masks, valid_masks, img_names in val_bar:
                    images, grading_labels, seg_masks = images.to(device), grading_labels.to(device), seg_masks.to(device)
                    valid_masks = valid_masks.to(device)
                    outputs = model(images, task='both')

                    if 'grading' in outputs:
                        class_preds.extend(torch.argmax(outputs['grading'], dim=1).cpu().numpy())
                        class_true.extend(grading_labels.cpu().numpy())
                    if 'segmentation' in outputs and valid_masks.any():
                        valid_indices = valid_masks.nonzero(as_tuple=True)[0]
                        if len(valid_indices) > 0:
                            seg_dice = dice_score(outputs['segmentation'].squeeze(1)[valid_indices], seg_masks[valid_indices])
                            seg_dice_scores.append(seg_dice)

            class_acc = accuracy_score([t for t in class_true if t != -1],
                                      [p for t, p in zip(class_true, class_preds) if t != -1]) if class_true else 0
            avg_dice = np.mean(seg_dice_scores) if seg_dice_scores else 0.0
            logging.info(f"Validation - Accuracy: {class_acc:.4f}, Dice: {avg_dice:.4f}")
            print(f"Validation - Accuracy: {class_acc:.4f}, Dice: {avg_dice:.4f}")

            # Early stopping
            if class_acc > best_acc:
                best_acc = class_acc
                torch.save(model.state_dict(), '/content/drive/MyDrive/IDRiD Dataset/best_model_2.pth')
                logging.info("Saved best model based on classification accuracy")
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    logging.info(f"Early stopping triggered after {patience} epochs without improvement")
                    print(f"Early stopping at epoch {epoch+1}")
                    break

            scheduler.step()

    except KeyboardInterrupt:
        logging.warning("Training interrupted by user")
        print("Training interrupted. Saving current model state...")
        torch.save(model.state_dict(), '/content/drive/MyDrive/IDRiD Dataset/interrupted_model_2.pth')
        logging.info("Saved interrupted model state to 'interrupted_model_2.pth'")
        return class_acc, avg_dice

    return class_acc, avg_dice

In [None]:
# Visualization
def visualize_results(model, test_loader, num_samples=5):
    model.eval()
    samples = []
    try:
        with torch.no_grad():
            for images, _, seg_masks, valid_masks, img_names in test_loader:
                images = images.to(device)
                outputs = model(images, task='segmentation')
                seg_preds = torch.sigmoid(outputs['segmentation'].squeeze(1)) > 0.1

                for i in range(min(num_samples, images.size(0))):
                    if not valid_masks[i]:
                        continue
                    img = images[i].cpu().permute(1, 2, 0).numpy()
                    img = (img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])).clip(0, 1)
                    pred_mask = seg_preds[i].cpu().numpy()
                    true_mask = seg_masks[i].cpu().numpy()

                    plt.figure(figsize=(15, 5))
                    plt.subplot(1, 3, 1)
                    plt.title(f"Image: {img_names[i]}")
                    plt.imshow(img)
                    plt.axis('off')
                    plt.subplot(1, 3, 2)
                    plt.title("Predicted Mask")
                    plt.imshow(pred_mask, cmap='gray')
                    plt.axis('off')
                    plt.subplot(1, 3, 3)
                    plt.title("Ground Truth Mask")
                    plt.imshow(true_mask, cmap='gray')
                    plt.axis('off')
                    plt.savefig(f"/content/drive/MyDrive/IDRiD Dataset/result_{img_names[i]}_2.png")
                    plt.close()
                    samples.append(img_names[i])
                    logging.info(f"Generated visualization for {img_names[i]}")
                if len(samples) >= num_samples:
                    break
    except Exception as e:
        logging.error(f"Error during visualization: {str(e)}")
        print(f"Error during visualization: {str(e)}")

In [None]:
print("Starting training process...")
logging.info("Starting training process")

model = MultiTaskModel()
# Unfreeze backbone for fine-tuning
for param in model.backbone.parameters():
    param.requires_grad = True

print("Training MultiTaskModel...")
logging.info("Training MultiTaskModel")
class_acc, dice_score = train_model(model, train_loader, test_loader, epochs=50, patience=5)

print("Generating visualizations...")
logging.info("Generating visualizations")
visualize_results(model, test_loader)

logging.info(f"Final MTL Results - Classification Accuracy: {class_acc:.4f}, Segmentation Dice: {dice_score:.4f}")
print(f"Final MTL Results - Classification Accuracy: {class_acc:.4f}, Segmentation Dice: {dice_score:.4f}")
print("Training completed. Check 'training_2.log' for detailed logs and 'result_*_2.png' for visualizations.")
logging.info("Training completed")

Starting training process...


Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth
100%|██████████| 13.6M/13.6M [00:00<00:00, 181MB/s]


Training MultiTaskModel...


Epoch 1/50: 100%|██████████| 59/59 [01:35<00:00,  1.62s/it, Class Loss=0.8755, Seg Loss=0.1384, Valid Masks=37]


Epoch 1/50 - Class Loss: 1.7896, Seg Loss: 1.2798, Valid Mask Batches: 37




Validation - Accuracy: 0.3010, Dice: 0.0829


Epoch 2/50: 100%|██████████| 59/59 [01:42<00:00,  1.74s/it, Class Loss=2.0226, Seg Loss=0.0, Valid Masks=38]


Epoch 2/50 - Class Loss: 2.2298, Seg Loss: 0.0941, Valid Mask Batches: 38




Validation - Accuracy: 0.1845, Dice: 0.0869


Epoch 3/50: 100%|██████████| 59/59 [01:43<00:00,  1.75s/it, Class Loss=1.8490, Seg Loss=0.0, Valid Masks=38]


Epoch 3/50 - Class Loss: 1.7545, Seg Loss: 0.2253, Valid Mask Batches: 38




Validation - Accuracy: 0.2233, Dice: 0.0820


Epoch 4/50: 100%|██████████| 59/59 [01:43<00:00,  1.75s/it, Class Loss=1.6643, Seg Loss=0.0, Valid Masks=41]


Epoch 4/50 - Class Loss: 1.6411, Seg Loss: 0.1435, Valid Mask Batches: 41




Validation - Accuracy: 0.3010, Dice: 0.1037


Epoch 5/50: 100%|██████████| 59/59 [01:40<00:00,  1.70s/it, Class Loss=1.3749, Seg Loss=0.0, Valid Masks=40]


Epoch 5/50 - Class Loss: 1.5993, Seg Loss: 0.1066, Valid Mask Batches: 40




Validation - Accuracy: 0.3107, Dice: 0.1094


Epoch 6/50: 100%|██████████| 59/59 [01:41<00:00,  1.72s/it, Class Loss=1.6561, Seg Loss=0.0690, Valid Masks=37]


Epoch 6/50 - Class Loss: 1.5521, Seg Loss: 0.0959, Valid Mask Batches: 37




Validation - Accuracy: 0.3204, Dice: 0.0950


Epoch 7/50: 100%|██████████| 59/59 [01:41<00:00,  1.72s/it, Class Loss=1.5436, Seg Loss=0.0, Valid Masks=36]


Epoch 7/50 - Class Loss: 1.5245, Seg Loss: 0.0828, Valid Mask Batches: 36




Validation - Accuracy: 0.2718, Dice: 0.1062


Epoch 8/50: 100%|██████████| 59/59 [01:41<00:00,  1.73s/it, Class Loss=1.3593, Seg Loss=0.0, Valid Masks=37]


Epoch 8/50 - Class Loss: 1.5331, Seg Loss: 0.0803, Valid Mask Batches: 37




Validation - Accuracy: 0.3204, Dice: 0.1006


Epoch 9/50: 100%|██████████| 59/59 [01:41<00:00,  1.72s/it, Class Loss=1.5837, Seg Loss=0.0, Valid Masks=37]


Epoch 9/50 - Class Loss: 1.5266, Seg Loss: 0.0782, Valid Mask Batches: 37




Validation - Accuracy: 0.2816, Dice: 0.1182


Epoch 10/50: 100%|██████████| 59/59 [01:39<00:00,  1.69s/it, Class Loss=1.3770, Seg Loss=0.0, Valid Masks=37]


Epoch 10/50 - Class Loss: 1.5891, Seg Loss: 0.0772, Valid Mask Batches: 37




Validation - Accuracy: 0.2913, Dice: 0.1185


Epoch 11/50: 100%|██████████| 59/59 [01:39<00:00,  1.69s/it, Class Loss=1.9096, Seg Loss=0.0, Valid Masks=33]


Epoch 11/50 - Class Loss: 1.5134, Seg Loss: 0.0663, Valid Mask Batches: 33




Validation - Accuracy: 0.2816, Dice: 0.1086
Early stopping at epoch 11
Generating visualizations...
Final MTL Results - Classification Accuracy: 0.2816, Segmentation Dice: 0.1086
Training completed. Check 'training_2.log' for detailed logs and 'result_*_2.png' for visualizations.
