In [None]:
!pip install medmnist
!pip install torch torchvision torchaudio
!pip install transformers datasets

Collecting medmnist
  Downloading medmnist-3.0.2-py3-none-any.whl.metadata (14 kB)
Collecting fire (from medmnist)
  Downloading fire-0.7.1-py3-none-any.whl.metadata (5.8 kB)
Downloading medmnist-3.0.2-py3-none-any.whl (25 kB)
Downloading fire-0.7.1-py3-none-any.whl (115 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.9/115.9 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fire, medmnist
Successfully installed fire-0.7.1 medmnist-3.0.2


In [None]:
from medmnist import PathMNIST
from medmnist import INFO
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from torchvision import models, transforms
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import medmnist
import torchvision

In [None]:
# get dataset info
info = INFO['pathmnist']
print(info.keys())

num_classes = len(info["label"].keys())
print("Number of classes:", num_classes)
print("num_classes =", type(num_classes))


# Define transform: resize to 224 for ResNet, normalize like ImageNet
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Load splits
DatasetClass = getattr(medmnist, info['python_class'])
train_dataset = DatasetClass(split='train', transform=transform, download=True)
val_dataset = DatasetClass(split='val', transform=transform, download=True)
test_dataset = DatasetClass(split='test', transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

dict_keys(['python_class', 'description', 'url', 'MD5', 'url_64', 'MD5_64', 'url_128', 'MD5_128', 'url_224', 'MD5_224', 'task', 'label', 'n_channels', 'n_samples', 'license'])
Number of classes: 9
num_classes = <class 'int'>


100%|██████████| 206M/206M [00:18<00:00, 11.3MB/s]


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pretrained model
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
#model = models.resnet18(weights=models.IMAGENET1K_V1)

# Replace final layer to match 9 classes
model.fc = nn.Linear(model.fc.in_features, num_classes)

model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 101MB/s]


In [None]:
epochs = 3

for epoch in range(epochs):

  model.train()
  total_loss = 0

  for images, labels in train_loader:
    images, labels = images.to(device), labels.squeeze().to(device)

    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, labels)

    loss.backward()
    optimizer.step()

    total_loss += loss.item()

print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss:.4f}")

KeyboardInterrupt: 

In [None]:
model.eval()

correct = 0
total = 0

with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.squeeze().to(device)

        outputs = model(images)
        _, pred = torch.max(outputs.data, 1)

        total += labels.size(0)
        correct += (pred == labels).sum().item()

print(f"Validation Accuracy: {100 * correct/total:.2f}%")

In [None]:
model.eval()
all_labels = []
all_preds = []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.squeeze().to(device)

        outputs = model(images)
        _, pred = torch.max(outputs, 1)

        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(pred.cpu().numpy())

cm = confusion_matrix(all_labels, all_preds)
disp = ConfusionMatrixDisplay(cm)
disp.plot(xticks_rotation='vertical')
plt.show()