In [None]:
"""
File: CNN.ipynb
------------------
A beginner notebook on training and validating a CNN model on the MIMIC-CXR dataset.
"""

In [2]:
from tqdm import tqdm

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader
from torch.nn import Conv2d, Linear, ReLU, MaxPool2d, Dropout, BatchNorm2d

from lib.utils import seed_everything

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Device set to: {torch.cuda.get_device_name(device)}')

Device set to: NVIDIA GeForce RTX 4060 Laptop GPU


In [None]:
class config:
    seed = 23
    train_batch_size = 32
    valid_batch_size = 128
    learning_rate = 0.001
    num_epochs = 10
    num_classes = 2

seed_everything(config.seed)

In [2]:
class CNN(torch.nn.Module):
    """ 
    Basic CNN model for chest xray classification.
    """

    def __init__(self, input_shape, in_channels=1, output_size=2):
        super(CNN, self).__init__()

        # First convolutional layer
        self.conv1 = Conv2d(in_channels=in_channels, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.bn1 = BatchNorm2d(32)

        # Second convolutional layer
        self.conv2 = Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.bn2 = BatchNorm2d(64)

        # Third convolutional layer
        self.conv3 = Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.bn3 = BatchNorm2d(128)

        # Fourth convolutional layer
        self.conv4 = Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.bn4 = BatchNorm2d(256)

        # Fully connected layers
        self.fc1 = Linear(256*input_shape*input_shape, 512)
        self.fc2 = Linear(512, output_size)

        # Activation, pooling and dropout layers
        self.relu = ReLU()
        self.maxpool = MaxPool2d(kernel_size=2, stride=2)
        self.dropout = Dropout(p=0.5)
    

    def forward(self, x):
        # Pass through first convolutional block
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        # Pass through second convolutional lblock
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.maxpool(x)

        # Pass through third convolutional lablock
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)
        x = self.maxpool(x)

        # Pass through fourth convolutional lblock
        x = self.conv4(x)
        x = self.bn4(x)
        x = self.relu(x)
        x = self.maxpool(x)

        # Flatten the tensor
        x = x.view(x.size(0), -1)

        # Pass through first fully connected block
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        
        # Pass through second fully connectedblock
        x = self.fc2(x)
        return x

In [None]:
# Code to get dataset
train_dataset = None
val_dataset = None
input_shape = None

In [None]:
# Define train and validation dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, ranom_seed=config.seed)
val_dataloader = DataLoader(val_dataset, batch_size=config.valid_batch_size, shuffle=False)

# Create an instance of the CNN model
model = CNN(input_shape=input_shape, output_size=config.num_classes).to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)

# Training loop
for epoch in range(config.num_epochs):
    
    # Training
    model.train()
    train_loss = 0.0
    train_correct = 0
    total_train_samples = 0

    for inputs, labels in tqdm(train_dataloader,
                               unit="train batch", len=len(train_dataloader),
                               desc=f"Training epoch {epoch+1}/{config.num_epochs}..."):
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        train_correct += (predicted == labels).sum().item()
        total_train_samples += labels.size(0)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_accuracy = train_correct / total_train_samples
    train_loss /= len(train_dataloader)

    # Validation
    model.eval()
    val_loss = 0.0
    val_correct = 0
    total_val_samples = 0

    with torch.no_grad():
        for inputs, labels in tqdm(val_dataloader, 
                                   unit="val batch",len=len(train_dataloader),
                                   desc=f"Validating epoch {epoch+1}/{config.num_epochs}..."):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            val_correct += (predicted == labels).sum().item()
            total_val_samples += labels.size(0)

            loss = criterion(outputs, labels)
            val_loss += loss.item()

    val_accuracy = val_correct / total_val_samples
    val_loss /= len(val_dataloader)

    # Print epoch statistics
    print(f"\nEpoch {epoch+1}/{config.num_epochs}:")
    print(f"Train Loss: {train_loss:.4f} | Train Accuracy: {train_accuracy:.4f}")
    print(f"Val Loss: {val_loss:.4f} | Val Accuracy: {val_accuracy:.4f}")