# New Section (Below is updated code from ChatGPT)

**Summary of Fixes**
* Data Augmentation: Increased the range of transformations (rotation, vertical
flip, jitter).
* Dropout: Added dropout layers to the model to help regularize it.
* Early Stopping: Implemented early stopping to prevent overfitting by stopping training when the validation AUC doesn't improve for a certain number of epochs.
* L2 Regularization: Added weight decay to the Adam optimizer to penalize large weights.
* Lower Learning Rate: Reduced the learning rate to allow the model to train more smoothly and avoid overfitting.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install medmnist

Collecting medmnist
  Downloading medmnist-3.0.2-py3-none-any.whl.metadata (14 kB)
Collecting fire (from medmnist)
  Downloading fire-0.7.0.tar.gz (87 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/87.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.2/87.2 kB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading medmnist-3.0.2-py3-none-any.whl (25 kB)
Building wheels for collected packages: fire
  Building wheel for fire (setup.py) ... [?25l[?25hdone
  Created wheel for fire: filename=fire-0.7.0-py3-none-any.whl size=114249 sha256=861fd717752879eb67780707c9158049f568ef935a2967b0a9756e21c7fead78
  Stored in directory: /root/.cache/pip/wheels/19/39/2f/2d3cadc408a8804103f1c34ddd4b9f6a93497b11fa96fe738e
Successfully built fire
Installing collected packages: fire, medmnist
Successfully installed fire-0.7.0 medmnist-3.0.2


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.models import resnet18
import medmnist
from medmnist import INFO, Evaluator
from sklearn.metrics import roc_auc_score
import numpy as np

In [None]:
# Step 1: Define device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Overfitting occurs when the model performs well on the training data but poorly on the validation data. This usually happens when the model becomes too complex relative to the amount of data, or when the training data is not sufficiently diverse.

To fix overfitting, here are several strategies you can try:
Increase Data Augmentation: You can apply more diverse transformations to the images, which will help the model generalize better to unseen data.

Use Dropout: Adding dropout layers to the model helps regularize it by randomly "dropping" neurons during training, forcing the model to not rely too heavily on any single feature.

Reduce Model Complexity: ResNet18 is a relatively deep model. You could try using a smaller model like ResNet34 or even a simpler model architecture if overfitting is severe.

Early Stopping: Monitor the validation loss or AUC and stop training if it stops improving. This will help avoid training for too long, which can cause overfitting.

Regularization (L2 weight decay): Adding L2 regularization (weight decay) to the optimizer helps prevent the model from fitting too tightly to the training data.

Reduce Learning Rate: If the learning rate is too high, the model might overfit due to too large updates during training.

Here's how you can apply these techniques to your code:
1. Increase Data Augmentation
Add more augmentations to introduce more variability in the training set:

In [None]:
transform = transforms.Compose([
    transforms.RandomRotation(15),  # Increased rotation range
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),  # Add vertical flip
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Color jitter
    transforms.RandomAffine(10, shear=5),  # Apply affine transformations
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize based on medical image domain
])

In [None]:
from PIL import Image

class SingleLabelChestMNIST(Dataset):
    def __init__(self, split, transform=None, size=224):
        info = INFO['chestmnist']
        DataClass = getattr(medmnist, info['python_class'])
        self.data = DataClass(split=split, download=True, as_rgb=True, size=size)
        self.imgs = self.data.imgs
        self.labels = self.data.labels
        self.transform = transform

        # Filter out images that have more than one label
        self.single_label_indices = [i for i in range(len(self.labels)) if self.labels[i].sum() == 1]

    def __len__(self):
        return len(self.single_label_indices)

    def __getitem__(self, idx):
        img = self.imgs[self.single_label_indices[idx]]
        label = self.labels[self.single_label_indices[idx]].argmax()  # Convert one-hot to scalar label

        # Convert NumPy array to PIL image
        img = Image.fromarray(img)

        if self.transform:
            img = self.transform(img)

        return img, label

In [None]:
# Load data with the specified size
batch_size = 64
train_dataset = SingleLabelChestMNIST(split='train', transform=transform, size=224)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = SingleLabelChestMNIST(split='test', transform=transform, size=224)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Downloading https://zenodo.org/records/10519652/files/chestmnist_224.npz?download=1 to /root/.medmnist/chestmnist_224.npz


100%|██████████| 3.89G/3.89G [05:23<00:00, 12.0MB/s]


Using downloaded and verified file: /root/.medmnist/chestmnist_224.npz


* Add Dropout Layers to the Model
You can modify the ResNet model by adding dropout layers after each block, or between the fully connected layers. Here's how to add dropout in the modified ChestMNISTModel class:

In [None]:
class ChestMNISTModel(nn.Module):
    def __init__(self, num_classes=14):
        super(ChestMNISTModel, self).__init__()
        self.model = resnet18(pretrained=True)
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)  # Modify for single channel
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)  # Modify output for 14 classes

        # Add Dropout layer
        self.dropout = nn.Dropout(p=0.5)  # 50% dropout

    def forward(self, x):
        x = self.model(x)
        x = self.dropout(x)  # Apply dropout after the fully connected layer
        return x

* L2 Regularization (Weight Decay)
You can add L2 regularization by passing a weight_decay parameter to the Adam optimizer

* Reduce Learning Rate
You can reduce the learning rate if the model seems to be overfitting. This can help the model converge more smoothly:

In [None]:
model = ChestMNISTModel(num_classes=14).to(device)
criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr=0.001)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)  # L2 regularization (weight decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5, verbose=True)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 142MB/s]


In [None]:
# Step 6: Training function
def train(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Calculate accuracy
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100 * correct / total

    return epoch_loss, epoch_acc

In [None]:
# Step 7: Validation function
def validate(model, test_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for i, (images, labels) in enumerate(test_loader):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            running_loss += loss.item()

            # For AUC calculation
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(outputs.cpu().numpy())

    epoch_loss = running_loss / len(test_loader)
    epoch_acc = 100 * correct / total

    # Calculate AUC
    all_preds = torch.softmax(torch.tensor(all_preds), dim=1).numpy()
    auc = roc_auc_score(all_labels, all_preds, multi_class='ovr')

    return epoch_loss, epoch_acc, auc

* Use Early Stopping
You can implement early stopping to stop training if the validation AUC doesn't improve for a specified number of epochs. Here's an example of how to add early stopping based on validation AUC:

In [None]:
import os
import torch

# Define the path to the saved model
best_model_path = '/content/drive/MyDrive/best_model_chtgptR18.pth'

# Step 1: Load the model if it exists, or initialize a new one if not
def load_model(model, optimizer, scheduler, model_path):
    if os.path.exists(model_path):
        print(f"Loading saved model from {model_path}...")
        checkpoint = torch.load(model_path)  # Load the entire checkpoint dictionary
        model.load_state_dict(checkpoint['model_state_dict'])  # Load model weights using correct key

        # If the checkpoint includes optimizer and scheduler states, load them as well
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

        # Optionally, load epoch counter if you saved it
        epoch_start = checkpoint.get('epoch', 0)
        best_auc = checkpoint.get('best_auc', 0.0)

        print("Model loaded successfully. Continuing training...")
    else:
        print("No saved model found, initializing a new model...")
        epoch_start = 0
        best_auc = 0.0

    return model, optimizer, scheduler, epoch_start, best_auc


# Step 2: Save the model after each epoch
def save_model(model, optimizer, scheduler, epoch, auc, model_path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'best_auc': auc
    }, model_path)
    print(f"Model saved at epoch {epoch}, AUC: {auc:.4f}")

In [None]:
# Load the model (if exists) and continue from there
model, optimizer, scheduler, epoch_start, best_auc = load_model(model, optimizer, scheduler, best_model_path)

# Training loop
num_epochs = 1000
patience = 10  # Early stopping patience
epochs_without_improvement = 0

for epoch in range(epoch_start, num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")

    # Train for one epoch
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
    # Validate the model
    val_loss, val_acc, val_auc = validate(model, test_loader, criterion, device)

    # Print metrics for the current epoch
    print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2f}%")
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.2f}%, Validation AUC: {val_auc:.4f}")

    # Update the scheduler based on validation AUC
    scheduler.step(val_auc)

    # Save model if AUC improves
    if val_auc > best_auc:
        best_auc = val_auc
        save_model(model, optimizer, scheduler, epoch + 1, best_auc, best_model_path)  # Save model at improved AUC
        epochs_without_improvement = 0  # Reset counter
    else:
        epochs_without_improvement += 1

    # Early stopping condition
    if epochs_without_improvement >= patience:
        print("Early stopping triggered!")
        break

No saved model found, initializing a new model...
Epoch 1/1000
Train Loss: 2.3659, Train Accuracy: 23.41%
Validation Loss: 2.1719, Validation Accuracy: 30.82%, Validation AUC: 0.6530
Model saved at epoch 1, AUC: 0.6530
Epoch 2/1000
Train Loss: 2.3664, Train Accuracy: 23.30%
Validation Loss: 2.1745, Validation Accuracy: 33.06%, Validation AUC: 0.6720
Model saved at epoch 2, AUC: 0.6720
Epoch 3/1000
Train Loss: 2.3519, Train Accuracy: 23.68%
Validation Loss: 2.1866, Validation Accuracy: 33.10%, Validation AUC: 0.6550
Epoch 4/1000
Train Loss: 2.3295, Train Accuracy: 24.42%
Validation Loss: 2.1699, Validation Accuracy: 33.92%, Validation AUC: 0.6840
Model saved at epoch 4, AUC: 0.6840
Epoch 5/1000
Train Loss: 2.3166, Train Accuracy: 24.64%
Validation Loss: 2.1371, Validation Accuracy: 32.64%, Validation AUC: 0.6920
Model saved at epoch 5, AUC: 0.6920
Epoch 6/1000
Train Loss: 2.3116, Train Accuracy: 25.00%
Validation Loss: 2.1273, Validation Accuracy: 33.55%, Validation AUC: 0.7019
Model sa