In [None]:
class MAML:
    def __init__(self, model, inner_lr=0.01, outer_lr=0.001, inner_steps=1):
        self.model = model.to(device)
        self.meta_optimizer = torch.optim.Adam(self.model.parameters(), lr=outer_lr)
        self.inner_lr = inner_lr
        self.inner_steps = inner_steps
        self.loss_fn = nn.BCEWithLogitsLoss()

    def clone_model(self):
        return copy.deepcopy(self.model)

    def inner_loop(self, support_images, support_labels):
        self.model.train()
        params = dict(self.model.named_parameters())  # initial params

        # with torch.no_grad():
        #     initial_preds = functional_call(self.model, params, support_images)
        #     initial_loss = self.loss_fn(initial_preds, support_labels)
        #     print("Initial support loss:", initial_loss.item())

        for _ in range(self.inner_steps):
            preds = functional_call(self.model, params, support_images)
            loss = self.loss_fn(preds, support_labels)
            grads = torch.autograd.grad(loss, params.values(), create_graph=True)

            # Differentiable update
            params = {
                name: param - self.inner_lr * grad
                for (name, param), grad in zip(params.items(), grads)
            }

        # with torch.no_grad():
        #     updated_preds = functional_call(self.model, params, support_images)
        #     updated_loss = self.loss_fn(updated_preds, support_labels)
        #     print("Updated support loss:", updated_loss.item())

        return params

    def outer_loop(self, tasks):
        self.model.train()
        meta_loss = 0.0

        for support_set, query_set in tasks:
            support_images, support_labels = support_set
            query_images, query_labels = query_set

            support_images = support_images.to(device)
            support_labels = support_labels.to(device)
            query_images = query_images.to(device)
            query_labels = query_labels.to(device)

            # Get adapted parameters
            adapted_params = self.inner_loop(support_images, support_labels)

            # Evaluate with adapted weights using functional_call
            query_preds = functional_call(self.model, adapted_params, query_images)
            loss = self.loss_fn(query_preds, query_labels)
            meta_loss += loss

        meta_loss /= len(tasks)
        self.meta_optimizer.zero_grad()
        meta_loss.backward()
        self.meta_optimizer.step()

        return meta_loss.item()

In [None]:
import random

def get_binary_tasks(dataset, index_map, disease, k_shot=5, q_query=15, num_tasks=4):
    tasks = []

    pos_indices = index_map[disease]["pos"]
    neg_indices = index_map[disease]["neg"]

    for _ in range(num_tasks):
        pos_idxs = random.sample(pos_indices, k_shot + q_query)
        neg_idxs = random.sample(neg_indices, k_shot + q_query)

        pos_samples = [dataset[i] for i in pos_idxs]
        neg_samples = [dataset[i] for i in neg_idxs]

        support = pos_samples[:k_shot] + neg_samples[:k_shot]
        query = pos_samples[k_shot:] + neg_samples[k_shot:]

        support_images = torch.stack([x for x, _ in support]).to(device)
        support_labels = torch.stack([y for _, y in support]).to(device)
        query_images = torch.stack([x for x, _ in query]).to(device)
        query_labels = torch.stack([y for _, y in query]).to(device)

        tasks.append(((support_images, support_labels), (query_images, query_labels)))

    return tasks


In [None]:
def build_label_index_map(dataset, disease_names):
    index_map = defaultdict(lambda: {"pos": [], "neg": []})

    for i in range(len(dataset)):
        _, label = dataset[i]
        for disease_idx, disease in enumerate(disease_names):
            if label[disease_idx] == 1:
                index_map[disease]["pos"].append(i)
            else:
                index_map[disease]["neg"].append(i)

    return index_map

index_map = build_label_index_map(train_dataset, disease_names)

In [None]:
# training loop
model = ResNet50_MultiLabel(15).to(device)
maml = MAML(model, inner_lr=0.01, outer_lr=0.001, inner_steps=1)

val_diseases = ["Nodule"]  # target disease
val_tasks = []

for disease in val_diseases:
    tasks = get_binary_tasks(test_dataset, test_index_map, disease=disease, k_shot=5, q_query=15, num_tasks=4)
    val_tasks.extend(tasks)

start_time = time.time()
num_epochs = 100
for epoch in range(num_epochs):
    for disease in disease_names:
        if disease == "Nodule":  # leave out target
            continue
        tasks = get_binary_tasks(train_dataset, index_map, disease=disease, k_shot=5, q_query=15, num_tasks=8)
        loss = maml.outer_loop(tasks)
    if epoch % 10 == 0:

        maml.model.eval()
        val_losses = []
        for (support_images, support_labels), (query_images, query_labels) in val_tasks:
            adapted_params = maml.inner_loop(support_images.to(device), support_labels.to(device))
            with torch.no_grad():
                preds = functional_call(maml.model, adapted_params, query_images.to(device))
                loss_fn = nn.BCEWithLogitsLoss()
                val_loss = loss_fn(preds, query_labels.to(device))
                val_losses.append(val_loss.item())

        avg_val_loss = sum(val_losses) / len(val_losses)
        elapsed = (time.time() - start_time) / 60
        print(f"Epoch {epoch+1}: Train Loss = {loss:.4f}, Val Loss = {avg_val_loss:.4f}, Time: {elapsed:.2f} min")
        maml.model.train()


torch.save(maml.model.state_dict(), "maml_resnet.pth")
print("✅ MAML model saved as maml_resnet.pth")

In [None]:
def evaluate_few_shot(maml, dataset, index_map, disease, disease_idx, k_shot=5, q_query=15, num_tasks=10):
    pos_indices = index_map[disease]["pos"]
    neg_indices = index_map[disease]["neg"]

    accs = []
    all_true = []
    all_probs = []
    all_preds = []

    for _ in range(num_tasks):
        pos_batch_idxs = random.sample(pos_indices, k_shot + q_query)
        neg_batch_idxs = random.sample(neg_indices, k_shot + q_query)

        support_idxs = pos_batch_idxs[:k_shot] + neg_batch_idxs[:k_shot]
        query_idxs = pos_batch_idxs[k_shot:] + neg_batch_idxs[k_shot:]

        support_samples = [dataset[i] for i in support_idxs]
        query_samples = [dataset[i] for i in query_idxs]

        support_images = torch.stack([x for x, _ in support_samples]).to(device)
        support_labels = torch.stack([y for _, y in support_samples]).to(device)
        query_images = torch.stack([x for x, _ in query_samples]).to(device)
        query_labels = torch.stack([y for _, y in query_samples]).to(device)

        adapted_params = maml.inner_loop(support_images, support_labels)
        preds = functional_call(maml.model, adapted_params, query_images)
        probs = torch.sigmoid(preds[:, disease_idx])
        preds_binary = (probs > 0.5).float()
        true_binary = query_labels[:, disease_idx]

        acc = (preds_binary == true_binary).sum().item() / len(preds_binary)
        accs.append(acc)

        all_true.extend(true_binary.cpu().numpy())
        all_probs.extend(probs.cpu().detach().numpy())
        all_preds.extend(preds_binary.cpu().numpy())

    final_acc = sum(accs) / len(accs)
    print(f"[MAML] Few-shot Accuracy on '{disease}': {final_acc:.2f}")

    try:
        auc = roc_auc_score(all_true, all_probs)
        print(f"[MAML] Few-shot AUC on '{disease}': {auc:.4f}")
    except ValueError:
        print(f"[MAML] Few-shot AUC on '{disease}': N/A (only one class present)")

    f1 = f1_score(all_true, all_preds, zero_division=0)
    print(f"[MAML] Few-shot F1 Score on '{disease}': {f1:.4f}")

    return final_acc

In [None]:
for target_disease in disease_names:
    #target_disease = "Nodule"
    target_disease_idx = disease_names.index(target_disease)

    evaluate_few_shot(
        maml,
        test_dataset,
        test_index_map,
        disease=target_disease,
        disease_idx=target_disease_idx,
        k_shot=5,
        q_query=15,
        num_tasks=10
    )

# Fine-tuning ResNet on small sample and compare with MAML
Given that MAML is able to produce good result on samll samples, we want to compare it with the traditional approach of finetuning a model on samll samples. The experiement proceeds in the following steps

Separte the samples into two set. (1) all other diseases except target disease "Nodule" (2) small samples of target diseases - 20 samples.

Train the resnet model on dataset (1) as a pretrained model. And evaluate its performance
Fine-tune the pretrained model on dataset (2) and evaluate again.

In [None]:
nodule_idx = disease_names.index("Nodule")

nodule_indices = [i for i, (_, label) in enumerate(filtered_dataset) if label[nodule_idx] == 1.0]
non_nodule_indices = [i for i, (_, label) in enumerate(filtered_dataset) if label[nodule_idx] == 0.0]

# Get small nodule dataset (20 train + 20 test)
import random
random.seed(42)
random.shuffle(nodule_indices)

nodule_train_indices = nodule_indices[:20]
nodule_test_indices = nodule_indices[20:40]

pretrain_indices = non_nodule_indices

nodule_train_dataset = Subset(filtered_dataset, nodule_train_indices)
nodule_test_dataset = Subset(filtered_dataset, nodule_test_indices)
pretrain_dataset = Subset(filtered_dataset, pretrain_indices)

In [None]:
class BCEWithLogitsLossExceptNodule(nn.Module):
    def __init__(self, nodule_index):
        super().__init__()
        self.nodule_index = nodule_index
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, input, target):
        input_without_nodule = torch.cat([input[:, :self.nodule_index], input[:, self.nodule_index+1:]], dim=1)
        target_without_nodule = torch.cat([target[:, :self.nodule_index], target[:, self.nodule_index+1:]], dim=1)
        return self.bce(input_without_nodule, target_without_nodule)

# Dataloader
pretrain_loader = DataLoader(pretrain_dataset, batch_size=32, shuffle=True)

# Training loop (no Nodule)
resnet_nodule = ResNet50_MultiLabel(15).to(device)
loss_fn = BCEWithLogitsLossExceptNodule(nodule_idx)

criterion = loss_fn
optimizer = torch.optim.Adam(resnet_nodule.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5, 10, 15], gamma=0.5)

start_time = time.time()
train_loss_resnet_plot = []
val_loss_resnet_plot = []


for epoch in range(0, 5):
    resnet_nodule.train()
    running_loss = 0.0

    for images, labels in pretrain_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = resnet_nodule(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    avg_loss = running_loss / len(pretrain_loader)
    train_loss_resnet_plot.append(avg_loss)
    elapsed = (time.time() - start_time) / 60
    print(f"\nEpoch {epoch}, Train Loss: {avg_loss:.4f}, Time: {elapsed:.2f} min")


torch.save(resnet_nodule.state_dict(), "pretrained_resnet_no_nodule.pth")

In [None]:
pre_resnet_nodule = ResNet50_MultiLabel(15).to(device)
pre_resnet_nodule.load_state_dict(torch.load("pretrained_resnet_no_nodule.pth"))
evaluate_resnet(pre_resnet_nodule, val_loader)

In [None]:
# Load pretrained weights
resnet_nodule.load_state_dict(torch.load("pretrained_resnet_no_nodule.pth"))

# Modify the loss to focus on Nodule only
class BCEWithLogitsLossForNodule(nn.Module):
    def __init__(self, nodule_index):
        super().__init__()
        self.nodule_index = nodule_index
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, input, target):
        input = input[:, self.nodule_index].unsqueeze(1)
        target = target[:, self.nodule_index].unsqueeze(1)
        return self.bce(input, target)


nodule_loss_fn = BCEWithLogitsLossForNodule(nodule_idx)

# Fine-tune
nodule_train_loader = DataLoader(nodule_train_dataset, batch_size=4, shuffle=True, drop_last=True)

criterion = nodule_loss_fn
# freeze weights
for name, param in resnet_nodule.backbone.named_parameters():
    if not name.startswith('fc'):
        param.requires_grad = False

# Only train classifier layer
optimizer = torch.optim.Adam(resnet_nodule.backbone.fc.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5, 10, 15], gamma=0.5)

start_time = time.time()

# Training loop
for epoch in range(20):
    resnet_nodule.train()
    running_loss = 0.0
    for images, labels in nodule_train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = resnet_nodule(images)
        loss = criterion(outputs, labels)

        if outputs.requires_grad:
            loss.backward()
            optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(nodule_train_loader)
    elapsed = (time.time() - start_time) / 60
    print(f"\n[Fine-tune] Epoch {epoch}, Train Loss: {avg_loss:.4f}, Time: {elapsed:.2f} min")


In [None]:
nodule_test_loader = DataLoader(nodule_test_dataset, batch_size=4, shuffle=False)
def evaluate_nodule(net, loader, nodule_idx):
    net.eval()
    all_probs = []
    all_labels = []
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = torch.sigmoid(net(images))
            all_probs.append(outputs[:, nodule_idx])
            all_labels.append(labels[:, nodule_idx])

    all_probs = torch.cat(all_probs).cpu()
    all_labels = torch.cat(all_labels).cpu()

    preds_binary = (all_probs > 0.5).float()

    # Classification metrics
    acc = (preds_binary == all_labels).sum().item() / len(all_labels)
    precision = precision_score(all_labels, preds_binary, zero_division=0)
    recall = recall_score(all_labels, preds_binary, zero_division=0)
    f1 = f1_score(all_labels, preds_binary, zero_division=0)

    print(f"\n📊 Evaluation on 'Nodule':")
    print(f"Accuracy:  {acc:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1 Score:  {f1:.4f}")

evaluate_nodule(resnet_nodule, nodule_test_loader, nodule_idx)

Finally, we evaluate the fine-tune model again. We can see the the AUCROC score for nodule improves from

Nodule: AUC = 0.4754

to

Nodule: AUC = 0.5209

Which shows the effect of fine-tuning. However, it is still significantly lower than the result of MAML which is 0.6630.

This shows that traditional CNN network is better at learning given large number of samples, but perform badly if the samples is small, in comparison with MAML architecture.