In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
from pathlib import Path
import sys
import time

# Print CUDA information at the start
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device count: {torch.cuda.device_count()}")
    print(f"CUDA device name: {torch.cuda.get_device_name(0)}")
    print(f"Current CUDA device: {torch.cuda.current_device()}")

# -------------------------
# Swin Transformer Block
# -------------------------
# A simplified transformer block that mimics a Swin block
# (Note: This version uses global attention rather than window-based attention.)
class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, num_heads):
        super(SwinTransformerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )
        
    def forward(self, x):
        # x: (B, C, H, W) --> reshape to (B, N, C) where N = H*W
        B, C, H, W = x.shape
        x_flat = x.view(B, C, H * W).permute(0, 2, 1)  # (B, N, C)
        x_norm = self.norm1(x_flat)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm)
        x_flat = x_flat + attn_out  # residual connection
        x_norm2 = self.norm2(x_flat)
        mlp_out = self.mlp(x_norm2)
        x_flat = x_flat + mlp_out   # residual connection
        # Reshape back to (B, C, H, W)
        x_out = x_flat.permute(0, 2, 1).view(B, C, H, W)
        return x_out

# -------------------------
# SwinNet Architecture
# -------------------------
# A U-Net style architecture where the encoder uses a patch embedding
# and several SwinTransformerBlocks, with a decoder that upsamples and
# uses skip connections.
class SwinUnet(nn.Module):
    def __init__(self, img_size=256, patch_size=4, in_chans=3, num_classes=4, embed_dim=96):
        super(SwinUnet, self).__init__()
        # Patch embedding (downsamples by patch_size)
        self.patch_embed = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        # Encoder stage 1
        self.encoder1 = nn.Sequential(
            SwinTransformerBlock(embed_dim, num_heads=3),
            SwinTransformerBlock(embed_dim, num_heads=3)
        )
        self.down1 = nn.Conv2d(embed_dim, embed_dim * 2, kernel_size=2, stride=2)
        # Encoder stage 2
        self.encoder2 = nn.Sequential(
            SwinTransformerBlock(embed_dim * 2, num_heads=6),
            SwinTransformerBlock(embed_dim * 2, num_heads=6)
        )
        self.down2 = nn.Conv2d(embed_dim * 2, embed_dim * 4, kernel_size=2, stride=2)
        # Encoder stage 3
        self.encoder3 = nn.Sequential(
            SwinTransformerBlock(embed_dim * 4, num_heads=12),
            SwinTransformerBlock(embed_dim * 4, num_heads=12)
        )
        self.down3 = nn.Conv2d(embed_dim * 4, embed_dim * 8, kernel_size=2, stride=2)
        # Bottleneck (Encoder stage 4)
        self.encoder4 = nn.Sequential(
            SwinTransformerBlock(embed_dim * 8, num_heads=24),
            SwinTransformerBlock(embed_dim * 8, num_heads=24)
        )
        # Decoder stage 3
        self.up3 = nn.ConvTranspose2d(embed_dim * 8, embed_dim * 4, kernel_size=2, stride=2)
        self.decoder3 = nn.Sequential(
            nn.Conv2d(embed_dim * 8, embed_dim * 4, kernel_size=3, padding=1),
            nn.BatchNorm2d(embed_dim * 4),
            nn.ReLU(inplace=True)
        )
        # Decoder stage 2
        self.up2 = nn.ConvTranspose2d(embed_dim * 4, embed_dim * 2, kernel_size=2, stride=2)
        self.decoder2 = nn.Sequential(
            nn.Conv2d(embed_dim * 4, embed_dim * 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(embed_dim * 2),
            nn.ReLU(inplace=True)
        )
        # Decoder stage 1
        self.up1 = nn.ConvTranspose2d(embed_dim * 2, embed_dim, kernel_size=2, stride=2)
        self.decoder1 = nn.Sequential(
            nn.Conv2d(embed_dim * 2, embed_dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(embed_dim),
            nn.ReLU(inplace=True)
        )
        self.out_conv = nn.Conv2d(embed_dim, num_classes, kernel_size=1)
        
    def forward(self, x):
        # Encoder
        x0 = self.patch_embed(x)  # [B, embed_dim, 256/patch_size, 256/patch_size]
        e1 = self.encoder1(x0)      # skip connection 1
        x1 = self.down1(e1)         # downsample to [B, embed_dim*2, ...]
        e2 = self.encoder2(x1)      # skip connection 2
        x2 = self.down2(e2)         # downsample to [B, embed_dim*4, ...]
        e3 = self.encoder3(x2)      # skip connection 3
        x3 = self.down3(e3)         # downsample to [B, embed_dim*8, ...]
        e4 = self.encoder4(x3)      # bottleneck
        
        # Decoder
        d3 = self.up3(e4)           # upsample to match e3 spatial size
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.decoder3(d3)
        d2 = self.up2(d3)           # upsample to match e2 spatial size
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.decoder2(d2)
        d1 = self.up1(d2)           # upsample to match e1 spatial size
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.decoder1(d1)
        out = self.out_conv(d1)
        return out

# -------------------------
# Dataset Definition
# -------------------------
class RetinalMultiClassDataset(Dataset):
    def __init__(self, 
                 image_dir, 
                 haemorrhages_mask_dir, 
                 hard_exudates_mask_dir, 
                 microaneurysm_mask_dir,
                 transform=None):
        self.image_dir = Path(image_dir)
        self.haemorrhages_mask_dir = Path(haemorrhages_mask_dir)
        self.hard_exudates_mask_dir = Path(hard_exudates_mask_dir)
        self.microaneurysm_mask_dir = Path(microaneurysm_mask_dir)
        self.transform = transform
        
        for dir_path in [self.image_dir, self.haemorrhages_mask_dir, 
                         self.hard_exudates_mask_dir, self.microaneurysm_mask_dir]:
            if not dir_path.exists():
                print(f"WARNING: Directory does not exist: {dir_path}")
                print(f"Current working directory: {os.getcwd()}")
                print(f"Available directories: {os.listdir('.')}")
        
        try:
            self.images = sorted([f for f in os.listdir(image_dir) 
                                  if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
            print(f"Found {len(self.images)} images in {image_dir}")
            if len(self.images) == 0:
                print(f"WARNING: No images found in {image_dir}")
        except FileNotFoundError:
            print(f"ERROR: Directory not found: {image_dir}")
            self.images = []
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        try:
            img_name = self.images[idx]
            img_path = str(self.image_dir / img_name)
            if not os.path.exists(img_path):
                print(f"WARNING: Image not found: {img_path}")
                return torch.zeros((3, 256, 256)), torch.zeros((256, 256), dtype=torch.long)
            
            image = Image.open(img_path).convert('RGB')
            # Load masks (assumes filenames match)
            haemorrhage_mask_path = str(self.haemorrhages_mask_dir / img_name)
            hard_exudate_mask_path = str(self.hard_exudates_mask_dir / img_name)
            microaneurysm_mask_path = str(self.microaneurysm_mask_dir / img_name)
            
            if not os.path.exists(haemorrhage_mask_path):
                print(f"WARNING: Haemorrhage mask not found: {haemorrhage_mask_path}")
                haemorrhage_mask = Image.new('L', image.size, 0)
            else:
                haemorrhage_mask = Image.open(haemorrhage_mask_path).convert('L')
            
            if not os.path.exists(hard_exudate_mask_path):
                print(f"WARNING: Hard exudate mask not found: {hard_exudate_mask_path}")
                hard_exudate_mask = Image.new('L', image.size, 0)
            else:
                hard_exudate_mask = Image.open(hard_exudate_mask_path).convert('L')
            
            if not os.path.exists(microaneurysm_mask_path):
                print(f"WARNING: Microaneurysm mask not found: {microaneurysm_mask_path}")
                microaneurysm_mask = Image.new('L', image.size, 0)
            else:
                microaneurysm_mask = Image.open(microaneurysm_mask_path).convert('L')
            
            transform = transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.ToTensor(),
            ])
            image = transform(image)
            haemorrhage_mask = transform(haemorrhage_mask) > 0.5
            hard_exudate_mask = transform(hard_exudate_mask) > 0.5
            microaneurysm_mask = transform(microaneurysm_mask) > 0.5
            
            multi_class_mask = torch.zeros((1, 256, 256), dtype=torch.long)
            multi_class_mask[haemorrhage_mask] = 1
            multi_class_mask[hard_exudate_mask] = 2
            multi_class_mask[microaneurysm_mask] = 3
            
            return image, multi_class_mask.squeeze(0)
        
        except Exception as e:
            print(f"ERROR in __getitem__ for index {idx}, image {self.images[idx] if idx < len(self.images) else 'invalid index'}: {str(e)}")
            return torch.zeros((3, 256, 256)), torch.zeros((256, 256), dtype=torch.long)

# -------------------------
# Segmentation Training Class
# -------------------------
# Uses CrossEntropyLoss for multi-class segmentation.
class RetinalSegmentation:
    def __init__(self, n_classes=4):
        if torch.cuda.is_available():
            try:
                self.device = torch.device('cuda')
                test_tensor = torch.zeros(1, device=self.device)
                del test_tensor
                print("Successfully initialized CUDA device")
            except RuntimeError as e:
                print(f"CUDA error: {e}")
                print("Falling back to CPU")
                self.device = torch.device('cpu')
        elif torch.backends.mps.is_available():
            try:
                self.device = torch.device('mps')
                print("Using MPS (Metal Performance Shaders) device")
            except Exception:
                print("MPS initialization failed, falling back to CPU")
                self.device = torch.device('cpu')
        else:
            self.device = torch.device('cpu')
            print("Using CPU device")
        
        try:
            # Initialize the SwinUnet model
            self.model = SwinUnet(img_size=256, patch_size=4, in_chans=3, num_classes=n_classes, embed_dim=96).to(self.device)
            print(f"Model initialized and moved to {self.device}")
        except Exception as e:
            print(f"Error initializing model: {e}")
            sys.exit(1)
            
        self.n_classes = n_classes
        
    def train(self, train_loader, num_epochs=10):
        if len(train_loader) == 0:
            print("ERROR: DataLoader is empty. Cannot train on empty dataset.")
            return

        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
        
        print(f"Training on {self.device}")
        print(f"Training with {len(train_loader)} batches per epoch")
        
        epoch_losses = []
        for epoch in range(num_epochs):
            self.model.train()
            running_loss = 0.0
            start_time = time.time()
            print(f"Starting epoch {epoch+1}/{num_epochs}")
            
            for batch_idx, (images, masks) in enumerate(train_loader):
                if batch_idx % 5 == 0:
                    print(f"  Processing batch {batch_idx+1}/{len(train_loader)}")
                try:
                    images = images.to(self.device)
                    masks = masks.to(self.device)
                    
                    if torch.isnan(images).any() or torch.isnan(masks).any():
                        print(f"WARNING: NaN values detected in input data (batch {batch_idx+1})")
                        continue
                    
                    optimizer.zero_grad()
                    outputs = self.model(images)
                    if batch_idx == 0 and epoch == 0:
                        print(f"Input shape: {images.shape}")
                        print(f"Output shape: {outputs.shape}")
                        print(f"Target shape: {masks.shape}")
                    
                    loss = criterion(outputs, masks)
                    if torch.isnan(loss):
                        print(f"WARNING: NaN loss detected in batch {batch_idx+1}")
                        continue
                    
                    loss.backward()
                    optimizer.step()
                    running_loss += loss.item()
                    
                    if self.device.type == 'cuda':
                        torch.cuda.synchronize()
                except Exception as e:
                    print(f"ERROR in training loop (batch {batch_idx+1}): {str(e)}")
                    continue
            
            avg_loss = running_loss / len(train_loader)
            epoch_losses.append(avg_loss)
            epoch_time = time.time() - start_time
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}, Time: {epoch_time:.2f}s")
            
            if epoch % 2 == 0:
                try:
                    self.evaluate_sample(train_loader)
                except Exception as e:
                    print(f"Error in evaluation: {e}")
        
        try:
            plt.figure()
            plt.plot(range(1, num_epochs + 1), epoch_losses, marker='o')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.title('Training Loss Over Epochs')
            plt.grid(True)
            plt.savefig('training_loss.png')
            print("Training loss plot saved to 'training_loss.png'")
        except Exception as e:
            print(f"Error plotting training loss: {e}")
    
    def predict(self, image):
        self.model.eval()
        with torch.no_grad():
            try:
                image = image.to(self.device)
                output = self.model(image.unsqueeze(0))
                probabilities = F.softmax(output, dim=1)
                predicted_mask = torch.argmax(probabilities, dim=1)
                return predicted_mask.squeeze().cpu().numpy()
            except Exception as e:
                print(f"Error in prediction: {e}")
                return np.zeros((256, 256), dtype=np.int64)
    
    def evaluate_sample(self, dataloader):
        try:
            images, masks = next(iter(dataloader))
            image = images[0].to(self.device)
            mask = masks[0].cpu().numpy()
            pred_mask = self.predict(image)
            
            fig, axs = plt.subplots(1, 3, figsize=(15, 5))
            axs[0].imshow(image.permute(1, 2, 0).cpu().numpy())
            axs[0].set_title('Original Image')
            axs[0].axis('off')
            colors = ['black', 'red', 'yellow', 'green']
            cmap = plt.matplotlib.colors.ListedColormap(colors)
            axs[1].imshow(mask, cmap=cmap, vmin=0, vmax=3)
            axs[1].set_title('Ground Truth')
            axs[1].axis('off')
            axs[2].imshow(pred_mask, cmap=cmap, vmin=0, vmax=3)
            axs[2].set_title('Prediction')
            axs[2].axis('off')
            plt.tight_layout()
            plt.savefig(f"sample_evaluation_{time.strftime('%Y%m%d_%H%M%S')}.png")
            print("Evaluation sample saved as image file")
        except Exception as e:
            print(f"Error in evaluation: {e}")

def main():
    print("\n" + "="*50)
    print("Starting retinal segmentation training script using SwinNet architecture")
    print("="*50 + "\n")
    
    # Update these directory paths to match your actual data locations.
    image_dir = 'C:/Second_Sem/490/train_images'
    haemorrhages_mask_dir = 'C:/Second_Sem/490/APTOS 2019 Blindness Detection Segmented/Haemorrhages/train_images'
    hard_exudates_mask_dir = 'C:/Second_Sem/490/APTOS 2019 Blindness Detection Segmented/Hard Exudates/train_images'
    microaneurysm_mask_dir = 'C:/Second_Sem/490/APTOS 2019 Blindness Detection Segmented/Microaneurysm/train_images'
    
    for dir_path in [image_dir, haemorrhages_mask_dir, hard_exudates_mask_dir, microaneurysm_mask_dir]:
        if not os.path.exists(dir_path):
            print(f"WARNING: Directory does not exist: {dir_path}")
    
    print("Creating dataset...")
    dataset = RetinalMultiClassDataset(
        image_dir=image_dir,
        haemorrhages_mask_dir=haemorrhages_mask_dir,
        hard_exudates_mask_dir=hard_exudates_mask_dir,
        microaneurysm_mask_dir=microaneurysm_mask_dir
    )
    
    if len(dataset) == 0:
        print("ERROR: Dataset is empty. Please check your directory paths and image files.")
        return
    
    print(f"Dataset created with {len(dataset)} samples")
    print("Verifying dataset by loading first sample...")
    try:
        sample_img, sample_mask = dataset[0]
        print(f"Sample image shape: {sample_img.shape}")
        print(f"Sample mask shape: {sample_mask.shape}")
    except Exception as e:
        print(f"ERROR: Failed to load sample from dataset: {e}")
    
    print("Creating data loader...")
    train_loader = DataLoader(
        dataset,
        batch_size=4,
        shuffle=True,
        num_workers=0,  # Set to >0 for production runs
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    print("Initializing segmentation model...")
    segmentation = RetinalSegmentation(n_classes=4)
    
    print("Starting training...")
    segmentation.train(train_loader, num_epochs=20)
    
    try:
        torch.save(segmentation.model.state_dict(), 'retinal_segmentation_model_swin.pth')
        print("Model saved successfully to 'retinal_segmentation_model_swin.pth'")
    except Exception as e:
        print(f"ERROR: Failed to save model: {e}")
    
    print("\n" + "="*50)
    print("Training completed")
    print("="*50 + "\n")
    
if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        print(f"Unhandled exception in main: {e}")
        import traceback
        traceback.print_exc()


PyTorch version: 2.5.1+cu121
CUDA available: True
CUDA device count: 1
CUDA device name: NVIDIA GeForce RTX 4070 Ti SUPER
Current CUDA device: 0

Starting retinal segmentation training script using SwinNet architecture

Creating dataset...
Found 3662 images in C:/Second_Sem/490/train_images
Dataset created with 3662 samples
Verifying dataset by loading first sample...
Sample image shape: torch.Size([3, 256, 256])
Sample mask shape: torch.Size([256, 256])
Creating data loader...
Initializing segmentation model...
Successfully initialized CUDA device
Model initialized and moved to cuda
Starting training...
Training on cuda
Training with 916 batches per epoch
Starting epoch 1/20
  Processing batch 1/916
Input shape: torch.Size([4, 3, 256, 256])
Output shape: torch.Size([4, 4, 64, 64])
Target shape: torch.Size([4, 256, 256])
ERROR in training loop (batch 1): input and target batch or spatial sizes don't match: target [4, 256, 256], input [4, 4, 64, 64]
ERROR in training loop (batch 2): inp

KeyboardInterrupt: 