In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
from PIL import Image
import pandas as pd
import os
import time
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import numpy as np
import seaborn as sns

class RetinalDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            img_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data_frame = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform if transform is not None else transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.img_dir, self.data_frame.iloc[idx]['id_code'] + '.png')
        image = Image.open(img_name).convert('RGB')
        diagnosis = self.data_frame.iloc[idx]['diagnosis']
        
        if self.transform:
            image = self.transform(image)
            
        return image, torch.tensor(diagnosis, dtype=torch.long)

class RetinalEfficientNetClassifier:
    def __init__(self, num_classes=5, model_name='efficientnet_b0'):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        
        # Load pre-trained EfficientNet
        if model_name == 'efficientnet_b0':
            self.model = models.efficientnet_b0(weights='DEFAULT')
        elif model_name == 'efficientnet_b1':
            self.model = models.efficientnet_b1(weights='DEFAULT')
        elif model_name == 'efficientnet_b2':
            self.model = models.efficientnet_b2(weights='DEFAULT')
        else:
            raise ValueError(f"Unsupported model name: {model_name}")
        
        # Replace the final fully connected layer
        num_features = self.model.classifier[1].in_features
        self.model.classifier[1] = nn.Linear(num_features, num_classes)
        
        # Move model to device
        self.model = self.model.to(self.device)
        
        # Define transformation pipeline
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
        ])
        
        # Track metrics
        self.train_losses = []
        self.train_accuracies = []
        self.val_losses = []
        self.val_accuracies = []

    def train(self, train_loader, val_loader=None, num_epochs=10, learning_rate=0.0001):
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=learning_rate, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
        
        best_val_accuracy = 0.0
        start_time = time.time()
        
        for epoch in range(num_epochs):
            epoch_start = time.time()
            # Training phase
            self.model.train()
            running_loss = 0.0
            correct = 0
            total = 0
            
            for inputs, labels in train_loader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                
                optimizer.zero_grad()
                outputs = self.model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()
                
                # Calculate accuracy
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            
            epoch_loss = running_loss / len(train_loader)
            epoch_accuracy = 100 * correct / total
            self.train_losses.append(epoch_loss)
            self.train_accuracies.append(epoch_accuracy)
            
            # Validation phase
            if val_loader:
                val_loss, val_accuracy = self.evaluate(val_loader, criterion)
                self.val_losses.append(val_loss)
                self.val_accuracies.append(val_accuracy)
                
                # Save best model
                if val_accuracy > best_val_accuracy:
                    best_val_accuracy = val_accuracy
                    torch.save(self.model.state_dict(), 'best_efficientnet_model.pth')
                
                epoch_time = time.time() - epoch_start
                print(f'Epoch {epoch+1}/{num_epochs}, '
                      f'Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_accuracy:.2f}%, '
                      f'Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%, '
                      f'Time: {epoch_time:.2f}s')
            else:
                epoch_time = time.time() - epoch_start
                print(f'Epoch {epoch+1}/{num_epochs}, '
                      f'Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_accuracy:.2f}%, '
                      f'Time: {epoch_time:.2f}s')
            
            scheduler.step()
        
        total_time = time.time() - start_time
        print(f'Training completed in {total_time:.2f} seconds')
        
        # Plot training curves
        if len(self.train_losses) > 1:
            self.plot_training_curves()
            
        
        return best_val_accuracy

    def evaluate(self, test_loader, criterion=None):
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        
        if criterion is None:
            criterion = nn.CrossEntropyLoss()
        
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = self.model(inputs)
                
                loss = criterion(outputs, labels)
                running_loss += loss.item()
                
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        test_loss = running_loss / len(test_loader)
        test_accuracy = 100 * correct / total
        
        return test_loss, test_accuracy
    
    def get_all_predictions(self, data_loader):
        self.model.eval()
        all_predictions = []
        all_labels = []
        
        with torch.no_grad():
            for inputs, labels in data_loader:
                inputs = inputs.to(self.device)
                outputs = self.model(inputs)
                _, predicted = torch.max(outputs, 1)
                
                all_predictions.extend(predicted.cpu().numpy())
                all_labels.extend(labels.numpy())
        
        return np.array(all_predictions), np.array(all_labels)
    
    def plot_confusion_matrix(self, test_loader, class_names=None):
        predictions, true_labels = self.get_all_predictions(test_loader)
        cm = confusion_matrix(true_labels, predictions)
        
        if class_names is None:
            class_names = [str(i) for i in range(cm.shape[0])]
        
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title('Confusion Matrix')
        plt.savefig('confusion_matrix.png')
        plt.close()
        
        print(classification_report(true_labels, predictions, target_names=class_names))

    def plot_training_curves(self):
        plt.figure(figsize=(12, 5))
        
        plt.subplot(1, 2, 1)
        plt.plot(self.train_losses, label='Train Loss')
        if self.val_losses:
            plt.plot(self.val_losses, label='Validation Loss')
        plt.title('Loss Curves')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        
        plt.subplot(1, 2, 2)
        plt.plot(self.train_accuracies, label='Train Accuracy')
        if self.val_accuracies:
            plt.plot(self.val_accuracies, label='Validation Accuracy')
        plt.title('Accuracy Curves')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy (%)')
        plt.legend()
        
        plt.tight_layout()
        plt.savefig('training_curves.png')
        plt.close()

    def predict(self, image_path):
        self.model.eval()
        image = Image.open(image_path).convert('RGB')
        image = self.transform(image).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            outputs = self.model(image)
            _, predicted = torch.max(outputs, 1)
            
        return predicted.item()

# Run experiment
if __name__ == "__main__":
    # Define data augmentation for training
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Simple transforms for validation/test
    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Load dataset
    full_dataset = RetinalDataset(
        csv_file='train.csv', 
        img_dir='train_images',
        transform=None  # We'll apply transforms separately to each split
    )

    # Split dataset (70% train, 15% validation, 15% test)
    train_size = int(0.7 * len(full_dataset))
    val_size = int(0.15 * len(full_dataset))
    test_size = len(full_dataset) - train_size - val_size
    
    train_dataset, val_dataset, test_dataset = random_split(
        full_dataset, [train_size, val_size, test_size],
        generator=torch.Generator().manual_seed(42)  # For reproducibility
    )
    
    # Apply appropriate transforms to each dataset
    train_dataset.dataset.transform = train_transform
    val_dataset.dataset.transform = test_transform
    test_dataset.dataset.transform = test_transform

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

    # Set class names
    class_names = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR']

    # Initialize and train the classifier
    classifier = RetinalEfficientNetClassifier(num_classes=5, model_name='efficientnet_b0')
    
    # Train the model
    best_val_accuracy = classifier.train(train_loader, val_loader, num_epochs=15, learning_rate=3e-4)
    
    # Load best model
    classifier.model.load_state_dict(torch.load('best_efficientnet_model.pth'))
    
    # Evaluate on test set
    test_loss, test_accuracy = classifier.evaluate(test_loader)
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
    
    # Generate confusion matrix
    classifier.plot_confusion_matrix(test_loader, class_names)
    
    print(f"Best validation accuracy: {best_val_accuracy:.2f}%")
    print(f"Final test accuracy: {test_accuracy:.2f}%")
    print("Model saved as 'best_efficientnet_model.pth'")
    print("Training curves saved as 'training_curves.png'")
    print("Confusion matrix saved as 'confusion_matrix.png'")

Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to C:\Users\saadu/.cache\torch\hub\checkpoints\efficientnet_b0_rwightman-7f5810bc.pth


Using device: cuda


100%|██████████| 20.5M/20.5M [00:00<00:00, 80.7MB/s]


Epoch 1/15, Train Loss: 0.7048, Train Acc: 74.37%, Val Loss: 0.5401, Val Acc: 79.96%, Time: 188.95s
Epoch 2/15, Train Loss: 0.4690, Train Acc: 82.40%, Val Loss: 0.5011, Val Acc: 80.69%, Time: 189.17s
Epoch 3/15, Train Loss: 0.3639, Train Acc: 86.58%, Val Loss: 0.5146, Val Acc: 81.06%, Time: 189.20s
Epoch 4/15, Train Loss: 0.2607, Train Acc: 90.75%, Val Loss: 0.5765, Val Acc: 78.51%, Time: 189.32s
Epoch 5/15, Train Loss: 0.1872, Train Acc: 93.84%, Val Loss: 0.6645, Val Acc: 79.60%, Time: 190.38s
Epoch 6/15, Train Loss: 0.1396, Train Acc: 95.16%, Val Loss: 0.7104, Val Acc: 81.06%, Time: 189.21s
Epoch 7/15, Train Loss: 0.0867, Train Acc: 97.35%, Val Loss: 0.7513, Val Acc: 80.33%, Time: 189.20s
Epoch 8/15, Train Loss: 0.0717, Train Acc: 98.05%, Val Loss: 0.7590, Val Acc: 83.61%, Time: 189.54s
Epoch 9/15, Train Loss: 0.0697, Train Acc: 98.48%, Val Loss: 0.7291, Val Acc: 82.70%, Time: 189.85s
Epoch 10/15, Train Loss: 0.0466, Train Acc: 98.44%, Val Loss: 0.8032, Val Acc: 81.60%, Time: 189.72s

  classifier.model.load_state_dict(torch.load('best_efficientnet_model.pth'))


Test Loss: 0.7255, Test Accuracy: 84.00%
                  precision    recall  f1-score   support

           No DR       0.97      1.00      0.99       274
         Mild DR       0.78      0.51      0.61        69
     Moderate DR       0.70      0.87      0.78       142
       Severe DR       0.54      0.56      0.55        27
Proliferative DR       0.71      0.39      0.51        38

        accuracy                           0.84       550
       macro avg       0.74      0.67      0.69       550
    weighted avg       0.84      0.84      0.83       550

Best validation accuracy: 84.15%
Final test accuracy: 84.00%
Model saved as 'best_efficientnet_model.pth'
Training curves saved as 'training_curves.png'
Confusion matrix saved as 'confusion_matrix.png'
