In [None]:
from google.colab import drive

drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
import timm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from tqdm import tqdm
import os
import copy
import matplotlib.pyplot as plt
import numpy as np


In [None]:
import os
import shutil

def remove_unwanted_folders(root, unwanted_folders=["Br35H-Mask-RCNN"]):
    for uf in unwanted_folders:
        for subdir, dirs, files in os.walk(root):
            path_to_remove = os.path.join(subdir, uf)
            if os.path.exists(path_to_remove):
                shutil.rmtree(path_to_remove)
                print(f"Removed {path_to_remove}")

# Apply to BR35H
br_root = "/content/drive/MyDrive/Data/Br35H"
remove_unwanted_folders(br_root)


In [None]:
for root, dirs, files in os.walk(br_root):
    print(root, dirs)


/content/drive/MyDrive/Data/Br35H ['yes', 'no', 'pred', 'train', 'test']
/content/drive/MyDrive/Data/Br35H/yes []
/content/drive/MyDrive/Data/Br35H/no []
/content/drive/MyDrive/Data/Br35H/pred []
/content/drive/MyDrive/Data/Br35H/train ['tumor', 'no_tumor']
/content/drive/MyDrive/Data/Br35H/train/tumor []
/content/drive/MyDrive/Data/Br35H/train/no_tumor []
/content/drive/MyDrive/Data/Br35H/test ['tumor', 'no_tumor']
/content/drive/MyDrive/Data/Br35H/test/tumor []
/content/drive/MyDrive/Data/Br35H/test/no_tumor []


In [None]:



import os

br_root = "/content/drive/MyDrive/Data/Br35H"

for split in ['train', 'test']:
    for cls in ['tumor', 'no_tumor']:
        os.makedirs(os.path.join(br_root, split, cls), exist_ok=True)

print("Folder structure created!")


Folder structure created!


In [None]:
import glob
import shutil
import random

# Example for yes -> tumor
yes_images = glob.glob(os.path.join(br_root, "yes", "*.*"))
random.shuffle(yes_images)
split_idx = int(0.8 * len(yes_images))

for img in yes_images[:split_idx]:
    shutil.move(img, os.path.join(br_root, "train", "tumor"))
for img in yes_images[split_idx:]:
    shutil.move(img, os.path.join(br_root, "test", "tumor"))

# Example for no -> no_tumor
no_images = glob.glob(os.path.join(br_root, "no", "*.*"))
random.shuffle(no_images)
split_idx = int(0.8 * len(no_images))

for img in no_images[:split_idx]:
    shutil.move(img, os.path.join(br_root, "train", "no_tumor"))
for img in no_images[split_idx:]:
    shutil.move(img, os.path.join(br_root, "test", "no_tumor"))


In [None]:
for split in ['train', 'test']:
    for cls in ['tumor', 'no_tumor']:
        path = os.path.join(br_root, split, cls)
        print(path, ":", os.listdir(path)[:5])


/content/drive/MyDrive/Data/Br35H/train/tumor : ['y1214.jpg', 'y1215.jpg', 'y1197.jpg', 'y1190.jpg', 'y1211.jpg']
/content/drive/MyDrive/Data/Br35H/train/no_tumor : ['no1227.jpg', 'no1243.jpg', 'no1309.jpg', 'no1255.jpg', 'no1284.jpg']
/content/drive/MyDrive/Data/Br35H/test/tumor : ['y1035.jpg', 'y1106.jpg', 'y1052.jpg', 'y1022.jpg', 'y1040.jpg']
/content/drive/MyDrive/Data/Br35H/test/no_tumor : ['no1034.jpg', 'No17.jpg', 'no1013.jpg', 'no1002.jpg', 'no1001.jpg']


In [None]:
from torchvision import transforms

img_size = 224

# Train on BR35H with strong augmentations
train_tfms = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomAffine(0, translate=(0.1,0.1)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

test_tfms = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])


from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader


br_root = "/content/drive/MyDrive/Data/Br35H"
btd_root = "/content/drive/MyDrive/Data/Brain Tumor Data"

# Datasets
br_train_ds = ImageFolder(f"{br_root}/train", transform=train_tfms)
btd_test_ds = ImageFolder(f"{btd_root}/test", transform=test_tfms)

# DataLoaders
train_loader = DataLoader(br_train_ds, batch_size=32, shuffle=True, num_workers=2)
test_loader  = DataLoader(btd_test_ds, batch_size=32, shuffle=False, num_workers=2)

num_classes = len(br_train_ds.classes)
print("Classes:", br_train_ds.classes)
print("Training images:", len(br_train_ds))
print("BTD test images:", len(btd_test_ds))


Classes: ['no_tumor', 'tumor']
Training images: 2400
BTD test images: 1311


In [None]:
import torch
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

def evaluate_model(model, test_loader):
    model.eval()
    y_true, y_pred = [], []

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())

    acc = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, average="macro", zero_division=0)
    rec = recall_score(y_true, y_pred, average="macro", zero_division=0)
    f1 = f1_score(y_true, y_pred, average="weighted", zero_division=0)
    cm = confusion_matrix(y_true, y_pred)

    return {
        "accuracy": acc,
        "precision": prec,
        "recall": rec,
        "f1": f1,
        "confusion_matrix": cm,
        "y_true": y_true,
        "y_pred": y_pred
    }


In [None]:
from torch.utils.data import Subset


btd_train_ds = ImageFolder(f"{btd_root}/train", transform=train_tfms)
few_shot_indices = list(range(int(0.5 * len(btd_train_ds))))
btd_few_shot_ds = Subset(btd_train_ds, few_shot_indices)

few_shot_loader = DataLoader(btd_few_shot_ds, batch_size=32, shuffle=True, num_workers=2)
print("Few-shot BTD samples for fine-tuning:", len(btd_few_shot_ds))



Few-shot BTD samples for fine-tuning: 2856


In [None]:
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

def fine_tune_model(model, few_shot_loader, epochs=10, lr=1e-5):
    model.train()
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        total_loss = 0
        for images, labels in tqdm(few_shot_loader, desc=f"Fine-tune Epoch {epoch+1}/{epochs}"):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {total_loss/len(few_shot_loader):.4f}")

    return model


In [None]:

class MixStyle(nn.Module):
    def __init__(self, p=0.5, alpha=0.1):
        super().__init__()
        self.p = p
        self.alpha = alpha

    def forward(self, x):
        if not self.training or torch.rand(1) > self.p:
            return x
        B, C, H, W = x.size()
        x_ = x.view(B, C, -1)
        mu = x_.mean(dim=2, keepdim=True)
        sigma = x_.std(dim=2, keepdim=True)
        mu_shuffle = mu[torch.randperm(B)]
        sigma_shuffle = sigma[torch.randperm(B)]
        x = (x_ - mu) / (sigma + 1e-6) * (sigma_shuffle + 1e-6) + mu_shuffle
        return x.view(B, C, H, W)


In [None]:


def get_model(name, num_classes=4, use_mixstyle=False):
    model = None



    if name in ["resnet50","mobilenetv2_100","convnext_tiny", "swin_tiny_patch4_window7_224", "densenet121","efficientnet_v2_s"]:
        model = timm.create_model(name, pretrained=True, num_classes=num_classes)

        if use_mixstyle:

            if "resnet" in name:
                if hasattr(model, "layer1"):
                    model.layer1[0].bn1 = nn.Sequential(model.layer1[0].bn1, MixStyle(p=0.5, alpha=0.1))

            elif "convnext" in name:
                if hasattr(model, "stages") and hasattr(model.stages[0], "blocks"):
                    first_block = model.stages[0].blocks[0]
                    first_block.norm = nn.Sequential(first_block.norm, MixStyle(p=0.5, alpha=0.1))

            if "mobilenetv2" in name and hasattr(model, "features"):
                model.features[0][0] = nn.Sequential(model.features[0][0], MixStyle(p=0.5, alpha=0.1))


            elif "densenet" in name and hasattr(model, "features"):
                model.features.norm0 = nn.Sequential(model.features.norm0, MixStyle(p=0.5, alpha=0.1))

            elif "swin" in name and hasattr(model, "patch_embed"):
                model.patch_embed.proj = nn.Sequential(model.patch_embed.proj, MixStyle(p=0.5, alpha=0.1))


    elif name == "medvit":
        model = timm.create_model("vit_small_patch16_224", pretrained=True, num_classes=num_classes)
    else:
        raise ValueError(f"Unknown model: {name}")

    return model


In [None]:
def train_model(model, train_loader, test_loader, model_name, use_mixstyle=False,
                epochs=45, patience=7, fine_tune=False,
                lr=1e-5, checkpoint_dir="/content/drive/MyDrive/checkpoints"):

    os.makedirs(checkpoint_dir, exist_ok=True)
    model = model.to(device)


    if fine_tune:
        for param in model.parameters():
            param.requires_grad = False
        if hasattr(model, "fc"):  # ResNet
            for param in model.fc.parameters():
                param.requires_grad = True
        elif hasattr(model, "classifier"):  # MobileNet / DenseNet / ConvNeXt
            for param in model.classifier.parameters():
                param.requires_grad = True
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    else:
        for param in model.parameters():
            param.requires_grad = True
        optimizer = optim.Adam(model.parameters(), lr=lr)

    criterion = nn.CrossEntropyLoss()
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=patience)

    best_loss = float('inf')
    best_model_wts = copy.deepcopy(model.state_dict())
    trigger_times = 0

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for images, labels in tqdm(train_loader, desc=f"{model_name} Epoch {epoch+1}/{epochs}"):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")
        scheduler.step(avg_loss)

        # Save best checkpoint
        if avg_loss < best_loss:
            best_loss = avg_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            save_path = f"{checkpoint_dir}/{model_name}_best.pth"
            torch.save(model.state_dict(), save_path)
            print(f"‚úÖ Saved best checkpoint: {save_path}")
            trigger_times = 0
        else:
            trigger_times += 1
            if trigger_times >= patience:
                print(f"‚èπ Early stopping at epoch {epoch+1}")
                break

    model.load_state_dict(best_model_wts)


    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())

    acc = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, average="macro", zero_division=0)
    rec = recall_score(y_true, y_pred, average="macro", zero_division=0)
    f1 = f1_score(y_true, y_pred, average="weighted", zero_division=0)
    cm = confusion_matrix(y_true, y_pred)

    return {"accuracy": acc, "precision": prec, "recall": rec, "f1": f1, "confusion_matrix": cm}


In [None]:

def grad_cam(model, image, target_class=None, layer_name="blocks.11.norm1"):
    """
    Simple Grad-CAM hook for transformer-based models
    """
    model.eval()
    image = image.unsqueeze(0).to(device)
    gradients = {}
    activations = {}

    def save_gradients(module, grad_input, grad_output):
        gradients['value'] = grad_output[0]

    def save_activations(module, input, output):
        activations['value'] = output

    layer = dict([*model.named_modules()])[layer_name]
    layer.register_forward_hook(save_activations)
    layer.register_backward_hook(save_gradients)

    outputs = model(image)
    if target_class is None:
        target_class = outputs.argmax().item()

    model.zero_grad()
    outputs[0, target_class].backward()

    grad = gradients['value'][0].cpu().detach().numpy()
    act = activations['value'][0].cpu().detach().numpy()

    weights = grad.mean(axis=-1).mean(axis=-1)
    cam = np.zeros(act.shape[-2:])
    for i, w in enumerate(weights):
        cam += w * act[i]
    cam = np.maximum(cam, 0)
    cam = cam / cam.max()

    plt.imshow(cam, cmap='jet', alpha=0.5)
    plt.axis('off')
    plt.show()


In [None]:


import pandas as pd
from torch.utils.data import Subset, DataLoader, Dataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Models and MixStyle support
models = ["resnet50", "mobilenetv2_100", "convnext_tiny", "swin_tiny_patch4_window7_224", "densenet121", "medvit"]
cnn_mixstyle_models = ["resnet50", "mobilenetv2_100", "convnext_tiny", "densenet121"]

# Directory to save results
results_dir = "/content/drive/MyDrive/Domain_Generalization"
os.makedirs(results_dir, exist_ok=True)
checkpoint_dir = "/content/drive/MyDrive/checkpoints1"
results = {}


btd_label_map = {0:1, 1:1, 2:0, 3:1}

class BinaryBTDDataset(Dataset):
    def __init__(self, subset, mapped_labels):
        self.subset = subset
        self.labels = mapped_labels
    def __len__(self):
        return len(self.subset)
    def __getitem__(self, idx):
        img, _ = self.subset[idx]
        return img, self.labels[idx]

btd_train_ds = ImageFolder(f"{btd_root}/train", transform=train_tfms)
few_shot_indices = list(range(int(0.5 * len(btd_train_ds))))  # 10% for few-shot
few_shot_labels = [btd_label_map[btd_train_ds.imgs[i][1]] for i in few_shot_indices]
btd_few_shot_ds = Subset(btd_train_ds, few_shot_indices)

few_shot_loader_binary = DataLoader(
    BinaryBTDDataset(btd_few_shot_ds, few_shot_labels),
    batch_size=32, shuffle=True, num_workers=2
)


btd_test_labels = [btd_label_map[label] for _, label in btd_test_ds.imgs]
test_loader_binary = DataLoader(
    BinaryBTDDataset(btd_test_ds, btd_test_labels),
    batch_size=32, shuffle=False, num_workers=2
)


for m in models:

    model = get_model(m, num_classes=2, use_mixstyle=False).to(device)

    baseline_ckpt = f"{checkpoint_dir}/{m}_baseline_best.pth"
    if os.path.exists(baseline_ckpt):
        model.load_state_dict(torch.load(baseline_ckpt, map_location=device))

    print(f"\nüîπ Fine-tuning {m}_baseline on few-shot BTD")
    model = fine_tune_model(model, few_shot_loader_binary, epochs=45, lr=1e-5)

    metrics = evaluate_model(model, test_loader_binary)
    results[f"{m}_baseline_finetune"] = metrics
    print(f"{m}_baseline_finetune: Acc={metrics['accuracy']:.6f}, F1={metrics['f1']:.6f}")


    if m in cnn_mixstyle_models:
        model = get_model(m, num_classes=2, use_mixstyle=True).to(device)

        mixstyle_ckpt = f"{checkpoint_dir}/{m}_mixstyle_best.pth"
        if os.path.exists(mixstyle_ckpt):
            model.load_state_dict(torch.load(mixstyle_ckpt, map_location=device))

        print(f"\nüîπ Fine-tuning {m}_mixstyle on few-shot BTD")
        model = fine_tune_model(model, few_shot_loader_binary, epochs=10, lr=1e-5)

        metrics_mix = evaluate_model(model, test_loader_binary)
        results[f"{m}_mixstyle_finetune"] = metrics_mix
        print(f"{m}_mixstyle_finetune: Acc={metrics_mix['accuracy']:.6f}, F1={metrics_mix['f1']:.6f}")

df_results = pd.DataFrame(results).T
csv_path = f"{results_dir}/all_model_evaluation_finetune.csv"
df_results.to_csv(csv_path)
print(f"\n‚úÖ Evaluation results saved to {csv_path}")

df_results



üîπ Fine-tuning resnet50_baseline on few-shot BTD


Fine-tune Epoch 1/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:45<00:00,  1.97it/s]


Epoch 1, Loss: 0.6513


Fine-tune Epoch 2/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:36<00:00,  2.44it/s]


Epoch 2, Loss: 0.5169


Fine-tune Epoch 3/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:35<00:00,  2.57it/s]


Epoch 3, Loss: 0.4087


Fine-tune Epoch 4/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:36<00:00,  2.47it/s]


Epoch 4, Loss: 0.3240


Fine-tune Epoch 5/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:35<00:00,  2.53it/s]


Epoch 5, Loss: 0.2689


Fine-tune Epoch 6/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:37<00:00,  2.40it/s]


Epoch 6, Loss: 0.2343


Fine-tune Epoch 7/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:35<00:00,  2.53it/s]


Epoch 7, Loss: 0.2022


Fine-tune Epoch 8/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:36<00:00,  2.46it/s]


Epoch 8, Loss: 0.1863


Fine-tune Epoch 9/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:35<00:00,  2.51it/s]


Epoch 9, Loss: 0.1667


Fine-tune Epoch 10/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:35<00:00,  2.52it/s]


Epoch 10, Loss: 0.1513


Fine-tune Epoch 11/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:35<00:00,  2.50it/s]


Epoch 11, Loss: 0.1396


Fine-tune Epoch 12/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:35<00:00,  2.56it/s]


Epoch 12, Loss: 0.1324


Fine-tune Epoch 13/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:36<00:00,  2.44it/s]


Epoch 13, Loss: 0.1281


Fine-tune Epoch 14/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:42<00:00,  2.14it/s]


Epoch 14, Loss: 0.1183


Fine-tune Epoch 15/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:35<00:00,  2.51it/s]


Epoch 15, Loss: 0.1123


Fine-tune Epoch 16/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:36<00:00,  2.46it/s]


Epoch 16, Loss: 0.1100


Fine-tune Epoch 17/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:36<00:00,  2.49it/s]


Epoch 17, Loss: 0.1070


Fine-tune Epoch 18/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:36<00:00,  2.50it/s]


Epoch 18, Loss: 0.0997


Fine-tune Epoch 19/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:36<00:00,  2.50it/s]


Epoch 19, Loss: 0.0961


Fine-tune Epoch 20/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:36<00:00,  2.46it/s]


Epoch 20, Loss: 0.0997


Fine-tune Epoch 21/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:36<00:00,  2.48it/s]


Epoch 21, Loss: 0.0902


Fine-tune Epoch 22/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:34<00:00,  2.57it/s]


Epoch 22, Loss: 0.0879


Fine-tune Epoch 23/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:36<00:00,  2.46it/s]


Epoch 23, Loss: 0.0838


Fine-tune Epoch 24/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:35<00:00,  2.57it/s]


Epoch 24, Loss: 0.0815


Fine-tune Epoch 25/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:36<00:00,  2.48it/s]


Epoch 25, Loss: 0.0781


Fine-tune Epoch 26/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:34<00:00,  2.60it/s]


Epoch 26, Loss: 0.0700


Fine-tune Epoch 27/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:36<00:00,  2.47it/s]


Epoch 27, Loss: 0.0700


Fine-tune Epoch 28/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:35<00:00,  2.57it/s]


Epoch 28, Loss: 0.0668


Fine-tune Epoch 29/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:36<00:00,  2.50it/s]


Epoch 29, Loss: 0.0627


Fine-tune Epoch 30/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:35<00:00,  2.56it/s]


Epoch 30, Loss: 0.0612


Fine-tune Epoch 31/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:35<00:00,  2.52it/s]


Epoch 31, Loss: 0.0555


Fine-tune Epoch 32/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:35<00:00,  2.54it/s]


Epoch 32, Loss: 0.0554


Fine-tune Epoch 33/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:35<00:00,  2.53it/s]


Epoch 33, Loss: 0.0518


Fine-tune Epoch 34/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:35<00:00,  2.52it/s]


Epoch 34, Loss: 0.0546


Fine-tune Epoch 35/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:35<00:00,  2.54it/s]


Epoch 35, Loss: 0.0466


Fine-tune Epoch 36/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:36<00:00,  2.46it/s]


Epoch 36, Loss: 0.0474


Fine-tune Epoch 37/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:35<00:00,  2.55it/s]


Epoch 37, Loss: 0.0435


Fine-tune Epoch 38/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:36<00:00,  2.47it/s]


Epoch 38, Loss: 0.0402


Fine-tune Epoch 39/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:35<00:00,  2.55it/s]


Epoch 39, Loss: 0.0394


Fine-tune Epoch 40/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:36<00:00,  2.46it/s]


Epoch 40, Loss: 0.0373


Fine-tune Epoch 41/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:35<00:00,  2.52it/s]


Epoch 41, Loss: 0.0372


Fine-tune Epoch 42/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:36<00:00,  2.48it/s]


Epoch 42, Loss: 0.0318


Fine-tune Epoch 43/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:34<00:00,  2.59it/s]


Epoch 43, Loss: 0.0347


Fine-tune Epoch 44/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:36<00:00,  2.49it/s]


Epoch 44, Loss: 0.0310


Fine-tune Epoch 45/45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:35<00:00,  2.53it/s]


Epoch 45, Loss: 0.0291
resnet50_baseline_finetune: Acc=0.929062, F1=0.929832

üîπ Fine-tuning resnet50_mixstyle on few-shot BTD


Fine-tune Epoch 1/10:  60%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà    | 54/90 [00:22<00:15,  2.36it/s]


KeyboardInterrupt: 

In [None]:
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt


def summarize_metrics(y_true, y_pred, model_name):
    acc = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average=None)  # per-class
    recall = recall_score(y_true, y_pred, average=None)
    f1 = f1_score(y_true, y_pred, average=None)
    cm = confusion_matrix(y_true, y_pred)

    print(f"\n--- {model_name} ---")
    print("Accuracy:", acc)
    print("Per-class Precision:", precision)
    print("Per-class Recall:", recall)
    print("Per-class F1:", f1)
    print("Confusion Matrix:\n", cm)

    plt.figure(figsize=(5,4))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
    plt.title(f"{model_name} Confusion Matrix")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.show()

    return {
        "model_name": model_name,
        "accuracy": acc,
        "per_class_precision": precision,
        "per_class_recall": recall,
        "per_class_f1": f1,
        "confusion_matrix": cm
    }

metrics = train_model(model, train_loader, test_loader, "resnet50_baseline")

y_true = metrics["y_true"]
y_pred = metrics["y_pred"]

# Summarize metrics
metrics_summary = summarize_metrics(y_true, y_pred, "resnet50_baseline")


In [None]:
import pandas as pd
import numpy as np
import ast
import matplotlib.pyplot as plt
import seaborn as sns
import os

df = pd.read_csv("/content/drive/MyDrive/checkpoints/all_model_evaluation_with_confusion.csv", index_col=0)

labels = ["No Tumor", "Tumor"]

save_dir = "/content/drive/MyDrive/checkpoints/confusion_matrices"
os.makedirs(save_dir, exist_ok=True)

for model in df.index:
    cm_str = df.loc[model, "confusion_matrix"]
    cm = np.array(ast.literal_eval(cm_str))

    plt.figure(figsize=(4, 3))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels)
    plt.title(f"Confusion Matrix - {model}")
    plt.xlabel("Predicted")
    plt.ylabel("Actual")

    save_path = os.path.join(save_dir, f"{model}_confusion_matrix.png")
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()

print(f"‚úÖ All confusion matrices saved in: {save_dir}")


In [None]:
import pandas as pd
import numpy as np
import ast
import matplotlib.pyplot as plt
import seaborn as sns
import os
from matplotlib.backends.backend_pdf import PdfPages


df = pd.read_csv("/content/drive/MyDrive/checkpoints/all_model_evaluation_with_confusion.csv", index_col=0)


labels = ["No Tumor", "Tumor"]

save_dir = "/content/drive/MyDrive/checkpoints/confusion_matrices"
os.makedirs(save_dir, exist_ok=True)

pdf_path = os.path.join(save_dir, "all_confusion_matrices.pdf")
pdf = PdfPages(pdf_path)

for model in df.index:
    cm_str = df.loc[model, "confusion_matrix"]
    cm = np.array(ast.literal_eval(cm_str))

    plt.figure(figsize=(4, 3))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels)
    plt.title(f"Confusion Matrix - {model}")
    plt.xlabel("Predicted")
    plt.ylabel("Actual")

    save_path = os.path.join(save_dir, f"{model}_confusion_matrix.png")
    plt.savefig(save_path, bbox_inches="tight")


    pdf.savefig()


    plt.show()
    plt.close()

pdf.close()
print(f"‚úÖ Confusion matrices saved as PNGs in: {save_dir}")
print(f"‚úÖ Combined PDF saved at: {pdf_path}")


In [None]:

def visualize_grad_cam(model, loader, target_class=None, layer_name="blocks.11.norm1", num_images=3):
    model.eval()
    images, labels = next(iter(loader))

    for i in range(num_images):
        image = images[i].unsqueeze(0).to(device)
        label = labels[i].item()

        gradients = {}
        activations = {}


        def save_gradients(module, grad_input, grad_output):
            gradients['value'] = grad_output[0]

        def save_activations(module, input, output):
            activations['value'] = output


        layer = dict([*model.named_modules()])[layer_name]
        layer.register_forward_hook(save_activations)
        layer.register_backward_hook(save_gradients)


        output = model(image)
        if target_class is None:
            target_class = output.argmax().item()

        model.zero_grad()
        output[0, target_class].backward()

        grad = gradients['value'][0].cpu().detach().numpy()
        act = activations['value'][0].cpu().detach().numpy()

        weights = grad.mean(axis=-1).mean(axis=-1)
        cam = np.zeros(act.shape[-2:])
        for j, w in enumerate(weights):
            cam += w * act[j]
        cam = np.maximum(cam, 0)
        cam = cam / cam.max()


        plt.imshow(image[0].permute(1,2,0).cpu().numpy()*0.229 + 0.485)  # denormalize
        plt.imshow(cam, cmap='jet', alpha=0.5)
        plt.title(f"Label: {label}, Pred: {target_class}")
        plt.axis('off')
        plt.show()



In [None]:
import pandas as pd

flat_results = {}
for model_name, metrics in results.items():
    # Keep only scalar metrics (accuracy, f1, precision, recall)
    flat_results[model_name] = {k: v for k, v in metrics.items() if k != "confusion_matrix"}

df_results = pd.DataFrame(flat_results).T
df_results.to_csv("/content/drive/MyDrive/cross_dataset_ablation_results.csv")
print("Saved flattened results to CSV")
df_results


In [None]:

import matplotlib.pyplot as plt

df = pd.read_csv("cross_dataset_ablation_results.csv", index_col=0)

cnn_models = ["resnet50", "mobilenetv2_100", "convnext_tiny", "densenet121"]

metrics = ["accuracy", "f1"]

for metric in metrics:
    baseline_vals = [df_results.loc[f"{m}_baseline", "accuracy"] for m in cnn_models]
    mixstyle_vals  = [df_results.loc[f"{m}_mixstyle", "accuracy"] for m in cnn_models]

    x = np.arange(len(cnn_models))
    width = 0.35

    plt.figure(figsize=(10,5))
    plt.bar(x - width/2, baseline_vals, width, label='Baseline')
    plt.bar(x + width/2, mixstyle_vals, width, label='MixStyle')
    plt.ylabel(metric.capitalize())
    plt.xlabel("CNN Models")
    plt.title(f"{metric.capitalize()} Comparison: Baseline vs MixStyle")
    plt.xticks(x, cnn_models)
    plt.legend()
    plt.show()


In [None]:
from collections import Counter

print("y_true counts:", Counter(y_true))
print("y_pred counts:", Counter(y_pred))
