In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
train_dir = "C:\\Users\\DTSC302\\Desktop\\Anannya\\gsoc\\dataset\\train"
val_dir = "C:\\Users\\DTSC302\\Desktop\\Anannya\\gsoc\\dataset\\val"

In [4]:
# Define classes
classes = ['no', 'sphere', 'vort']

# Data transforms
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485], std=[0.229])  # grayscale normalization
])

val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485], std=[0.229])
])


In [5]:
class LensDataset(Dataset):
    def __init__(self, data_dir, transform):
        self.samples = []
        self.transform = transform
        for label_idx, label in enumerate(classes):
            path = os.path.join(data_dir, label)
            for fname in os.listdir(path):
                if fname.endswith('.npy'):
                    self.samples.append((os.path.join(path, fname), label_idx))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        arr = np.load(path)[0]  # (150, 150) grayscale image
        img = Image.fromarray((arr * 255).astype(np.uint8))  # convert back to 0-255 range for PIL
        img = self.transform(img)
        return img, torch.tensor(label, dtype=torch.long)

In [6]:
batch_size = 64
train_dataset = LensDataset(train_dir, train_transform)
val_dataset = LensDataset(val_dir, val_transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [7]:
class TransferResNet(nn.Module):
    def __init__(self):
        super(TransferResNet, self).__init__()
        self.model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.model.fc = nn.Linear(self.model.fc.in_features, 3)

    def forward(self, x):
        return self.model(x)

model = TransferResNet().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

In [None]:
epochs = 10
for epoch in range(epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)

        outputs = model(imgs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * imgs.size(0)  # correct loss scaling
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

    avg_loss = total_loss / total
    accuracy = correct / total

    model.eval()
    val_loss = 0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * imgs.size(0)
            _, predicted = torch.max(outputs, 1)
            val_correct += (predicted == labels).sum().item()
            val_total += labels.size(0)

    val_avg_loss = val_loss / val_total
    val_accuracy = val_correct / val_total

    scheduler.step()

    print(f"Epoch [{epoch+1}/{epochs}]")
    print(f"  Train Loss: {avg_loss:.4f}, Train Acc: {accuracy:.4f}")
    print(f"  Val Loss:   {val_avg_loss:.4f}, Val Acc:   {val_accuracy:.4f}")

Epoch [1/10]
  Train Loss: 0.2835, Train Acc: 0.8913
  Val Loss:   0.2636, Val Acc:   0.9056
Epoch [2/10]
  Train Loss: 0.2680, Train Acc: 0.9011
  Val Loss:   0.3315, Val Acc:   0.8908
Epoch [3/10]
  Train Loss: 0.2655, Train Acc: 0.9013
  Val Loss:   0.3296, Val Acc:   0.8929
Epoch [4/10]
  Train Loss: 0.2541, Train Acc: 0.9054
  Val Loss:   0.2824, Val Acc:   0.8992
Epoch [5/10]
  Train Loss: 0.2357, Train Acc: 0.9122
  Val Loss:   0.2387, Val Acc:   0.9128
Epoch [6/10]
  Train Loss: 0.1948, Train Acc: 0.9290
  Val Loss:   0.1973, Val Acc:   0.9327
Epoch [7/10]
  Train Loss: 0.1776, Train Acc: 0.9349
  Val Loss:   0.1733, Val Acc:   0.9395
Epoch [8/10]
  Train Loss: 0.1793, Train Acc: 0.9353
  Val Loss:   0.1931, Val Acc:   0.9336
Epoch [9/10]
  Train Loss: 0.1687, Train Acc: 0.9384
  Val Loss:   0.1631, Val Acc:   0.9412


In [None]:
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_curve, auc, RocCurveDisplay

# Get number of classes
num_classes = len(classes)

# Evaluation
model.eval()
y_true, y_pred = [], []
y_scores = []

with torch.no_grad():
    for imgs, labels in val_loader:
        imgs = imgs.to(device)
        outputs = model(imgs)
        probs = torch.softmax(outputs, dim=1)

        _, predicted = torch.max(probs, 1)

        y_pred.extend(predicted.cpu().numpy())
        y_true.extend(labels.numpy())
        y_scores.extend(probs.cpu().numpy())

# Classification report
print("\nClassification Report:\n")
print(classification_report(y_true, y_pred, target_names=classes))

# Confusion matrix
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.tight_layout()
plt.show()

# ROC Curve
# Binarize the true labels for multi-class ROC
y_true_bin = label_binarize(y_true, classes=list(range(num_classes)))
y_scores = np.array(y_scores)

plt.figure(figsize=(8, 6))
for i in range(num_classes):
    fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_scores[:, i])
    roc_auc = auc(fpr, tpr)

    plt.plot(fpr, tpr, lw=2, label=f'{classes[i]} (AUC = {roc_auc:.2f})')

plt.plot([0, 1], [0, 1], 'k--', lw=2)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('One-vs-Rest ROC Curve')
plt.legend(loc='lower right')
plt.tight_layout()
plt.show()
