# Imports and Dataset Setup 

In [3]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, models

from medmnist import INFO, ChestMNIST
import medmnist

torch.manual_seed(42)

<torch._C.Generator at 0x10c0d7cb0>

# Load ChestMNIST Dataset + Preprocessing

In [4]:
# Get dataset info
info = INFO['chestmnist']
n_classes = len(info['label'])

# Data transformations
data_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Load training and test datasets
train_dataset = ChestMNIST(split='train', transform=data_transform, download=True)
test_dataset = ChestMNIST(split='test', transform=data_transform, download=True)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

100%|██████████| 82.8M/82.8M [00:08<00:00, 9.38MB/s]


#  Define and Modify the Model (ResNet18)

In [5]:
# Load ResNet-18 with pretrained weights
model = models.resnet18(pretrained=True)

# Modify the final fully connected layer
# ChestMNIST is multi-label with 14 binary labels
model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, 14),
    nn.Sigmoid()  # Sigmoid for multi-label classification
)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /Users/arrryyy/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:05<00:00, 8.48MB/s]
