In [None]:
# Install necessary libraries
!pip install torch torchvision medmnist matplotlib


In [None]:
import os
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from medmnist import PathMNIST

# Ensure the root directory exists
root_dir = './data'
if not os.path.exists(root_dir):
    os.makedirs(root_dir)

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load the dataset
# Set download=True if not already downloaded
train_set = PathMNIST(root=root_dir, split='train', transform=transform, download=True)
val_set = PathMNIST(root=root_dir, split='val', transform=transform, download=True)
test_set = PathMNIST(root=root_dir, split='test', transform=transform, download=True)

# DataLoader
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
val_loader = DataLoader(val_set, batch_size=64, shuffle=False)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)

print("Dataset loaded successfully!")


In [None]:
import matplotlib.pyplot as plt

# Display 3x3 grid of images with labels
fig, axs = plt.subplots(3, 3, figsize=(8, 8))
for i in range(9):
    img, label = train_set[i]
    img = img[0]  # Take only the first channel to make it (28, 28)
    axs[i // 3, i % 3].imshow(img, cmap='gray')
    axs[i // 3, i % 3].set_title(f"Label: {label}")
    axs[i // 3, i % 3].axis('off')
plt.tight_layout()
plt.show()


In [None]:
import torch.nn as nn

class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)  # Change input channels to 3
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [None]:
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        # Reshape labels to 1D
        labels = labels.squeeze(1)  # Remove the extra dimension
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    # Validation accuracy after each epoch
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in val_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            # Reshape labels to 1D for validation as well
            labels = labels.squeeze(1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    # Print epoch loss and accuracy
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}, Accuracy: {100 * correct / total:.2f}%")
    
    # Save model at the end of each epoch (optional)
    torch.save(model.state_dict(), f"model_epoch_{epoch + 1}.pth")

# Save final model after all epochs
torch.save(model.state_dict(), "bscs22115_final_model.pth")


In [None]:
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt

# Define class names for PathMNIST (modify as needed if you have specific class names)
class_names = [f"Class {i}" for i in range(9)]  # Replace with actual class names if available

# Test model
model.eval()
y_true, y_pred = [], []
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        y_true.extend(labels.squeeze().tolist())  # Squeeze if labels have an extra dimension
        y_pred.extend(predicted.tolist())

# Confusion Matrix and Classification Report
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

print(classification_report(y_true, y_pred, target_names=class_names))
