In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
from pathlib import Path

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=4):  # 4 classes: background, haemorrhages, hard exudates, microaneurysm
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes

        # Encoder
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(64, 128)
        )
        self.down2 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(128, 256)
        )
        self.down3 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(256, 512)
        )
        self.down4 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(512, 1024)
        )

        # Decoder
        self.up1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.up_conv1 = DoubleConv(1024, 512)
        
        self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.up_conv2 = DoubleConv(512, 256)
        
        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.up_conv3 = DoubleConv(256, 128)
        
        self.up4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.up_conv4 = DoubleConv(128, 64)
        
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        # Decoder
        x = self.up1(x5)
        x = torch.cat([x4, x], dim=1)
        x = self.up_conv1(x)
        
        x = self.up2(x)
        x = torch.cat([x3, x], dim=1)
        x = self.up_conv2(x)
        
        x = self.up3(x)
        x = torch.cat([x2, x], dim=1)
        x = self.up_conv3(x)
        
        x = self.up4(x)
        x = torch.cat([x1, x], dim=1)
        x = self.up_conv4(x)
        
        logits = self.outc(x)
        return logits

class RetinalMultiClassDataset(Dataset):
    def __init__(self, 
                 image_dir, 
                 haemorrhages_mask_dir, 
                 hard_exudates_mask_dir, 
                 microaneurysm_mask_dir,
                 transform=None):
        self.image_dir = Path(image_dir)
        self.haemorrhages_mask_dir = Path(haemorrhages_mask_dir)
        self.hard_exudates_mask_dir = Path(hard_exudates_mask_dir)
        self.microaneurysm_mask_dir = Path(microaneurysm_mask_dir)
        self.transform = transform
        
        # Get list of image files
        self.images = sorted([f for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))])
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = str(self.image_dir / img_name)
        image = Image.open(img_path).convert('RGB')
        
        # Load individual masks
        # Assuming mask filenames match image filenames
        haemorrhage_mask_path = str(self.haemorrhages_mask_dir / img_name)
        hard_exudate_mask_path = str(self.hard_exudates_mask_dir / img_name)
        microaneurysm_mask_path = str(self.microaneurysm_mask_dir / img_name)
        
        haemorrhage_mask = Image.open(haemorrhage_mask_path).convert('L')
        hard_exudate_mask = Image.open(hard_exudate_mask_path).convert('L')
        microaneurysm_mask = Image.open(microaneurysm_mask_path).convert('L')
        
        # Apply transforms to image and masks
        transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
        ])
        
        image = transform(image)
        
        # Transform masks and convert to binary (0 or 1)
        haemorrhage_mask = transform(haemorrhage_mask) > 0.5
        hard_exudate_mask = transform(hard_exudate_mask) > 0.5
        microaneurysm_mask = transform(microaneurysm_mask) > 0.5
        
        # Create a multi-class mask where:
        # 0: Background
        # 1: Haemorrhages
        # 2: Hard Exudates
        # 3: Microaneurysms
        multi_class_mask = torch.zeros((1, 256, 256), dtype=torch.long)
        
        # Set the values for each class
        # Priority: if pixels belong to multiple classes, choose one based on priority
        multi_class_mask[haemorrhage_mask] = 1
        multi_class_mask[hard_exudate_mask] = 2
        multi_class_mask[microaneurysm_mask] = 3
        
        return image, multi_class_mask.squeeze(0)

class RetinalSegmentation:
    def __init__(self, n_classes=4):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 
                                  'mps' if torch.backends.mps.is_available() else 'cpu')
        self.model = UNet(n_channels=3, n_classes=n_classes).to(self.device)
        self.n_classes = n_classes
        
    def train(self, train_loader, num_epochs=10):
        # Use Cross Entropy Loss for multi-class segmentation
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
        
        print(f"Training on {self.device}")
        
        for epoch in range(num_epochs):
            self.model.train()
            running_loss = 0.0
            
            for images, masks in train_loader:
                images = images.to(self.device)
                masks = masks.to(self.device)
                
                optimizer.zero_grad()
                outputs = self.model(images)  # [B, C, H, W]
                
                loss = criterion(outputs, masks)
                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()
            
            print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}')
            
            # Validate after each epoch
            if epoch % 2 == 0:
                self.evaluate_sample(train_loader)
    
    def predict(self, image):
        """Generate prediction for a single image"""
        self.model.eval()
        with torch.no_grad():
            image = image.to(self.device)
            output = self.model(image.unsqueeze(0))
            probabilities = F.softmax(output, dim=1)
            predicted_mask = torch.argmax(probabilities, dim=1)
            return predicted_mask.squeeze().cpu().numpy()
    
    def evaluate_sample(self, dataloader):
        """Visualize prediction on a sample from the dataset"""
        # Get a sample
        images, masks = next(iter(dataloader))
        image = images[0].to(self.device)
        mask = masks[0].cpu().numpy()
        
        # Generate prediction
        pred_mask = self.predict(image)
        
        # Visualize
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))
        
        # Original image
        axs[0].imshow(image.permute(1, 2, 0).cpu().numpy())
        axs[0].set_title('Original Image')
        axs[0].axis('off')
        
        # Ground truth mask
        colors = ['black', 'red', 'yellow', 'green']  # colors for different classes
        cmap = plt.matplotlib.colors.ListedColormap(colors)
        axs[1].imshow(mask, cmap=cmap, vmin=0, vmax=3)
        axs[1].set_title('Ground Truth')
        axs[1].axis('off')
        
        # Predicted mask
        axs[2].imshow(pred_mask, cmap=cmap, vmin=0, vmax=3)
        axs[2].set_title('Prediction')
        axs[2].axis('off')
        
        plt.tight_layout()
        plt.show()

# Example usage
def main():
    dataset = RetinalMultiClassDataset(
        image_dir='path/to/images',
        haemorrhages_mask_dir='path/to/haemorrhages_masks',
        hard_exudates_mask_dir='path/to/hard_exudates_masks',
        microaneurysm_mask_dir='path/to/microaneurysm_masks'
    )
    
    train_loader = DataLoader(
        dataset,
        batch_size=4,
        shuffle=True,
        num_workers=4
    )
    
    # Initialize and train the model
    segmentation = RetinalSegmentation(n_classes=4)
    segmentation.train(train_loader, num_epochs=20)
    
    # Save the trained model
    torch.save(segmentation.model.state_dict(), 'retinal_segmentation_model.pth')
    
if __name__ == "__main__":
    main()