# Imports and Dataset Setup 

In [15]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, models
import torch.optim as optim

from medmnist import INFO, ChestMNIST
import medmnist

torch.manual_seed(42)

<torch._C.Generator at 0x10c0d7cb0>

# Load ChestMNIST Dataset + Preprocessing

In [16]:
# Get dataset info
info = INFO['chestmnist']
n_classes = len(info['label'])

# Data transformations
data_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
])

# Load training and test datasets
train_dataset = ChestMNIST(split='train', transform=data_transform, download=True)
test_dataset = ChestMNIST(split='test', transform=data_transform, download=True)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

#  Define and Modify the Model (ResNet18)

In [17]:
# Load ResNet-18 with pretrained weights
model = models.resnet18(pretrained=True)

# Modify the final fully connected layer
# ChestMNIST is multi-label with 14 binary labels
model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, 14),
    nn.Sigmoid()  # Sigmoid for multi-label classification
)
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Loss Function, Optimizer & Training Loop Setup

In [18]:
# Binary Cross Entropy for multi-label classification
criterion = nn.BCELoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 5  # you can increase later
model.train()

for epoch in range(num_epochs):
    running_loss = 0.0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device).float()  # convert labels to float for BCELoss

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward + Optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss:.4f}")

KeyboardInterrupt: 