In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T
import torchvision.models as models
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchmetrics
from PIL import Image
import os
# from albumentations import Compose, ... # Import augmentations from albumentations
# from albumentations.pytorch import ToTensorV2

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# --- 1. Classification Dataset ---
class ClassificationDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.classes = sorted(os.listdir(root_dir))
        self.class_to_idx = {c: i for i, c in enumerate(self.classes)}
        self.image_paths = []
        self.labels = []
        for class_name in self.classes:
            class_path = os.path.join(root_dir, class_name)
            for img_name in os.listdir(class_path):
                self.image_paths.append(os.path.join(class_path, img_name))
                self.labels.append(self.class_to_idx[class_name])
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert('RGB') # Ensure consistent color channels
        if self.transform:
            image = self.transform(image)
        return image, label

# --- 2. Segmentation Dataset ---
class SegmentationDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.img_files = sorted([f for f in os.listdir(img_dir) if f.endswith(('.png', '.jpg'))])
        self.mask_files = sorted([f for f in os.listdir(mask_dir) if f.endswith(('.png', '.tif'))]) # Adjust mask extensions
        self.transform = transform

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_files[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_files[idx])
        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L') # Assuming single channel mask

        if self.transform:
            transformed = self.transform(image=np.array(image), mask=np.array(mask))
            image = transformed['image']
            mask = transformed['mask']

        return image, mask / 255.0 # Normalize mask to [0, 1] if needed

# --- 3. Classification Model Training ---
def train_classification_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device):
    for epoch in range(num_epochs):
        model.train()
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        scheduler.step(val_loss) # For ReduceLROnPlateau

        model.eval()
        val_loss = 0.0
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * images.size(0)
                _, preds = torch.max(outputs, 1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        val_loss /= len(val_loader.dataset)
        accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=len(dataset.classes)).to(device)
        print(f'Epoch {epoch+1}, Val Loss: {val_loss:.4f}, Val Acc: {accuracy(torch.tensor(all_preds).to(device), torch.tensor(all_labels).to(device)):.4f}')
    return model

# --- 4. Segmentation Model Fine-tuning ---
def train_segmentation_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device):
    for epoch in range(num_epochs):
        model.train()
        for images, masks in train_loader:
            images = images.to(device)
            masks = masks.to(device).unsqueeze(1) # Assuming binary masks, adjust for multi-class
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
        scheduler.step(val_loss) # For ReduceLROnPlateau

        model.eval()
        val_loss = 0.0
        dice_metric = torchmetrics.Dice(num_classes=1 if num_classes == 2 else num_classes, average='macro').to(device)
        with torch.no_grad():
            for images, masks in val_loader:
                images = images.to(device)
                masks = masks.to(device).unsqueeze(1)
                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item() * images.size(0)
                # Assuming outputs are logits, apply sigmoid if using BCEWithLogitsLoss
                if isinstance(criterion, nn.BCEWithLogitsLoss):
                    preds = torch.sigmoid(outputs) > 0.5
                else:
                    preds = torch.argmax(outputs, dim=1) # For CrossEntropy or similar

                # Ensure masks are in the correct format for metrics
                dice_metric.update(preds, masks.squeeze(1).long()) # Adjust dimensions as needed
                iou_metric.update(preds, masks.squeeze(1).long())

        val_loss /= len(val_loader.dataset)
        val_dice = dice_metric.compute()
        val_iou = iou_metric.compute()
        print(f'Epoch {epoch+1}, Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}, Val Dice: {val_dice:.4f}, Val IoU: {val_iou:.4f}')
        if isinstance(scheduler, ReduceLROnPlateau):
            scheduler.step(val_loss)
        elif scheduler is not None:
            scheduler.step()
    return model

if __name__ == '__main__':
    # --- Hyperparameters and Setup ---
    classification_data_dir = 'RetinalOCT_Dataset' # <--- CHANGE THIS
    segmentation_img_dir = 'Segementation data/images'   # <--- CHANGE THIS
    segmentation_mask_dir = 'Segementation data/masks'     # <--- CHANGE THIS
    batch_size = 32
    learning_rate_classification = 1e-3
    learning_rate_segmentation = 1e-4
    num_epochs_classification = 10
    num_epochs_segmentation = 20
    weight_decay = 1e-4
    num_segmentation_classes = 8 + 1 # Number of disease classes + background (if applicable) <--- CHANGE THIS

    # Check for CUDA availability
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # --- 1. Classification DataLoaders ---
    classification_transform_train = T.Compose([
        T.RandomResizedCrop(224),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet stats
    ])
    classification_transform_val = T.Compose([
        T.Resize(256),
        T.CenterCrop(224),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    dataset = ClassificationDataset(root_dir=os.path.join(classification_data_dir, 'train'), transform=classification_transform_train)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    train_loader_classification = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader_classification = DataLoader(val_dataset, batch_size=batch_size, num_workers=4)

    # --- 2. Initialize Classification Model and Train ---
    classification_model = models.resnet50(pretrained=True) # Or any other chosen architecture
    num_classes = len(dataset.classes)
    classification_model.fc = nn.Linear(classification_model.fc.in_features, num_classes)
    classification_model.to(device)

    criterion_classification = nn.CrossEntropyLoss()
    optimizer_classification = optim.AdamW(classification_model.parameters(), lr=learning_rate_classification, weight_decay=weight_decay)
    scheduler_classification = ReduceLROnPlateau(optimizer_classification, mode='min', patience=3, factor=0.1)

    print("Training classification model...")

    trained_classification_model = train_classification_model(
        classification_model, train_loader_classification, val_loader_classification,
        criterion_classification, optimizer_classification, scheduler_classification,
        num_epochs_classification, device
    )

    # Save the trained classification model weights
    torch.save(trained_classification_model.state_dict(), 'pretrained_classification_model.pth')

    print("Classification model training complete.")

    # --- 3. Segmentation DataLoaders ---
    segmentation_transform_train = None # Define your segmentation augmentations using torchvision or albumentations <--- CHANGE THIS
    segmentation_transform_val = None   # Define your segmentation validation augmentations <--- CHANGE THIS

    masked_dataset = SegmentationDataset(
        img_dir=segmentation_img_dir,
        mask_dir=segmentation_mask_dir,
        transform=segmentation_transform_train # Apply augmentations here if using albumentations
    )

    print("Training segmentation model...")

    train_size_segmentation = int(0.8 * len(masked_dataset))
    val_size_segmentation = len(masked_dataset) - train_size_segmentation
    train_dataset_segmentation, val_dataset_segmentation = random_split(masked_dataset, [train_size_segmentation, val_size_segmentation], generator=torch.Generator().manual_seed(42))
    train_loader_segmentation = DataLoader(train_dataset_segmentation, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader_segmentation = DataLoader(val_dataset_segmentation, batch_size=batch_size, num_workers=4)

    # --- 4. Initialize Segmentation Model and Fine-tune ---
    # Example using segmentation_models_pytorch U-Net with ResNet50 encoder
    segmentation_model = smp.Unet(
        encoder_name="resnet50",
        encoder_weights=None, # We will load our pretrained weights
        in_channels=3,
        classes=num_segmentation_classes
    )

    print("Loading pretrained classification model weights...")

    # Load pretrained weights into the encoder
    pretrained_dict = torch.load('pretrained_classification_model.pth')
    model_dict = segmentation_model.encoder.state_dict()

    # Filter out unnecessary keys (e.g., fc layer weights)
    pretrained_dict_filtered = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.size() == model_dict[k].size()}
    model_dict.update(pretrained_dict_filtered)
    segmentation_model.encoder.load_state_dict(model_dict, strict=False)

    # Freeze encoder layers (optional, can also fine-tune with a smaller lr)
    for name, param in segmentation_model.encoder.named_parameters():
        param.requires_grad = False

    segmentation_model.to(device)

    # Define segmentation loss and optimizer
    criterion_segmentation = nn.CrossEntropyLoss() if num_segmentation_classes > 1 else nn.BCEWithLogitsLoss() # <--- ADJUST LOSS BASED ON TASK
    optimizer_segmentation = optim.AdamW(segmentation_model.parameters(), lr=learning_rate_segmentation, weight_decay=weight_decay)
    scheduler_segmentation = ReduceLROnPlateau(optimizer_segmentation, mode='min', patience=5, factor=0.1)

    print("Training segmentation model...")

    # Train the segmentation model
    trained_segmentation_model = train_segmentation_model(
        segmentation_model, train_loader_segmentation, val_loader_segmentation,
        criterion_segmentation, optimizer_segmentation, scheduler_segmentation,
        num_epochs_segmentation, device, num_segmentation_classes
    )

    # Save the trained segmentation model
    torch.save(trained_segmentation_model.state_dict(), 'trained_segmentation_model.pth')

    print("Training completed!")

Using device: cuda




Training classification model...
