In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

from torch.utils.data import Dataset, DataLoader
from medmnist import INFO, ChestMNIST

In [13]:
from sklearn.metrics import f1_score, roc_auc_score, precision_score, recall_score
import numpy as np
import random

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
threshold = 0.25  # consistent threshold across experiments
lr = 1e-3

In [16]:
import torchvision.transforms as transforms
from medmnist import INFO, ChestMNIST

# Load dataset metadata
info = INFO['chestmnist']
DataClass = ChestMNIST

# Define standard transform (as per MedMNIST recommendation)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])

# Load datasets with transforms
train_dataset = DataClass(split='train', download=True, transform=transform)
test_dataset  = DataClass(split='test',  download=True, transform=transform)


In [17]:
# Check dataset sizes
print(f"Train set size: {len(train_dataset)}")
print(f"Test set size: {len(test_dataset)}")

Train set size: 78468
Test set size: 22433


In [18]:
from collections import defaultdict

def create_episode_multilabel(dataset, n_way=5, k_shot=5, q_query=15):
    class_to_indices = defaultdict(list)
    for idx, (img, labels) in enumerate(dataset):
        for cls in range(labels.shape[0]):
            if labels[cls] == 1:
                class_to_indices[cls].append(idx)

    selected_classes = random.sample(list(class_to_indices.keys()), n_way)

    support_x, support_y, query_x, query_y = [], [], [], []

    for cls in selected_classes:
        indices = random.sample(class_to_indices[cls], k_shot + q_query)
        for i in range(k_shot):
            img, labels = dataset[indices[i]]
            support_x.append(img)
            support_y.append(torch.tensor(labels[selected_classes]))
        for i in range(k_shot, k_shot + q_query):
            img, labels = dataset[indices[i]]
            query_x.append(img)
            query_y.append(torch.tensor(labels[selected_classes]))

    return (torch.stack(support_x), torch.stack(support_y),
            torch.stack(query_x), torch.stack(query_y))

In [19]:
class ProtoNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128), nn.ReLU()
        )

    def forward(self, x):
        return self.encoder(x)

In [20]:
def train_protonet(model, dataset, n_way=5, k_shot=5, q_query=15, num_episodes=300):
    model.to(device)
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for episode in range(num_episodes):
        support_x, support_y, query_x, query_y = create_episode_multilabel(dataset, n_way, k_shot, q_query)
        support_x, support_y = support_x.to(device), support_y.to(device)
        query_x, query_y = query_x.to(device), query_y.to(device)

        support_embeddings = model(support_x)
        query_embeddings = model(query_x)

        prototypes = []
        for i in range(n_way):
            cls_mask = (support_y[:, i] == 1)
            if cls_mask.sum() == 0:
                prototypes.append(torch.zeros_like(support_embeddings[0]))
            else:
                prototypes.append(support_embeddings[cls_mask].mean(dim=0))
        prototypes = torch.stack(prototypes)

        logits = -torch.cdist(query_embeddings, prototypes)
        loss = F.binary_cross_entropy_with_logits(logits, query_y.float())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if episode % 50 == 0:
            print(f"[{episode}/{num_episodes}] Loss: {loss.item():.4f}")

In [21]:
def evaluate_selected_metrics(model, dataset, n_way=5, k_shot=5, q_query=15, test_episodes=300, threshold=0.25):
    model.eval()
    f1_scores, aurocs, precisions, recalls = [], [], [], []

    with torch.no_grad():
        for _ in range(test_episodes):
            support_x, support_y, query_x, query_y = create_episode_multilabel(dataset, n_way, k_shot, q_query)
            support_x, support_y = support_x.to(device), support_y.to(device)
            query_x, query_y = query_x.to(device), query_y.to(device)

            support_embeddings = model(support_x)
            query_embeddings = model(query_x)

            prototypes = []
            for i in range(n_way):
                cls_mask = (support_y[:, i] == 1)
                if cls_mask.sum() == 0:
                    prototypes.append(torch.zeros_like(support_embeddings[0]))
                else:
                    prototypes.append(support_embeddings[cls_mask].mean(dim=0))
            prototypes = torch.stack(prototypes)

            
            logits = -torch.cdist(query_embeddings, prototypes)
            probs = torch.sigmoid(logits)
            preds = (probs > threshold).float()

            preds_np = preds.cpu().numpy()
            targets_np = query_y.cpu().numpy()

            if np.sum(targets_np) > 0:
                f1 = f1_score(targets_np, preds_np, average='micro', zero_division=0)
                precision = precision_score(targets_np, preds_np, average='micro', zero_division=0)
                recall = recall_score(targets_np, preds_np, average='micro', zero_division=0)
                try:
                    auroc = roc_auc_score(targets_np, probs.cpu().numpy(), average='micro')
                except:
                    auroc = 0
                f1_scores.append(f1)
                precisions.append(precision)
                recalls.append(recall)
                aurocs.append(auroc)

    return {
        'avg_f1_score': np.mean(f1_scores),
        'avg_auroc': np.mean(aurocs),
        'avg_precision': np.mean(precisions),
        'avg_recall': np.mean(recalls)
    }

def print_selected_metrics(name, metrics):
    print(f"\n==== {name} Metrics ====")
    print(f"F1-Score   : {metrics['avg_f1_score']:.4f}")
    print(f"AUROC      : {metrics['avg_auroc']:.4f}")
    print(f"Precision  : {metrics['avg_precision']:.4f}")
    print(f"Recall     : {metrics['avg_recall']:.4f}")
    print("="*30)

In [22]:
# One-Shot (k=1)
model_1shot = ProtoNet()
train_protonet(model_1shot, train_dataset, n_way=5, k_shot=1, num_episodes=300)

[0/300] Loss: 0.6342
[50/300] Loss: 0.5727
[100/300] Loss: 0.6049
[150/300] Loss: 0.6242
[200/300] Loss: 0.5702
[250/300] Loss: 0.5445


In [24]:
# Few-Shot (k=5)
model_5shot = ProtoNet()
train_protonet(model_5shot, train_dataset, n_way=5, k_shot=5, num_episodes=300)

[0/300] Loss: 0.6509
[50/300] Loss: 0.5902
[100/300] Loss: 0.5378
[150/300] Loss: 0.5417
[200/300] Loss: 0.6053
[250/300] Loss: 0.5482


In [27]:
# One-Shot (k=1) Test 
metrics_1shot = evaluate_selected_metrics(model_1shot, test_dataset, n_way=5, k_shot=1, test_episodes=300)
print_selected_metrics("One-Shot", metrics_1shot)

# Few-Shot (k=5) 
metrics_5shot = evaluate_selected_metrics(model_5shot, test_dataset, n_way=5, k_shot=5, test_episodes=300)
print_selected_metrics("Few-Shot 5-Shot", metrics_5shot)


==== One-Shot Metrics ====
F1-Score   : 0.3780
AUROC      : 0.5454
Precision  : 0.2877
Recall     : 0.5630

==== Few-Shot 5-Shot Metrics ====
F1-Score   : 0.4172
AUROC      : 0.5800
Precision  : 0.3001
Recall     : 0.6902


### Self-Supervised Pretraining using RotNet

In [28]:
from torchvision.transforms import functional as TF

# Self-supervised dataset class for RotNet task
class RotNetDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset):
        self.base_dataset = base_dataset
        self.rotations = [0, 90, 180, 270]

    def __len__(self):
        return len(self.base_dataset) * 4

    def __getitem__(self, idx):
        base_idx = idx // 4
        rot_label = idx % 4
        img, _ = self.base_dataset[base_idx]  # ignore true label
        rotated_img = TF.rotate(img, self.rotations[rot_label])
        return rotated_img, rot_label

In [29]:
class RotNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128), nn.ReLU()
        )
        self.classifier = nn.Linear(128, 4)  # 4 rotation classes

    def forward(self, x):
        features = self.encoder(x)
        return self.classifier(features)

In [30]:
# Load RotNet training dataset from ChestMNIST train set
rotnet_train = RotNetDataset(train_dataset)
rotnet_loader = DataLoader(rotnet_train, batch_size=64, shuffle=True)

def train_rotnet(model, dataloader, num_epochs=10):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = criterion(logits, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"[RotNet Epoch {epoch+1}] Loss: {total_loss:.4f}")

In [31]:
# Train RotNet
rotnet_model = RotNet()
train_rotnet(rotnet_model, rotnet_loader, num_epochs=10)

# Save pretrained encoder
pretrained_encoder = rotnet_model.encoder

[RotNet Epoch 1] Loss: 70.6259
[RotNet Epoch 2] Loss: 33.4759
[RotNet Epoch 3] Loss: 27.9486
[RotNet Epoch 4] Loss: 24.1288
[RotNet Epoch 5] Loss: 21.3974
[RotNet Epoch 6] Loss: 19.1500
[RotNet Epoch 7] Loss: 16.2586
[RotNet Epoch 8] Loss: 14.2985
[RotNet Epoch 9] Loss: 12.6736
[RotNet Epoch 10] Loss: 10.7863


In [32]:
class ProtoNetWithRotNet(nn.Module):
    def __init__(self, pretrained_encoder):
        super().__init__()
        self.encoder = pretrained_encoder  

    def forward(self, x):
        return self.encoder(x)

In [33]:
def train_protonet(model, dataset, n_way=5, k_shot=5, q_query=15, num_episodes=300):
    model.to(device)
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    for episode in range(num_episodes):
        support_x, support_y, query_x, query_y = create_episode_multilabel(dataset, n_way, k_shot, q_query)
        support_x, support_y = support_x.to(device), support_y.to(device)
        query_x, query_y = query_x.to(device), query_y.to(device)

        support_embeddings = model(support_x)
        query_embeddings = model(query_x)

        prototypes = []
        for i in range(n_way):
            cls_mask = (support_y[:, i] == 1)
            if cls_mask.sum() == 0:
                prototypes.append(torch.zeros_like(support_embeddings[0]))
            else:
                prototypes.append(support_embeddings[cls_mask].mean(dim=0))
        prototypes = torch.stack(prototypes)

        logits = -torch.cdist(query_embeddings, prototypes)
        loss = nn.BCEWithLogitsLoss()(logits, query_y.float())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if episode % 50 == 0:
            print(f"[Episode {episode}] Loss: {loss.item():.4f}")


In [34]:
model_rotnet_proto = ProtoNetWithRotNet(pretrained_encoder=pretrained_encoder).to(device)

In [35]:
train_protonet(model_rotnet_proto, train_dataset, n_way=5, k_shot=5, q_query=15, num_episodes=300)

[Episode 0] Loss: 2.5259
[Episode 50] Loss: 0.5617
[Episode 100] Loss: 0.6101
[Episode 150] Loss: 0.5767
[Episode 200] Loss: 0.6191
[Episode 250] Loss: 0.6367


In [36]:
metrics_rotnet_proto = evaluate_selected_metrics(model_rotnet_proto, test_dataset, n_way=5, k_shot=5, test_episodes=300, threshold=0.25)

In [37]:
print_selected_metrics("Few-Shot + RotNet", metrics_rotnet_proto)


==== Few-Shot + RotNet Metrics ====
F1-Score   : 0.4031
AUROC      : 0.5419
Precision  : 0.2818
Recall     : 0.7128
