This version of the code will give you detailed feedback on the model’s performance throughout the training process, with clear metrics for both loss and accuracy. You can visually assess how well the model is learning and whether there are any signs of overfitting or underfitting based on the plots of training and validation metrics.

Here’s an updated version of the code with these improvements:


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install torch torchvision medmnist
!pip install git+https://github.com/MedMNIST/MedMNIST.git

In [None]:
!pip install scikit-learn
from sklearn.metrics import roc_auc_score

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, models
import medmnist
from medmnist import INFO, PathMNIST
from tqdm import tqdm
import matplotlib.pyplot as plt

# 1. Load and preprocess dataset (224x224 resolution)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Pretrained model normalization
])

info = INFO['pathmnist']
DataClass = getattr(medmnist, info['python_class'])

# Load the PathMNIST dataset with the desired resolution:
train_dataset = PathMNIST(split='train', download=True, transform=transform, as_rgb=True, size=224)
val_dataset = PathMNIST(split='val', download=True, transform=transform, as_rgb=True, size=224)
test_dataset = PathMNIST(split='test', download=True, transform=transform, as_rgb=True, size=224)

# DataLoader for training, validation, and testing
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
import medmnist

# Get info for PathMNIST
info = medmnist.INFO['pathmnist']

# Print the list of text labels
print("Text labels for PathMNIST:", info['label'])

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import medmnist

def analyze_split(dataset, split_name):
    """Analyzes and plots label distribution for a given dataset split."""
    labels = dataset.labels

    # Calculate label distribution
    unique_labels, label_counts = np.unique(labels, return_counts=True)

    # Get label text from the INFO dictionary in medmnist
    info = medmnist.INFO['pathmnist']
    label_text_values = list(info['label'].values())  # Get all values from label_text dictionary
    # print (label_text_values)


    # Plot label distribution using label text on x-axis
    plt.figure(figsize=(10, 5))
    plt.bar(range(len(unique_labels)), label_counts, color='skyblue')
    plt.title(f'Label Distribution in PathMNIST {split_name} Set')
    plt.xlabel('Labels')
    plt.ylabel('Frequency')

    # Set x-axis ticks and labels using label_text values
    plt.xticks(range(len(unique_labels)), label_text_values, rotation=45, ha='right')

    plt.tight_layout()
    plt.show()

    # Image distribution
    print(f"Image shape in {split_name} set:", dataset[0][0].shape)
    print(f"Total number of images in {split_name} set:", len(dataset))

# Analyze each split
analyze_split(train_dataset, "Train")
analyze_split(val_dataset, "Validation")
analyze_split(test_dataset, "Test")

In [None]:
# get a glimpse into the dataset
x, y = train_dataset[0]
print(x.shape, y.shape)
# torch.Size([3, 224, 224]) (1,)
train_dataset.montage(length=3)

In [None]:
# Add this import statement at the beginning of your code:
from sklearn.metrics import roc_auc_score
import numpy as np
import matplotlib.pyplot as plt
import medmnist
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import models
import torchvision.transforms as transforms
from tqdm import tqdm

# Define the Model with ResNet18 pretrained
class PathMNISTClassifier(nn.Module):
    def __init__(self, num_classes=9):  # PathMNIST has 9 classes
        super(PathMNISTClassifier, self).__init__()
        self.model = models.resnet18(pretrained=True)

        # Modify the final fully connected layer to match the number of classes in PathMNIST
        in_features = self.model.fc.in_features
        self.model.fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.model(x)

# Training loop with additional metrics (AUC, ACC)
def train_model(model, train_loader, criterion, optimizer, device, num_epochs=10):
    model.to(device)
    model.train()

    # Track metrics
    train_losses = []
    train_accuracies = []
    train_auc_scores = []

    for epoch in range(num_epochs):
        running_loss = 0.0
        correct_preds = 0
        total_preds = 0
        all_labels = []
        all_preds = []

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(images)
            labels = labels.squeeze(1)  # Remove the extra dimension
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct_preds += (predicted == labels).sum().item()
            total_preds += labels.size(0)

            # Store all labels and predictions for AUC calculation
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(torch.softmax(outputs, dim=1).cpu().detach().numpy())

        avg_loss = running_loss / len(train_loader)
        accuracy = (correct_preds / total_preds) * 100
        auc = roc_auc_score(np.array(all_labels), np.array(all_preds), multi_class='ovr', average='macro')

        train_losses.append(avg_loss)
        train_accuracies.append(accuracy)
        train_auc_scores.append(auc)

        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_loss:.4f}, Train Accuracy: {accuracy:.2f}%, Train AUC: {auc:.4f}")

    return train_losses, train_accuracies, train_auc_scores

# Validation loop with AUC and Accuracy
def evaluate_model(model, val_loader, criterion, device):
    model.to(device)
    model.eval()

    # Track metrics
    val_losses = []
    val_accuracies = []
    val_auc_scores = []

    with torch.no_grad():
        running_loss = 0.0
        correct_preds = 0
        total_preds = 0
        all_labels = []
        all_preds = []

        for images, labels in tqdm(val_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            labels = labels.squeeze(1)  # Remove the extra dimension
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct_preds += (predicted == labels).sum().item()
            total_preds += labels.size(0)

            # Store all labels and predictions for AUC calculation
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(torch.softmax(outputs, dim=1).cpu().detach().numpy())

        avg_loss = running_loss / len(val_loader)
        accuracy = (correct_preds / total_preds) * 100
        auc = roc_auc_score(np.array(all_labels), np.array(all_preds), multi_class='ovr', average='macro')

        val_losses.append(avg_loss)
        val_accuracies.append(accuracy)
        val_auc_scores.append(auc)

        print(f"Validation Loss: {avg_loss:.4f}, Validation Accuracy: {accuracy:.2f}%, Validation AUC: {auc:.4f}")

    return val_losses, val_accuracies, val_auc_scores

# Initialize Model, Loss, Optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = PathMNISTClassifier(num_classes=9)  # PathMNIST has 9 categories
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Track the best AUC
best_auc = 0.0
best_model_wts = None

# Train the model and evaluate
num_epochs = 15
train_losses, train_accuracies, train_auc_scores = train_model(model, train_loader, criterion, optimizer, device, num_epochs=num_epochs)
val_losses, val_accuracies, val_auc_scores = evaluate_model(model, val_loader, criterion, device)

# Save the model if best AUC is achieved
if val_auc_scores[-1] > best_auc:
    best_auc = val_auc_scores[-1]
    best_model_wts = model.state_dict()
    torch.save(best_model_wts, '/content/drive/MyDrive/ColabNotebooks/pathmnist_224_best_auc_model.pth')


In [None]:
# 7. Plot the tracked metrics
def plot_metrics(train_losses, train_accuracies, train_auc_scores, val_losses, val_accuracies, val_auc_scores):
    epochs = range(1, len(train_losses) + 1)

    # Repeat the single validation metric value for each epoch to match lengths
    val_losses = val_losses * len(epochs)
    val_accuracies = val_accuracies * len(epochs)
    val_auc_scores = val_auc_scores * len(epochs)

    # Plot Train and Validation Loss
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label='Train Loss', color='blue')
    plt.plot(epochs, val_losses, label='Validation Loss', color='red')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Train and Validation Loss')
    plt.legend()

    # Plot Train and Validation Accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, label='Train Accuracy', color='blue')
    plt.plot(epochs, val_accuracies, label='Validation Accuracy', color='red')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy (%)')
    plt.title('Train and Validation Accuracy')
    plt.legend()

    # Plot Train and Validation AUC
    plt.figure(figsize=(12, 6))
    plt.plot(epochs, train_auc_scores, label='Train AUC', color='blue')
    plt.plot(epochs, val_auc_scores, label='Validation AUC', color='red')
    plt.xlabel('Epochs')
    plt.ylabel('AUC')
    plt.title('Train and Validation AUC')
    plt.legend()

    plt.tight_layout()
    plt.show()

In [None]:
# Plot all tracked metrics
plot_metrics(train_losses, train_accuracies, train_auc_scores, val_losses, val_accuracies, val_auc_scores)

# Optional: Visualize some test images with labels (as before)
def plot_images_with_labels(loader, class_names):
    data_iter = iter(loader)
    images, labels = next(data_iter)
    images = images.numpy().transpose((0, 2, 3, 1))  # Convert to HWC format

    fig, axes = plt.subplots(2, 3, figsize=(12, 6))
    axes = axes.ravel() # Flatten axes to iterate through them
    for i in range(6):
        # axes[i] now refers to the correct subplot object
        axes[i].imshow(images[i])
        # Convert the label to a string to match the keys in class_names
        label = class_names[str(labels[i].item())]
        axes[i].set_title(f"Label: {label}")
        axes[i].axis('off')
    plt.show()

# Class names for PathMNIST dataset (corresponding to the 9 classes)
info = medmnist.INFO['pathmnist']
class_names = info['label']  # Get class names from the INFO dictionary

# Visualize some images from the test set with labels
plot_images_with_labels(test_loader, class_names)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import medmnist  # Make sure you have imported medmnist

def show_two_images_per_label(dataset):
    """Displays two images for each label type in the dataset."""

    # Get class names from medmnist.INFO
    class_names = medmnist.INFO['pathmnist']['label']
    num_classes = len(class_names)  # Get the number of classes

    # Create a dictionary to store images for each label
    images_by_label = {}
    for i in range(len(dataset)):
        image, label = dataset[i]
        label = label.item()  # Get the label as an integer

        if label not in images_by_label:
            images_by_label[label] = []

        images_by_label[label].append(image)

        # Stop collecting images for a label if we have 2
        if len(images_by_label[label]) == 2:
            continue

    # Plot the images
    fig, axes = plt.subplots(num_classes, 2, figsize=(12, 4 * num_classes))

    for label in range(num_classes):
        if label in images_by_label:
            # Ensure i doesn't exceed the bounds of the axes array
            for i, image in enumerate(images_by_label[label][:2]):  # Limit to 2 images
                image = image.numpy().transpose((1, 2, 0))  # Transpose to HWC format
                axes[label, i].imshow(image)
                axes[label, i].set_title(f"Label: {class_names[str(label)]}")  # Use class_names for title
                axes[label, i].axis('off')
        else:
            print(f"No images found for label: {class_names[str(label)]}")  # Use class_names for print statement

    plt.tight_layout()
    plt.show()

# Call the function to display the images
show_two_images_per_label(train_dataset)