# New Section (Below is updated code from ChatGPT)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install medmnist

Collecting medmnist
  Downloading medmnist-3.0.2-py3-none-any.whl.metadata (14 kB)
Collecting fire (from medmnist)
  Downloading fire-0.7.0.tar.gz (87 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/87.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.2/87.2 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading medmnist-3.0.2-py3-none-any.whl (25 kB)
Building wheels for collected packages: fire
  Building wheel for fire (setup.py) ... [?25l[?25hdone
  Created wheel for fire: filename=fire-0.7.0-py3-none-any.whl size=114249 sha256=5db8d0430426bc6ee99dc224f608936cfd19749c9da3a2ae200cef991e4c48cb
  Stored in directory: /root/.cache/pip/wheels/19/39/2f/2d3cadc408a8804103f1c34ddd4b9f6a93497b11fa96fe738e
Successfully built fire
Installing collected packages: fire, medmnist
Successfully installed fire-0.7.0 medmnist-3.0.2


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.models import resnet18
import medmnist
from medmnist import INFO, Evaluator
from sklearn.metrics import roc_auc_score
import numpy as np

In [None]:
# Step 1: Define device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Step 2: Define transformations for data augmentation
transform = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize based on medical image domain
])

In [None]:
from PIL import Image

class SingleLabelChestMNIST(Dataset):
    def __init__(self, split, transform=None, size=224):
        info = INFO['chestmnist']
        DataClass = getattr(medmnist, info['python_class'])
        self.data = DataClass(split=split, download=True, as_rgb=True, size=size)
        self.imgs = self.data.imgs
        self.labels = self.data.labels
        self.transform = transform

        # Filter out images that have more than one label
        self.single_label_indices = [i for i in range(len(self.labels)) if self.labels[i].sum() == 1]

    def __len__(self):
        return len(self.single_label_indices)

    def __getitem__(self, idx):
        img = self.imgs[self.single_label_indices[idx]]
        label = self.labels[self.single_label_indices[idx]].argmax()  # Convert one-hot to scalar label

        # Convert NumPy array to PIL image
        img = Image.fromarray(img)

        if self.transform:
            img = self.transform(img)

        return img, label

In [None]:
# Load data with the specified size
batch_size = 64
train_dataset = SingleLabelChestMNIST(split='train', transform=transform, size=224)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = SingleLabelChestMNIST(split='test', transform=transform, size=224)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Downloading https://zenodo.org/records/10519652/files/chestmnist_224.npz?download=1 to /root/.medmnist/chestmnist_224.npz


100%|██████████| 3.89G/3.89G [00:41<00:00, 93.2MB/s]


Using downloaded and verified file: /root/.medmnist/chestmnist_224.npz


In [None]:
# Step 5: Define the ResNet18 model
class ChestMNISTModel(nn.Module):
    def __init__(self, num_classes=14):
        super(ChestMNISTModel, self).__init__()
        self.model = resnet18(pretrained=True)
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)  # Modify for single channel
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)  # Modify output for 14 classes

    def forward(self, x):
        return self.model(x)

In [None]:
model = ChestMNISTModel(num_classes=14).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5, verbose=True)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 117MB/s]


In [None]:
# Step 6: Training function
def train(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Calculate accuracy
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100 * correct / total

    return epoch_loss, epoch_acc

In [None]:
# Step 7: Validation function
def validate(model, test_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for i, (images, labels) in enumerate(test_loader):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            running_loss += loss.item()

            # For AUC calculation
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(outputs.cpu().numpy())

    epoch_loss = running_loss / len(test_loader)
    epoch_acc = 100 * correct / total

    # Calculate AUC
    all_preds = torch.softmax(torch.tensor(all_preds), dim=1).numpy()
    auc = roc_auc_score(all_labels, all_preds, multi_class='ovr')

    return epoch_loss, epoch_acc, auc

In [None]:
# Step 8: Training loop with best model saving
num_epochs = 100
best_auc = 0.0  # Initialize best AUC score
best_model_path = '/content/drive/MyDrive/best_model.pth'

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")

    # Train for one epoch
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
    # Validate the model
    val_loss, val_acc, val_auc = validate(model, test_loader, criterion, device)

    # Print metrics for the current epoch
    print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2f}%")
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.2f}%, Validation AUC: {val_auc:.4f}")

    # Update the scheduler based on validation AUC
    scheduler.step(val_auc)

    # Save model if AUC improves
    if val_auc > best_auc:
        best_auc = val_auc
        torch.save(model.state_dict(), best_model_path)
        print(f"New best model saved with AUC: {best_auc:.4f}")

Epoch 1/100


  all_preds = torch.softmax(torch.tensor(all_preds), dim=1).numpy()


Train Loss: 2.1301, Train Accuracy: 31.68%
Validation Loss: 2.1157, Validation Accuracy: 32.98%, Validation AUC: 0.6661
New best model saved with AUC: 0.6661
Epoch 2/100
Train Loss: 2.0143, Train Accuracy: 35.19%
Validation Loss: 2.1708, Validation Accuracy: 35.05%, Validation AUC: 0.6766
New best model saved with AUC: 0.6766
Epoch 3/100
Train Loss: 1.9407, Train Accuracy: 37.48%
Validation Loss: 2.0330, Validation Accuracy: 35.90%, Validation AUC: 0.7093
New best model saved with AUC: 0.7093
Epoch 4/100
Train Loss: 1.8890, Train Accuracy: 39.10%
Validation Loss: 1.9489, Validation Accuracy: 38.07%, Validation AUC: 0.7401
New best model saved with AUC: 0.7401
Epoch 5/100
Train Loss: 1.8558, Train Accuracy: 40.26%
Validation Loss: 1.9083, Validation Accuracy: 38.65%, Validation AUC: 0.7437
New best model saved with AUC: 0.7437
Epoch 6/100
Train Loss: 1.8172, Train Accuracy: 41.49%
Validation Loss: 1.8955, Validation Accuracy: 39.81%, Validation AUC: 0.7485
New best model saved with AUC:

KeyboardInterrupt: 