## Step 1: Install Required Packages

In [None]:
!pip install torch torchvision numpy matplotlib seaborn scikit-learn Pillow tqdm -q

## Step 2: Import Libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision import transforms, models
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
import os
import time
import copy
from tqdm import tqdm

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Step 3: Configuration

In [None]:
# Dataset path
DATASET_ROOT = r"D:\AMDNet23 Fundus Image Dataset for  Age-Related Macular Degeneration Disease Detection\AMDNet23 Fundus Image Dataset for  Age-Related Macular Degeneration Disease Detection\AMDNet23 Dataset"
TRAIN_DIR = os.path.join(DATASET_ROOT, "train")
VALID_DIR = os.path.join(DATASET_ROOT, "valid")

# Hyperparameters
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 30
LEARNING_RATE = 0.0001
NUM_CLASSES = 4
CLASS_NAMES = ['amd', 'cataract', 'diabetes', 'normal']

print(f"Train directory: {TRAIN_DIR}")
print(f"Valid directory: {VALID_DIR}")

## Step 4: Create Dataset Class

In [None]:
class AMDDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_to_idx = {name: idx for idx, name in enumerate(CLASS_NAMES)}
        
        for class_name in CLASS_NAMES:
            class_dir = os.path.join(root_dir, class_name)
            if os.path.exists(class_dir):
                for img_name in os.listdir(class_dir):
                    if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
                        self.images.append(os.path.join(class_dir, img_name))
                        self.labels.append(self.class_to_idx[class_name])
        
        print(f"Loaded {len(self.images)} images from {root_dir}")
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

## Step 5: Define Transforms and Create DataLoaders

In [None]:
# Training transforms with augmentation
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Validation transforms (no augmentation)
val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Create datasets
train_dataset = AMDDataset(TRAIN_DIR, transform=train_transform)
val_dataset = AMDDataset(VALID_DIR, transform=val_transform)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"\nTraining samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

## Step 6: Visualize Sample Images

In [None]:
# Get a batch of images
images, labels = next(iter(train_loader))

# Denormalize for visualization
def denormalize(tensor):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    return tensor * std + mean

# Plot
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
for i, ax in enumerate(axes.flat):
    if i < len(images):
        img = denormalize(images[i]).permute(1, 2, 0).numpy()
        img = np.clip(img, 0, 1)
        ax.imshow(img)
        ax.set_title(CLASS_NAMES[labels[i]])
        ax.axis('off')
plt.tight_layout()
plt.show()

## Step 7: Create the Model

In [None]:
# Load pretrained ResNet50
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)

# Modify the final layer for our classes
num_features = model.fc.in_features
model.fc = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(num_features, 512),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512, NUM_CLASSES)
)

model = model.to(device)

# Print model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## Step 8: Define Loss Function and Optimizer

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

## Step 9: Training Function

In [None]:
def train_epoch(model, loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in tqdm(loader, desc="Training"):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    return running_loss / total, correct / total

def validate(model, loader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Validating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    return running_loss / total, correct / total, all_preds, all_labels

## Step 10: Train the Model

In [None]:
# Training history
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
best_acc = 0.0
best_model_wts = None

print("Starting training...\n")

for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    print("-" * 40)
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer)
    
    # Validate
    val_loss, val_acc, _, _ = validate(model, val_loader, criterion)
    
    # Update scheduler
    scheduler.step(val_loss)
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
    
    # Save best model
    if val_acc > best_acc:
        best_acc = val_acc
        best_model_wts = copy.deepcopy(model.state_dict())
        print(f"✓ New best model saved! Accuracy: {val_acc:.4f}")
    
    print()

print(f"Training complete! Best validation accuracy: {best_acc:.4f}")

In [None]:
# Continue training from saved model
checkpoint = torch.load('outputs/models/amd_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
best_acc = checkpoint['val_acc']
best_model_wts = model.state_dict()

# Initialize history for additional epochs
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
ADDITIONAL_EPOCHS = 30  # Train for 30 more epochs

print(f"Continuing training from validation accuracy: {best_acc:.4f}")
print(f"Training for {ADDITIONAL_EPOCHS} more epochs...\n")

for epoch in range(ADDITIONAL_EPOCHS):
    print(f"Epoch {epoch+1}/{ADDITIONAL_EPOCHS}")
    print("-" * 40)
    
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer)
    val_loss, val_acc, _, _ = validate(model, val_loader, criterion)
    scheduler.step(val_loss)
    
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
    
    if val_acc > best_acc:
        best_acc = val_acc
        best_model_wts = copy.deepcopy(model.state_dict())
        torch.save({
            'model_state_dict': model.state_dict(),
            'class_names': CLASS_NAMES,
            'val_acc': best_acc
        }, 'outputs/models/amd_model.pth')
        print(f"✓ New best model saved! Accuracy: {val_acc:.4f}")
    
    print()

print(f"Continued training complete! Best validation accuracy: {best_acc:.4f}")

## Step 10B: Continue Training (Optional)

**Run this to continue training from a saved model for more epochs**

## Step 11: Plot Training History

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Loss
ax1.plot(history['train_loss'], label='Train Loss', linewidth=2)
ax1.plot(history['val_loss'], label='Val Loss', linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Accuracy
ax2.plot(history['train_acc'], label='Train Acc', linewidth=2)
ax2.plot(history['val_acc'], label='Val Acc', linewidth=2)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Training and Validation Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Step 12: Evaluate Best Model

In [None]:
# Load best model
model.load_state_dict(best_model_wts)

# Final evaluation
val_loss, val_acc, all_preds, all_labels = validate(model, val_loader, criterion)

print(f"\nFinal Validation Accuracy: {val_acc:.4f}")
print(f"Final Validation Loss: {val_loss:.4f}")

# Classification report
print("\nClassification Report:")
print(classification_report(all_labels, all_preds, target_names=CLASS_NAMES))

## Step 13: Confusion Matrix

In [None]:
# Compute confusion matrix
cm = confusion_matrix(all_labels, all_preds)

# Plot
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
plt.title('Confusion Matrix', fontsize=14, fontweight='bold')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.show()

## Step 14: Save the Model

In [None]:
# Create output directory
os.makedirs('outputs/models', exist_ok=True)

# Save model
torch.save({
    'model_state_dict': model.state_dict(),
    'class_names': CLASS_NAMES,
    'val_acc': best_acc
}, 'outputs/models/amd_model.pth')

print("Model saved to outputs/models/amd_model.pth")

In [None]:
# Load the saved model (skip Step 10 if using this)
checkpoint = torch.load('outputs/models/amd_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
best_acc = checkpoint['val_acc']
best_model_wts = model.state_dict()

print(f"✓ Model loaded successfully!")
print(f"Validation accuracy: {best_acc:.4f}")
print("\nYou can now skip to Step 12 for evaluation or Step 15 for predictions")

## Alternative: Load Previously Trained Model

**Run this cell instead of Step 10 if you already have a trained model saved**

## Step 15: Test Prediction on Single Image

In [None]:
def predict_image(image_path):
    """Predict class for a single image"""
    # Load and transform image
    image = Image.open(image_path).convert('RGB')
    image_tensor = val_transform(image).unsqueeze(0).to(device)
    
    # Predict
    model.eval()
    with torch.no_grad():
        outputs = model(image_tensor)
        probs = torch.softmax(outputs, dim=1)
        conf, pred = probs.max(1)
    
    predicted_class = CLASS_NAMES[pred.item()]
    confidence = conf.item() * 100
    
    # Display
    plt.figure(figsize=(8, 6))
    plt.imshow(image)
    plt.title(f"Prediction: {predicted_class.upper()}\nConfidence: {confidence:.2f}%", fontsize=14)
    plt.axis('off')
    plt.show()
    
    # Show all probabilities
    print("\nClass Probabilities:")
    for i, name in enumerate(CLASS_NAMES):
        print(f"  {name}: {probs[0][i].item()*100:.2f}%")
    
    return predicted_class, confidence

# Example usage (uncomment and modify path):
# predict_image(r"path\to\your\image.jpg")

## Step 16: Visualize Predictions on Validation Set

In [None]:
# Get some validation images
model.eval()
images, labels = next(iter(val_loader))
images, labels = images.to(device), labels.to(device)

with torch.no_grad():
    outputs = model(images)
    _, preds = outputs.max(1)

# Plot
fig, axes = plt.subplots(3, 4, figsize=(14, 10))
for i, ax in enumerate(axes.flat):
    if i < len(images):
        img = denormalize(images[i].cpu()).permute(1, 2, 0).numpy()
        img = np.clip(img, 0, 1)
        ax.imshow(img)
        
        true_label = CLASS_NAMES[labels[i]]
        pred_label = CLASS_NAMES[preds[i]]
        color = 'green' if labels[i] == preds[i] else 'red'
        ax.set_title(f"True: {true_label}\nPred: {pred_label}", color=color, fontweight='bold')
        ax.axis('off')

plt.tight_layout()
plt.show()