<a href="https://colab.research.google.com/github/SuwethaV/Medmnist/blob/main/Med_Mnist_Multi_scale_%2B_Dual_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
hf!pip install medmnist

Collecting medmnist
  Downloading medmnist-3.0.2-py3-none-any.whl.metadata (14 kB)
Collecting fire (from medmnist)
  Downloading fire-0.7.1-py3-none-any.whl.metadata (5.8 kB)
Downloading medmnist-3.0.2-py3-none-any.whl (25 kB)
Downloading fire-0.7.1-py3-none-any.whl (115 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.9/115.9 kB[0m [31m1.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fire, medmnist
Successfully installed fire-0.7.1 medmnist-3.0.2


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
import medmnist
import os
from medmnist import INFO

In [None]:
# Reproducibility (optional)
torch.manual_seed(42)

<torch._C.Generator at 0x7882699ba370>

In [None]:
# Hyperparameters / settings (edit these)
dataset_flag = 'pathmnist' # e.g., 'pathmnist', 'bloodmnist', 'dermamnist', 'octmnist', 'pneumoniamnist'
img_size = 64
batch_size = 128
epochs = 10
learning_rate = 1e-3
save_dir = 'simple_out'


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"[Step 1] Using device: {device}")

[Step 1] Using device: cpu


In [None]:
# ----- 2) Load dataset & dataloaders -----
print(f"[Step 2] Preparing MedMNIST dataset: {dataset_flag}")
info = INFO[dataset_flag]
num_classes = len(info['label'])
DataClass = getattr(medmnist, info['python_class'])


# Transforms: make sure we convert to PIL first, then resize -> tensor -> normalize
common_tf = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


# as_rgb=True ensures 3 channels for all datasets
train_dataset = DataClass(split='train', transform=common_tf, download=True, as_rgb=True)
val_dataset = DataClass(split='val', transform=common_tf, download=True, as_rgb=True)
test_dataset = DataClass(split='test', transform=common_tf, download=True, as_rgb=True)


train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)


print(f" Train batches: {len(train_loader)} | Val batches: {len(val_loader)} | Test batches: {len(test_loader)}")
print(f" Classes: {num_classes} | Image size: {img_size}x{img_size}")

[Step 2] Preparing MedMNIST dataset: pathmnist


 29%|██▉       | 59.2M/206M [02:00<04:25, 551kB/s]

In [None]:
# Step 3: Data Preparation (Fixed)
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import medmnist
from medmnist import INFO

data_flag = 'pathmnist'
download = True

info = INFO[data_flag]
task = info['task']
n_classes = len(info['label'])

# ✅ Fixed transforms (removed ToPILImage)
data_transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])

# Load datasets
train_dataset = getattr(medmnist, info['python_class'])(split='train', transform=data_transform, download=download)
val_dataset   = getattr(medmnist, info['python_class'])(split='val', transform=data_transform, download=download)
test_dataset  = getattr(medmnist, info['python_class'])(split='test', transform=data_transform, download=download)

# Dataloaders
train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)
val_loader   = DataLoader(dataset=val_dataset, batch_size=128, shuffle=False)
test_loader  = DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)

print("Data loaded successfully!")

In [None]:
# ----- 3) Define Dual Attention blocks (small + readable) -----
class ChannelAttention(nn.Module):
  def __init__(self, in_planes, ratio=8):
    super().__init__()
    self.avg_pool = nn.AdaptiveAvgPool2d(1)
    self.max_pool = nn.AdaptiveMaxPool2d(1)
    hidden = max(4, in_planes // ratio)
    self.fc1 = nn.Conv2d(in_planes, hidden, 1, bias=False)
    self.relu = nn.ReLU(inplace=True)
    self.fc2 = nn.Conv2d(hidden, in_planes, 1, bias=False)
    self.sigmoid = nn.Sigmoid()
  def forward(self, x):
    avg_out = self.fc2(self.relu(self.fc1(self.avg_pool(x))))
    max_out = self.fc2(self.relu(self.fc1(self.max_pool(x))))
    return self.sigmoid(avg_out + max_out)


class SpatialAttention(nn.Module):
  def __init__(self, kernel_size=7):
    super().__init__()
    padding = kernel_size // 2
    self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
    self.sigmoid = nn.Sigmoid()
  def forward(self, x):
    avg_out = torch.mean(x, dim=1, keepdim=True)
    max_out, _ = torch.max(x, dim=1, keepdim=True)
    attn = self.sigmoid(self.conv(torch.cat([avg_out, max_out], dim=1)))
    return x * attn


# Tiny multi-scale head: three different kernel sizes, then fuse and attend
class MultiScaleDualAttentionCNN(nn.Module):
  def __init__(self, num_classes):
    super().__init__()
    # Multi-scale convs (3,5,7) from the same input
    self.conv3 = nn.Conv2d(3, 32, 3, padding=1)
    self.conv5 = nn.Conv2d(3, 32, 5, padding=2)
    self.conv7 = nn.Conv2d(3, 32, 7, padding=3)
    self.bn = nn.BatchNorm2d(96)
    self.relu = nn.ReLU(inplace=True)
    self.ca = ChannelAttention(96)
    self.sa = SpatialAttention(7)
    self.pool = nn.AdaptiveAvgPool2d(1)
    self.fc = nn.Linear(96, num_classes)
  def forward(self, x):
    x = torch.cat([self.conv3(x), self.conv5(x), self.conv7(x)], dim=1) # fuse multi-scale
    x = self.relu(self.bn(x)) # norm + nonlinearity
    x = x * self.ca(x) # channel attention
    x = self.sa(x) # spatial attention (applied as residual gate)
    x = self.pool(x).flatten(1) # global pooling
    return self.fc(x)


model = MultiScaleDualAttentionCNN(num_classes).to(device)

In [None]:
# ----- 4) Loss & Optimizer -----
print("[Step 4] Setting up optimizer & loss")
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
print("[Step 5] Training…")
os.makedirs(save_dir, exist_ok=True)
best_val_acc = 0.0

for epoch in range(1, epochs + 1):
    # ---- Train ----
    model.train()
    running_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs} [train]", ncols=0)
    for batch_idx, (images, labels) in enumerate(pbar, start=1):
        images = images.to(device)
        labels = labels.squeeze().long().to(device)

        optimizer.zero_grad(set_to_none=True)
        logits = model(images)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        avg_loss = running_loss / (batch_idx * train_loader.batch_size)
        pbar.set_postfix(loss=f"{avg_loss:.4f}")

    train_loss = running_loss / len(train_loader.dataset)

    # ---- Validate ----
    model.eval()
    val_correct = 0
    val_total = 0
    val_running = 0.0
    pbar_val = tqdm(val_loader, desc=f"Epoch {epoch}/{epochs} [val] ", ncols=0)
    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(pbar_val, start=1):
            images = images.to(device)
            labels = labels.squeeze().long().to(device)

            logits = model(images)
            loss = criterion(logits, labels)

            val_running += loss.item() * images.size(0)
            preds = logits.argmax(1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)

            pbar_val.set_postfix(acc=f"{(val_correct/val_total)*100:.2f}%")

    val_loss = val_running / len(val_loader.dataset)
    val_acc = (val_correct / val_total) * 100.0
    print(f"[Epoch {epoch}] train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | val_acc={val_acc:.2f}%")

    # Save best
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_path = os.path.join(save_dir, f"{dataset_flag}_best.pth")
        torch.save(model.state_dict(), best_path)
        print(f" Saved best model to: {best_path}")

[Step 5] Training…


Epoch 1/10 [train]: 100% 704/704 [47:53<00:00,  4.08s/it, loss=1.0809]
Epoch 1/10 [val] : 100% 79/79 [01:47<00:00,  1.36s/it, acc=66.96%]


[Epoch 1] train_loss=1.0823 | val_loss=0.8670 | val_acc=66.96%
 Saved best model to: simple_out/pathmnist_best.pth


Epoch 2/10 [train]: 100% 704/704 [44:32<00:00,  3.80s/it, loss=0.7146]
Epoch 2/10 [val] : 100% 79/79 [01:47<00:00,  1.36s/it, acc=74.18%]


[Epoch 2] train_loss=0.7155 | val_loss=0.7147 | val_acc=74.18%
 Saved best model to: simple_out/pathmnist_best.pth


Epoch 3/10 [train]: 100% 704/704 [45:01<00:00,  3.84s/it, loss=0.6162]
Epoch 3/10 [val] : 100% 79/79 [01:48<00:00,  1.37s/it, acc=78.78%]


[Epoch 3] train_loss=0.6170 | val_loss=0.6032 | val_acc=78.78%
 Saved best model to: simple_out/pathmnist_best.pth


Epoch 4/10 [train]: 100% 704/704 [44:59<00:00,  3.83s/it, loss=0.5589]
Epoch 4/10 [val] : 100% 79/79 [01:46<00:00,  1.35s/it, acc=79.19%]


[Epoch 4] train_loss=0.5597 | val_loss=0.5751 | val_acc=79.19%
 Saved best model to: simple_out/pathmnist_best.pth


Epoch 5/10 [train]: 100% 704/704 [45:02<00:00,  3.84s/it, loss=0.5307]
Epoch 5/10 [val] : 100% 79/79 [01:46<00:00,  1.35s/it, acc=83.72%]


[Epoch 5] train_loss=0.5313 | val_loss=0.4871 | val_acc=83.72%
 Saved best model to: simple_out/pathmnist_best.pth


Epoch 6/10 [train]: 100% 704/704 [45:17<00:00,  3.86s/it, loss=0.4927]
Epoch 6/10 [val] : 100% 79/79 [02:00<00:00,  1.52s/it, acc=74.59%]


[Epoch 6] train_loss=0.4934 | val_loss=0.6923 | val_acc=74.59%


Epoch 7/10 [train]: 100% 704/704 [51:01<00:00,  4.35s/it, loss=0.4686]
Epoch 7/10 [val] : 100% 79/79 [02:13<00:00,  1.69s/it, acc=80.42%]


[Epoch 7] train_loss=0.4692 | val_loss=0.5460 | val_acc=80.42%


Epoch 8/10 [train]: 100% 704/704 [51:53<00:00,  4.42s/it, loss=0.4434]
Epoch 8/10 [val] : 100% 79/79 [01:49<00:00,  1.38s/it, acc=84.76%]


[Epoch 8] train_loss=0.4439 | val_loss=0.4591 | val_acc=84.76%
 Saved best model to: simple_out/pathmnist_best.pth


Epoch 9/10 [train]: 100% 704/704 [50:49<00:00,  4.33s/it, loss=0.4270]
Epoch 9/10 [val] : 100% 79/79 [02:00<00:00,  1.53s/it, acc=85.07%]


[Epoch 9] train_loss=0.4275 | val_loss=0.4340 | val_acc=85.07%
 Saved best model to: simple_out/pathmnist_best.pth


Epoch 10/10 [train]: 100% 704/704 [50:43<00:00,  4.32s/it, loss=0.4093]
Epoch 10/10 [val] : 100% 79/79 [01:50<00:00,  1.40s/it, acc=84.65%]

[Epoch 10] train_loss=0.4099 | val_loss=0.4238 | val_acc=84.65%





In [None]:
# ----- 6) Test evaluation -----
print("[Step 6] Testing best model on test set…")
# Load best checkpoint if available
best_path = os.path.join(save_dir, f"{dataset_flag}_best.pth")
if os.path.exists(best_path):
  model.load_state_dict(torch.load(best_path, map_location=device))


model.eval()
correct = 0
count = 0
pbar_test = tqdm(test_loader, desc="[test]", ncols=0)
with torch.no_grad():
  for images, labels in pbar_test:
    images = images.to(device)
    labels = labels.squeeze().long().to(device)
    logits = model(images)
    preds = logits.argmax(1)
    correct += (preds == labels).sum().item()
    count += labels.size(0)
    pbar_test.set_postfix(acc=f"{(correct/max(1,count))*100:.2f}%")


test_acc = (correct / max(1, count)) * 100.0
print(f"Final Test Accuracy: {test_acc:.2f}%")

[Step 6] Testing best model on test set…


[test]: 100% 57/57 [01:19<00:00,  1.39s/it, acc=78.70%]

Final Test Accuracy: 78.70%



