In [None]:
"""
GSoC 2025 Internship Application Task - 1
Author: Dhruv Srivastava
"""

"""Import dependencies"""
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet18
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt

In [None]:
"""Define Dataset Class"""
class MyDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        
        self.data = []
        self.labels = []
        self.class_names = ['no', 'sphere', 'vort']
        
        print(f"Loading dataset from: {data_dir}")
        print(f"Looking for classes: {self.class_names}")
        
        for idx, class_name in enumerate(self.class_names):
            class_dir = os.path.join(data_dir, class_name)
            print(f"Processing class: {class_name} (index: {idx})")
            
            # Check if directory exists
            if not os.path.exists(class_dir):
                print(f"[ERROR] Directory not found: {class_dir}")
                continue
            
            files = os.listdir(class_dir)
            print(f"Found {len(files)} files in {class_name} directory")
            
            for file_name in files:
                if file_name.endswith('.npy'):
                    file_path = os.path.join(class_dir, file_name)
                    image = np.load(file_path)
                    
                    # Debug image loading
                    print(f"Loading image: {file_name}")
                    print(f"Image shape: {image.shape}")
                    
                    # Ensure the image is 3-channel (RGB-like)
                    if len(image.shape) == 2:
                        image = np.stack([image]*3, axis=0)
                        print("Converted 2D image to 3-channel")
                    elif len(image.shape) == 3 and image.shape[0] == 1:
                        image = np.repeat(image, 3, axis=0)
                        print("Converted single-channel image to 3-channel")
                    
                    self.data.append(torch.tensor(image, dtype=torch.float32))
                    self.labels.append(idx)
        
        print(f"Total images loaded: {len(self.data)}")
        print(f"Distribution of classes: {np.unique(self.labels, return_counts=True)}")
        
        self.transform = transform
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image = self.data[idx]
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

In [None]:
# Hyperparameters
batch_size = 32
learning_rate = 0.001
num_epochs = 50
    
# Data Directories
train_dir = '../dataset/dataset/train'
val_dir = '../dataset/dataset/val'
    
print(f"Training Directory: {train_dir}")
print(f"Validation Directory: {val_dir}")
    
# Create Datasets and Dataloaders
train_dataset = MyDataset(train_dir)
val_dataset = MyDataset(val_dir)
    
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
    
print(f"Batch Size: {batch_size}")
print(f"Number of Training Batches: {len(train_loader)}")
print(f"Number of Validation Batches: {len(val_loader)}")

In [None]:
# Modified ResNet18 for Lens Classification
class Net(nn.Module):
    def __init__(self, num_classes=3):
        super(Net, self).__init__()
        
        print("Initializing Modified ResNet18")
        
        # Load ResNet18
        resnet = resnet18(pretrained=True)
        
        # Modify first conv layer to accept single-channel input
        resnet.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        
        # Replace the last layer
        num_features = resnet.fc.in_features
        resnet.fc = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
        
        self.model = resnet
        
        print(f"Model architecture: {self.model}")
    
    def forward(self, x):
        return self.model(x)

In [None]:
"""Training and Evaluation"""
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=50):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Training on device: {device}")
    
    model.to(device)
    
    best_val_accuracy = 0.0
    
    for epoch in range(num_epochs):
        print(f"\n===== Epoch {epoch+1}/{num_epochs} =====")
        
        # Training Phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        
        for batch_idx, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            
            # Debug information
            print(f"Training Batch {batch_idx+1}/{len(train_loader)}")
            print(f"Batch images shape: {images.shape}")
            print(f"Batch labels: {labels}")
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            train_correct += (predicted == labels).sum().item()
            
            
            batch_accuracy = (predicted == labels).float().mean().item()
            print(f"Batch Loss: {loss.item():.4f}, Batch Accuracy: {batch_accuracy:.4f}")
        
        # Validation Phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for batch_idx, (images, labels) in enumerate(val_loader):
                images, labels = images.to(device), labels.to(device)
                
                # Debug validation information
                print(f"Validation Batch {batch_idx+1}/{len(val_loader)}")
                print(f"Batch images shape: {images.shape}")
                print(f"Batch labels: {labels}")
                
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_correct += (predicted == labels).sum().item()
                
                
                probs = torch.softmax(outputs, dim=1)
                all_preds.extend(probs.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                
                
                batch_accuracy = (predicted == labels).float().mean().item()
                print(f"Validation Batch Loss: {loss.item():.4f}, Validation Batch Accuracy: {batch_accuracy:.4f}")
        
        
        train_accuracy = train_correct / len(train_loader.dataset)
        val_accuracy = val_correct / len(val_loader.dataset)
        
        # Epoch-level metrics
        print(f'\n[SUMMARY] Epoch {epoch+1}/{num_epochs}:')
        print(f'Train Loss: {train_loss/len(train_loader):.4f}, Train Accuracy: {train_accuracy:.4f}')
        print(f'Val Loss: {val_loss/len(val_loader):.4f}, Val Accuracy: {val_accuracy:.4f}')
        
        # Save best model
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save(model.state_dict(), 'lens_classifier_model.pth')
            print(f"New best model saved with validation accuracy: {best_val_accuracy:.4f}")
    
    print("\nTraining Complete!")
    return all_preds, all_labels

In [None]:
# Initialize Model
model = Net(num_classes=3)
    

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
    
print("Optimizer: Adam")
print(f"Learning Rate: {learning_rate}")
    
# Train Model
all_preds, all_labels = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs)

In [None]:
""" ROC Curve Plotting Function"""
def plot_roc_curve(all_preds, all_labels):
    print("Generating ROC Curve")
    
    # Convert predictions and labels to numpy arrays
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    n_classes = 3
    
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve((all_labels == i).astype(int), all_preds[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
        
        print(f"Class {i} ROC AUC: {roc_auc[i]:.4f}")
    
    # Plot ROC curves
    plt.figure(figsize=(10, 8))
    colors = ['blue', 'red', 'green']
    class_names = ['No Substructure', 'Sphere Substructure', 'Vortex Substructure']
    
    for i, color in zip(range(n_classes), colors):
        plt.plot(fpr[i], tpr[i], color=color, 
                 label=f'{class_names[i]} (AUC = {roc_auc[i]:.2f})')
    
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right")
    plt.savefig('roc_curve.png')
    plt.close()
    
    print("ROC Curve saved as roc_curve.png")


plot_roc_curve(all_preds, all_labels)
    
print("Training and Evaluation Complete!")