# Imports

In [1]:
import sys
import os
from PIL import Image
from glob import glob
import math
import numpy as np
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from tqdm import tqdm
from timm import create_model
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, roc_auc_score
from sklearn.preprocessing import label_binarize
from itertools import cycle
import seaborn as sns

# DataLoader

In [2]:
class NPYClassificationDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        """
        Args:
            file_paths (list): List of file paths to .npy files.
            labels (list): Corresponding labels for classification.
            transform (callable, optional): Optional transform to be applied on an image.
        """
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        label = self.labels[idx]

        img = np.load(self.file_paths[idx], allow_pickle=True)
        if label == 1:
            img = img[0]
        img = Image.fromarray(np.uint8(img * 255))

        # Apply transformations if provided
        if self.transform:
            img = self.transform(img)

        return img, torch.tensor(label, dtype=torch.long)

In [3]:
train_transforms = transforms.Compose([
    # transforms.CenterCrop(100),
    transforms.Resize(150, Image.LANCZOS),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=30),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

val_transforms = transforms.Compose([
        transforms.Resize(150, Image.LANCZOS),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
])

In [4]:
dataset_root = "/kaggle/input/deeplense/SpecificTest_06_A/Dataset"
axion_files = sorted(glob(os.path.join(dataset_root, "axion", "*.npy")))
no_sub_files = sorted(glob(os.path.join(dataset_root, "no_sub", "*.npy")))
cdm_files = sorted(glob(os.path.join(dataset_root, "cdm", "*.npy")))

all_files = no_sub_files + axion_files + cdm_files
labels = [0] * len(no_sub_files) + [1] * len(axion_files) + [2] * len(cdm_files)

# First split: 90% train, 10% val (stratified)
train_files, val_files, train_labels, val_labels = train_test_split(
    all_files, labels, test_size=0.1, stratify=labels, random_state=42
)

# Train MAE only on no_sub_train_files
batch_size=512
train_dataset = NPYClassificationDataset(train_files, train_labels, train_transforms)
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=4, shuffle=True)

# Validation set for MAE (later used in classification also)
val_dataset = NPYClassificationDataset(val_files, val_labels, val_transforms)
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=4, shuffle=False)

## Model Architecture

In [5]:
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}")

In [6]:
class EncoderViT(nn.Module):
    def __init__(self, base="tiny", p=0.25):
        super(EncoderViT, self).__init__()

        modelss = create_model(f"vit_{base}_patch16_224", pretrained=True)
        modelss.patch_embed = nn.Identity()
        modelss.head = nn.Identity() # now output shape is embed_dim (tiny: 192, base: 768)
        self.set_dropout(modelss, p)

        # self.model.pos_embed = nn.Identity()  # Bypass position encoding in timm
        # Override `_pos_embed()` so ViT doesn’t add its own position embedding
        # def forward_pos_embed(x):
        #     return x 
        # self.model._pos_embed = forward_pos_embed

        self.encoder_blocks = modelss.blocks
        self.norm = modelss.norm

    def set_dropout(self, model, p):
        """Recursively set dropout probability in a model."""
        for name, module in model.named_modules():
            if isinstance(module, nn.Dropout):
                module.p = p

    def forward(self, x):
        for block in self.encoder_blocks:
            x = block(x)

        x = self.norm(x)

        return x

class AttentionPooling(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.query = nn.Parameter(torch.randn(embed_dim))  # Learnable query vector
        self.attn_weights = nn.Linear(embed_dim, 1)  # Linear layer to compute scores

    def forward(self, x):
        """
        x: (batch_size, num_patches, embed_dim)
        Returns: (batch_size, embed_dim) - aggregated representation
        """
        scores = self.attn_weights(x).squeeze(-1)  # (batch_size, num_patches)
        attn_weights = torch.softmax(scores, dim=1)  # Normalize
        pooled = torch.sum(attn_weights.unsqueeze(-1) * x, dim=1)
        return pooled

# Define the Encoder (ViT model)
class ClassifierViT(nn.Module):
    def __init__(self, base="tiny", input_dim=256, embed_dim=192, num_patches=196, p=0.25, num_classes=3):
        super().__init__()
        self.embed_dim = embed_dim
        self.embedInput = nn.Linear(input_dim, self.embed_dim)
        self.encoder = EncoderViT(base=base, p=p)
        self.num_patches = num_patches
        # Load saved components
        checkpoint = torch.load("/kaggle/input/specific_task_06/pytorch/default/2/encoder_embedInput.pth", weights_only=True)

        self.embedInput.load_state_dict(checkpoint["embedInput"])
        self.encoder.load_state_dict(checkpoint["encoder"])
        self.attnetion_pool = AttentionPooling(embed_dim)
        self.num_patches = num_patches
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(p=p),
            nn.Linear(128, num_classes)
        )

        self.register_buffer("full_position_encoding", self.sinusoidal_position_encoding(num_patches, self.embed_dim).unsqueeze(0))

    def forward(self, x):
        batch_size = x.shape[0]
        
        full_pos_encoding = self.full_position_encoding.expand(batch_size, -1, -1)
        x = self.embedInput(x) + full_pos_encoding # (bs, visible_patches, embed_dim)
        
        x = self.encoder(x)
        x = self.attnetion_pool(x)
        x = self.classifier(x)
        return x

    def sinusoidal_position_encoding(self, num_patches, embed_dim):
        position = torch.arange(num_patches).unsqueeze(1)  # Shape: (num_patches, 1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))

        pe = torch.zeros(num_patches, embed_dim)
        pe[:, 1::2] = torch.sin(position * div_term)
        pe[:, 0::2] = torch.cos(position * div_term)

        return pe  # Shape: (num_patches, embed_dim)

## Training

In [7]:
torch.cuda.empty_cache()

In [8]:
def image_to_patches(img_tensor, patch_size=16):
    _, C, H, W = img_tensor.shape
    num_patches = (H // patch_size) * (W // patch_size)

    # Split into patches
    img_patches = img_tensor.unfold(2, patch_size, patch_size) # [1, 1, 224, 224] -> [1, 1, 14, 224, 16]
    img_patches = img_patches.unfold(3, patch_size, patch_size)  # [1, 1, 14, 224, 16] -> [1, 1, 14, 14, 16, 16]
    img_patches = img_patches.contiguous().view(-1, num_patches, patch_size * patch_size)  # Flatten patches [1, 1, 14, 14, 16, 16] -> [1, 196, 256]

    return img_patches

In [None]:
patch_size = 10
input_dim = patch_size**2
num_patches = int(150/patch_size)**2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ClassifierViT(base="tiny", embed_dim = 192, input_dim=input_dim, num_patches=num_patches)
model = nn.DataParallel(model.to(device))
print("masked patches: ", int(0.75*225))
print("visible patches: ", num_patches - int(0.75*225))
print_trainable_parameters(model)

# Optimizer & Loss
optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-6) #  , weight_decay=2e-4 , weight_decay=1e-4
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
criterion = nn.CrossEntropyLoss()

# Track metrics
train_losses, val_losses = [], []
train_accuracies, train_aucs = [], []
val_accuracies, val_aucs = [], []

# Training Loop
num_epochs = 100
best_val_auc = 0.0

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    total, correct = 0, 0
    all_probs = []
    all_labels = []

    for images, batch_labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        images = images.to(device)
        batch_labels = batch_labels.to(device)

        optimizer.zero_grad()
        images = image_to_patches(images, patch_size)
        outputs = model(images)
        loss = criterion(outputs, batch_labels)

        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == batch_labels).sum().item()
        total += batch_labels.size(0)
        probs = torch.softmax(outputs, dim=1) # [:, 1]  # Probability for class 1
        all_probs.extend(probs.cpu().detach().numpy())
        all_labels.extend(batch_labels.cpu().numpy())


    train_acc = correct / total
    train_loss = running_loss / len(train_loader)

    try:
        train_auc = roc_auc_score(all_labels, all_probs, multi_class='ovr')
    except ValueError:
        train_auc = float('nan')

    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    train_aucs.append(train_auc)

    # Validation Step
    model.eval()
    val_correct, val_total = 0, 0
    val_loss = 0.0
    all_probs_test = []
    all_labels_test = []

    with torch.no_grad():
        for batch_data, batch_labels in tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
            batch_data, batch_labels = batch_data.to(device), batch_labels.to(device)
            batch_data = image_to_patches(batch_data, patch_size)

            outputs = model(batch_data)
            loss = criterion(outputs, batch_labels)

            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            val_correct += (predicted == batch_labels).sum().item()
            val_total += batch_labels.size(0)

            # Collect probabilities and labels for ROC/AUC
            probs = torch.softmax(outputs, dim=1) # [:, 1]  # Probability for class 1
            all_probs_test.extend(probs.cpu().detach().numpy())
            all_labels_test.extend(batch_labels.cpu().numpy())
    
    val_acc = val_correct / val_total
    val_loss = val_loss / len(val_loader)
    scheduler.step(val_loss)

    try:
        val_auc = roc_auc_score(all_labels_test, all_probs_test, multi_class='ovr')
    except ValueError:
        val_auc = float('nan') 

    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    val_aucs.append(val_auc)

    print(f"Epoch [{epoch+1}/{num_epochs}] | Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f},  Train AUC: {train_auc:.4f} | Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val AUC: {val_auc:.4f}")

    # Save Best Model
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        torch.save(model.state_dict(), "best_fine_tuned_vit_model_clas.pth")
        print("Model Saved (Best Validation AUC)")


masked patches:  168
visible patches:  57
trainable params: 5383876 || all params: 5383876 || trainable%: 100.0


  return F.linear(input, self.weight, self.bias)
Epoch 1/100: 100%|██████████| 157/157 [03:45<00:00,  1.44s/it]
Epoch 1/100: 100%|██████████| 18/18 [00:09<00:00,  2.00it/s]


Epoch [1/100] | Train Loss: 0.9760, Train Acc: 0.4739,  Train AUC: 0.6808 | Val Loss: 3.9772, Val Acc: 0.3976, Val AUC: 0.7928
Model Saved (Best Validation AUC)


Epoch 2/100: 100%|██████████| 157/157 [03:48<00:00,  1.46s/it]
Epoch 2/100: 100%|██████████| 18/18 [00:09<00:00,  1.99it/s]


Epoch [2/100] | Train Loss: 0.7118, Train Acc: 0.6522,  Train AUC: 0.8295 | Val Loss: 0.9011, Val Acc: 0.5792, Val AUC: 0.9393
Model Saved (Best Validation AUC)


Epoch 3/100: 100%|██████████| 157/157 [03:44<00:00,  1.43s/it]
Epoch 3/100: 100%|██████████| 18/18 [00:09<00:00,  1.96it/s]


Epoch [3/100] | Train Loss: 0.5121, Train Acc: 0.7849,  Train AUC: 0.9208 | Val Loss: 3.0710, Val Acc: 0.5419, Val AUC: 0.9100


Epoch 4/100:  90%|████████▉ | 141/157 [03:22<00:22,  1.42s/it]

## Plottings

In [None]:
sns.set(style="whitegrid", font_scale=1.2)
epochs = range(1, num_epochs + 1)
save_dir = "/kaggle/working/"

# Plot Loss
plt.figure(figsize=(10, 6))
plt.plot(epochs, train_losses, label='Train Loss', color='blue')
plt.plot(epochs, val_losses, label='Val Loss', color='red')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss over Epochs')
plt.legend()
plt.savefig(os.path.join(save_dir, 'Losses.png'))
plt.show()

# Plot Accuracy
plt.figure(figsize=(10, 6))
plt.plot(epochs, train_accuracies, label='Train Accuracy', color='blue')
plt.plot(epochs, val_accuracies, label='Val Accuracy', color='red')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy over Epochs')
plt.legend()
plt.savefig(os.path.join(save_dir, 'Accuracies.png'))
plt.show()

# Plot AUC
plt.figure(figsize=(10, 6))
plt.plot(epochs, train_aucs, label='Train AUC', color='blue')
plt.plot(epochs, val_aucs, label='Val AUC', color='red')
plt.xlabel('Epoch')
plt.ylabel('AUC')
plt.title('AUC over Epochs')
plt.legend()
plt.savefig(os.path.join(save_dir, 'AUC.png'))
plt.show()

## Evaluation

In [None]:
state_dict = torch.load("/kaggle/working/best_fine_tuned_vit_model_clas.pth", map_location=device, weights_only=True)
model.load_state_dict(state_dict)

all_probs_test = []
all_labels_test = []
val_correct, val_total = 0, 0

with torch.no_grad():
    for batch_data, batch_labels in val_loader:
        batch_data, batch_labels = batch_data.to(device), batch_labels.to(device)
        batch_data = image_to_patches(batch_data, patch_size)

        outputs = model(batch_data)

        _, predicted = torch.max(outputs, 1)
        val_correct += (predicted == batch_labels).sum().item()
        val_total += batch_labels.size(0)

        probs = torch.softmax(outputs, dim=1)
        all_probs_test.extend(probs.cpu().detach().numpy())
        all_labels_test.extend(batch_labels.cpu().numpy())

val_acc = val_correct / val_total
print(f"Accuracy: {(val_acc*100):.2f}%")

# Step 1: Binarize the labels
all_labels_test = np.array(all_labels_test)
all_probs_test = np.array(all_probs_test)

n_classes = len(np.unique(all_labels_test))
all_labels_test_bin = label_binarize(all_labels_test, classes=np.arange(n_classes))

# Step 2: Compute ROC curve and AUC for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(all_labels_test_bin[:, i], all_probs_test[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

# Step 3: Compute micro-average ROC curve and AUC
fpr["micro"], tpr["micro"], _ = roc_curve(all_labels_test_bin.ravel(), all_probs_test.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

# Step 4: Plot ROC curves
label_map = {0:"no", 1:"sphere", 2: "vort"}
plt.figure(figsize=(8, 6))
colors = cycle(["aqua", "darkorange", "cornflowerblue"])
for i, color in zip(range(n_classes), colors):
    plt.plot(fpr[i], tpr[i], color=color, lw=2,
             label=f"ROC curve of class {label_map[i]} (AUC = {roc_auc[i]:.2f})")

plt.plot(fpr["micro"], tpr["micro"], color="deeppink", linestyle=":", linewidth=4,
         label=f"Micro-average ROC curve (AUC = {roc_auc['micro']:.2f})")

plt.plot([0, 1], [0, 1], "k--", lw=2)
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Multiclass ROC Curve")
plt.legend(loc="lower right")
plt.savefig(os.path.join(save_dir, 'ROC_curve.png'))
plt.show()

# Step 5: Compute AUC using roc_auc_score with multi_class='ovr'
val_auc = roc_auc_score(all_labels_test, all_probs_test, multi_class='ovr')
print(f"AUC (One-vs-Rest, macro-average): {val_auc:.2f}")

# Optional: Compute AUC with different averaging methods
val_auc_micro = roc_auc_score(all_labels_test, all_probs_test, multi_class='ovr', average='micro')
val_auc_weighted = roc_auc_score(all_labels_test, all_probs_test, multi_class='ovr', average='weighted')
print(f"AUC (One-vs-Rest, micro-average): {val_auc_micro:.2f}")
print(f"AUC (One-vs-Rest, weighted-average): {val_auc_weighted:.2f}")