In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:

"""
GSoC 2025 Internship Application Task - 1
Author: Dhruv Srivastava
"""

"""Import dependencies"""
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet18
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt

In [None]:
"""Define Dataset Class"""
class MyDataset(Dataset):
    def __init__(self, data_dir, transform=None):

        self.data = []
        self.labels = []
        self.class_names = ['axion', 'cdm', 'no_sub']

        print(f"Loading dataset from: {data_dir}")
        print(f"Looking for classes: {self.class_names}")

        for idx, class_name in enumerate(self.class_names):
            class_dir = os.path.join(data_dir, class_name)
            print(f"Processing class: {class_name} (index: {idx})")

            # Check if directory exists
            if not os.path.exists(class_dir):
                print(f"[ERROR] Directory not found: {class_dir}")
                continue

            files = os.listdir(class_dir)
            print(f"Found {len(files)} files in {class_name} directory")

            for file_name in files:
                if file_name.endswith('.npy'):
                    file_path = os.path.join(class_dir, file_name)

                    # Load the numpy file
                    loaded_data = np.load(file_path, allow_pickle=True)

                    # Handle different data structures based on class name
                    if class_name == 'axion':
                        image = loaded_data[0]
                        # Debug image loading for axion
                        print(f"Loading image: {file_name}")
                        print(f"Original loaded shape for axion: {loaded_data.shape}, Extracted image shape: {image.shape}")
                    else: # For 'cdm' and 'no_sub'
                        image = loaded_data
                        # Debug image loading for cdm/no_sub
                        print(f"Loading image: {file_name}")
                        print(f"Image shape for {class_name}: {image.shape}")

                    # Ensure all images are 3-channel (RGB-like)
                    if len(image.shape) == 2:
                        image = np.stack([image]*3, axis=0)
                        print("Converted 2D image to 3-channel")
                    elif len(image.shape) == 3 and image.shape[0] == 1:
                        image = np.repeat(image, 3, axis=0)
                        print("Converted single-channel image to 3-channel")

                    self.data.append(torch.tensor(image, dtype=torch.float32))
                    self.labels.append(idx)

        print(f"Total images loaded: {len(self.data)}")
        print(f"Distribution of classes: {np.unique(self.labels, return_counts=True)}")

        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

In [None]:
"""Define Dataset Class for Vision Transformer with Debugging"""
class MyDatasetViT(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data = []
        self.labels = []
        self.class_names = ['axion', 'cdm', 'no_sub']
        self.transform = transform

        print(f"Loading dataset from: {data_dir}")
        print(f"Looking for classes: {self.class_names}")

        for idx, class_name in enumerate(self.class_names):
            class_dir = os.path.join(data_dir, class_name)
            print(f"--- Processing class: {class_name} ---")

            if not os.path.exists(class_dir):
                print(f"[ERROR] Directory not found: {class_dir}")
                continue

            files = os.listdir(class_dir)

            for file_name in files:
                if file_name.endswith('.npy'):
                    file_path = os.path.join(class_dir, file_name)
                    loaded_data = np.load(file_path, allow_pickle=True)

                    if class_name == 'axion':
                        image = loaded_data[0]
                    else:
                        image = loaded_data

                    # [DEBUG] Print the shape of the raw numpy array
                    print(f"  [DEBUG] Loaded '{file_name}'. Raw numpy shape: {image.shape}")

                    # Ensure the image is a 2D array (H, W) before adding channel dimension.
                    if image.ndim != 2:
                        image = np.squeeze(image)

                    # Convert to a float tensor and add a channel dimension -> [1, H, W]
                    image_tensor = torch.tensor(image, dtype=torch.float32).unsqueeze(0)

                    # [DEBUG] Print the shape of the final tensor being stored in the dataset
                    print(f"  [DEBUG] Storing tensor with final shape: {image_tensor.shape}\n")

                    self.data.append(image_tensor)
                    self.labels.append(idx)

        print("\n--- Dataset Loading Complete ---")
        print(f"Total images loaded: {len(self.data)}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        """
        This method is called by the DataLoader to get one item from the dataset.
        The debug prints here are CRITICAL for finding the error.
        """
        #print(f"--- Getting item index: {idx} ---")

        # Retrieve the pre-loaded tensor and its label
        image = self.data[idx]
        label = self.labels[idx]

        # [DEBUG] Print shape BEFORE the transform is applied
        #print(f"  [DEBUG] Shape of tensor BEFORE transform: {image.shape}")

        # Apply transformations (e.g., resizing) if they are provided
        if self.transform:
            image = self.transform(image)
            # [DEBUG] Print shape AFTER the transform is applied
            #print(f"  [DEBUG] Shape of tensor AFTER transform: {image.shape}")
        else:
            #print("  [DEBUG] No transform was applied.")
            pass

        return image, label

In [None]:
# Hyperparameters
batch_size = 32
learning_rate = 0.001
num_epochs = 500

# Data Directories
train_dir = '/content/drive/MyDrive/Model_III_dataset/Model_III'
#val_dir = '../dataset/dataset/val'

print(f"Training Directory: {train_dir}")
#print(f"Validation Directory: {val_dir}")

# Create Datasets and Dataloaders
#train_dataset = MyDataset(train_dir)
#val_dataset = MyDataset(val_dir)
#dataset = MyDataset(train_dir)
#train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [0.85, 0.15, 0.0])

#import data loaders from file
train_loader = torch.load('/content/drive/MyDrive/train_loader.pth', weights_only=False)
val_loader = torch.load('/content/drive/MyDrive/val_loader.pth', weights_only=False)

print(f"Batch Size: {batch_size}")
print(f"Number of Training Batches: {len(train_loader)}")
print(f"Number of Validation Batches: {len(val_loader)}")

Training Directory: /content/drive/MyDrive/Model_III_dataset/Model_III
Batch Size: 32
Number of Training Batches: 2096
Number of Validation Batches: 420


In [None]:
# Modified ResNet18 for Lens Classification
class Net(nn.Module):
    def __init__(self, num_classes=3):
        super(Net, self).__init__()

        print("Initializing Modified ResNet18")

        # Load ResNet18
        resnet = resnet18(pretrained=True)

        # Modify first conv layer to accept single-channel input
        resnet.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)

        # Replace the last layer
        num_features = resnet.fc.in_features
        resnet.fc = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

        self.model = resnet

        print(f"Model architecture: {self.model}")

    def forward(self, x):
        return self.model(x)

In [None]:
"""Training and Evaluation"""
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=50):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Training on device: {device}")

    model.to(device)

    best_val_accuracy = 0.0

    for epoch in range(num_epochs):
        print(f"\n===== Epoch {epoch+1}/{num_epochs} =====")

        # Training Phase
        model.train()
        train_loss = 0.0
        train_correct = 0

        for batch_idx, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)

            # Debug information
            #print(f"Training Batch {batch_idx+1}/{len(train_loader)}")
            #print(f"Batch images shape: {images.shape}")
            #print(f"Batch labels: {labels}")

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            train_correct += (predicted == labels).sum().item()


            batch_accuracy = (predicted == labels).float().mean().item()
            #print(f"Batch Loss: {loss.item():.4f}, Batch Accuracy: {batch_accuracy:.4f}")

        # Validation Phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for batch_idx, (images, labels) in enumerate(val_loader):
                images, labels = images.to(device), labels.to(device)

                # Debug validation information
                #print(f"Validation Batch {batch_idx+1}/{len(val_loader)}")
                #print(f"Batch images shape: {images.shape}")
                #print(f"Batch labels: {labels}")

                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_correct += (predicted == labels).sum().item()


                probs = torch.softmax(outputs, dim=1)
                all_preds.extend(probs.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())


                batch_accuracy = (predicted == labels).float().mean().item()
               #print(f"Validation Batch Loss: {loss.item():.4f}, Validation Batch Accuracy: {batch_accuracy:.4f}")


        train_accuracy = train_correct / len(train_loader.dataset)
        val_accuracy = val_correct / len(val_loader.dataset)

        # Epoch-level metrics
        print(f'\n[SUMMARY] Epoch {epoch+1}/{num_epochs}:')
        print(f'Train Loss: {train_loss/len(train_loader):.4f}, Train Accuracy: {train_accuracy:.4f}')
        print(f'Val Loss: {val_loss/len(val_loader):.4f}, Val Accuracy: {val_accuracy:.4f}')

        # Save best model
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save(model.state_dict(), '/content/drive/MyDrive/lens_classifier_model.pth')
            print(f"New best model saved with validation accuracy: {best_val_accuracy:.4f}")

    print("\nTraining Complete!")
    return all_preds, all_labels

In [None]:
from torch.optim.lr_scheduler import CosineAnnealingLR
# Initialize Model
model = Net(num_classes=3)


criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)

print("Optimizer: Adam")
print(f"Learning Rate: {learning_rate}")


# Train Model
all_preds, all_labels = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, 27)

Initializing Modified ResNet18
Model architecture: ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runn

In [None]:
""" ROC Curve Plotting Function"""
def plot_roc_curve(all_preds, all_labels):
    print("Generating ROC Curve")

    #Load model from file
    model = Net(num_classes=3)
    model.load_state_dict(torch.load('/content/drive/MyDrive/lens_classifier_model.pth'))

    # Convert predictions and labels to numpy arrays
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    n_classes = 3

    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve((all_labels == i).astype(int), all_preds[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

        print(f"Class {i} ROC AUC: {roc_auc[i]:.4f}")

    # Plot ROC curves
    plt.figure(figsize=(10, 8))
    colors = ['blue', 'red', 'green']
    class_names = ['No Substructure', 'Sphere Substructure', 'Vortex Substructure']

    for i, color in zip(range(n_classes), colors):
        plt.plot(fpr[i], tpr[i], color=color,
                 label=f'{class_names[i]} (AUC = {roc_auc[i]:.2f})')

    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right")
    plt.savefig('/content/drive/MyDrive/roc_curve.png')
    plt.close()

    print("ROC Curve saved as roc_curve.png")


plot_roc_curve(all_preds, all_labels)

print("Training and Evaluation Complete!")

Generating ROC Curve
Initializing Modified ResNet18




Model architecture: ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): R