# CNN Training Notebook
This notebook trains a CNN on CRC histology images, saving the best model based on validation accuracy.

In [11]:
import os
from glob import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader


In [12]:
# Configuration
train_dir = "./data/NCT-CRC-HE-100K/NCT-CRC-HE-100K"
val_dir = "./data/CRC-VAL-HE-7K/CRC-VAL-HE-7K"
epochs = 20
batch_size = 32
lr = 1e-3
patience = 3
log_interval = 100
num_workers = 4


In [13]:
# Set Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)


Device: cuda


In [14]:
# Data Transforms and DataLoaders
train_tf = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(0.05,0.05,0.05,0.05),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
val_tf = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
train_ds = datasets.ImageFolder(train_dir, transform=train_tf)
val_ds = datasets.ImageFolder(val_dir, transform=val_tf)
train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=(device.type=='cuda')
)
val_loader = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=(device.type=='cuda')
)
print(f'Train samples: {len(train_ds)}, Val samples: {len(val_ds)}')


Train samples: 100000, Val samples: 7180


In [15]:
# Model Definition
class CNNModel(nn.Module):
    def __init__(self, num_classes=9):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64,128,3, padding=1)
        self.bn3   = nn.BatchNorm2d(128)
        self.pool  = nn.MaxPool2d(2,2)
        self.dropout_conv = nn.Dropout2d(0.25)
        self.fc1   = nn.Linear(128*28*28, 512)
        self.dropout_fc = nn.Dropout(0.5)
        self.fc2   = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = self.dropout_conv(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout_fc(x)
        return self.fc2(x)


In [16]:
# Training and Evaluation Functions
def train_one_epoch(model, loader, criterion, optimizer, device, log_interval=100):
    model.train()
    running_loss, seen = 0.0, 0
    for i, (imgs, labels) in enumerate(loader, 1):
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad(set_to_none=True)
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * labels.size(0)
        seen += labels.size(0)

        if i % log_interval == 0:
            print(f" Train [{i}/{len(loader)}] loss: {running_loss/seen:.6f}")
            running_loss, seen = 0.0, 0

def validate(model, loader, criterion, device):
    model.eval()
    val_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            val_loss += criterion(outputs, labels).item()
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return val_loss / len(loader), 100.0 * correct / total

def test(model, loader, criterion, device):
    print("\n==> Testing on validation set")
    loss, acc = validate(model, loader, criterion, device)
    print(f" Test Loss: {loss:.4f}, Accuracy: {acc:.2f}%\n")
    return loss, acc


In [17]:
# Initialize Model, Loss, and Optimizer
model = CNNModel(num_classes=len(train_ds.classes)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)


In [18]:
# Training Loop with Accuracy-based Checkpointing
best_val_acc = 0.0
patience_cnt = 0
for epoch in range(1, epochs+1):
    print(f"\nEpoch {epoch}/{epochs}")
    train_one_epoch(model, train_loader, criterion, optimizer, device, log_interval)
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    print(f" Validation Loss: {val_loss:.4f}, Accuracy: {val_acc:.2f}%")
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_cnt = 0
        torch.save(model.state_dict(), 'best_model.pth')
        print("  -> Best model saved.")
    else:
        patience_cnt += 1
        print(f"  -> No improvement. Patience {patience_cnt}/{patience}")
        if patience_cnt > patience:
            print("  -> Early stopping.")
            break



Epoch 1/20
 Train [100/3125] loss: 8.859043
 Train [200/3125] loss: 1.651279
 Train [300/3125] loss: 1.549303
 Train [400/3125] loss: 1.563558
 Train [500/3125] loss: 1.511907
 Train [600/3125] loss: 1.499458
 Train [700/3125] loss: 1.475703
 Train [800/3125] loss: 1.457913
 Train [900/3125] loss: 1.434758
 Train [1000/3125] loss: 1.472282
 Train [1100/3125] loss: 1.449750
 Train [1200/3125] loss: 1.395694
 Train [1300/3125] loss: 1.403654
 Train [1400/3125] loss: 1.472466
 Train [1500/3125] loss: 1.407295
 Train [1600/3125] loss: 1.388527
 Train [1700/3125] loss: 1.392254
 Train [1800/3125] loss: 1.378481
 Train [1900/3125] loss: 1.403178
 Train [2000/3125] loss: 1.389321
 Train [2100/3125] loss: 1.452788
 Train [2200/3125] loss: 1.425283
 Train [2300/3125] loss: 1.455182
 Train [2400/3125] loss: 1.365283
 Train [2500/3125] loss: 1.402053
 Train [2600/3125] loss: 1.379417
 Train [2700/3125] loss: 1.333708
 Train [2800/3125] loss: 1.376579
 Train [2900/3125] loss: 1.365307
 Train [300