# ResNet-50 Training Notebook
This notebook trains a ResNet-50 on CRC histology images, saving the best model based on validation accuracy.

In [9]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader


In [10]:
# Configuration
train_dir = './data/NCT-CRC-HE-100K/NCT-CRC-HE-100K'
val_dir   = './data/CRC-VAL-HE-7K/CRC-VAL-HE-7K'
epochs      = 20
batch_size  = 32
lr          = 1e-3
patience    = 3
log_interval = 100
num_workers = 4


In [11]:
# Set Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)


Device: cuda


In [12]:
# Data Transforms and DataLoaders
train_tf = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(0.05,0.05,0.05,0.05),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
val_tf = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
train_ds = datasets.ImageFolder(train_dir, transform=train_tf)
val_ds   = datasets.ImageFolder(val_dir,   transform=val_tf)
train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=(device.type=='cuda')
)
val_loader = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=(device.type=='cuda')
)
print(f'Train samples: {len(train_ds)}, Val samples: {len(val_ds)}')


Train samples: 100000, Val samples: 7180


In [13]:
# Model Definition: ResNet-50
model = models.resnet50(pretrained=True)
# Replace final fully connected layer
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, len(train_ds.classes))
model = model.to(device)




In [14]:
# Training and Evaluation Functions
def train_one_epoch(model, loader, criterion, optimizer, device, log_interval=100):
    model.train()
    running_loss = 0.0
    for i, (imgs, labels) in enumerate(loader, 1):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % log_interval == 0:
            print(f" Train [{i}/{len(loader)}]  loss: {running_loss/log_interval:.4f}")
            running_loss = 0.0

def validate(model, loader, criterion, device):
    model.eval()
    val_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            val_loss += criterion(outputs, labels).item()
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return val_loss / len(loader), 100.0 * correct / total

def test(model, loader, criterion, device):
    print("\n==> Testing on validation set")
    loss, acc = validate(model, loader, criterion, device)
    print(f" Test Loss: {loss:.4f}, Accuracy: {acc:.2f}%\n")
    return loss, acc


In [15]:
# Initialize Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)


In [16]:
# Training Loop with Accuracy-based Checkpointing
best_val_acc = 0.0
patience_cnt = 0
for epoch in range(1, epochs+1):
    print(f"\nEpoch {epoch}/{epochs}")
    train_one_epoch(model, train_loader, criterion, optimizer, device, log_interval)
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    print(f" Validation Loss: {val_loss:.4f}, Accuracy: {val_acc:.2f}%")
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_cnt = 0
        torch.save(model.state_dict(), 'best_resnet50.pth')
        best_model = model
        print("  -> Best model saved.")
    else:
        patience_cnt += 1
        print(f"  -> No improvement. Patience {patience_cnt}/{patience}")
        if patience_cnt > patience:
            print("  -> Early stopping.")
            break



Epoch 1/20
 Train [100/3125]  loss: 0.9024
 Train [200/3125]  loss: 0.5285
 Train [300/3125]  loss: 0.4378
 Train [400/3125]  loss: 0.3450
 Train [500/3125]  loss: 0.2835
 Train [600/3125]  loss: 0.2860
 Train [700/3125]  loss: 0.2768
 Train [800/3125]  loss: 0.2365
 Train [900/3125]  loss: 0.2315
 Train [1000/3125]  loss: 0.2606
 Train [1100/3125]  loss: 0.2471
 Train [1200/3125]  loss: 0.2257
 Train [1300/3125]  loss: 0.2099
 Train [1400/3125]  loss: 0.1951
 Train [1500/3125]  loss: 0.2224
 Train [1600/3125]  loss: 0.1875
 Train [1700/3125]  loss: 0.1658
 Train [1800/3125]  loss: 0.1960
 Train [1900/3125]  loss: 0.2182
 Train [2000/3125]  loss: 0.1628
 Train [2100/3125]  loss: 0.1960
 Train [2200/3125]  loss: 0.2267
 Train [2300/3125]  loss: 0.1940
 Train [2400/3125]  loss: 0.1918
 Train [2500/3125]  loss: 0.1742
 Train [2600/3125]  loss: 0.1674
 Train [2700/3125]  loss: 0.1759
 Train [2800/3125]  loss: 0.2011
 Train [2900/3125]  loss: 0.1480
 Train [3000/3125]  loss: 0.1713
 Train 

In [17]:
# Final Test and Save Final Model
test(best_model, val_loader, criterion, device)
torch.save(best_model.state_dict(), 'final_resnet50.pth')
print(f"Final model saved as 'final_resnet50.pth'. Best val accuracy: {best_val_acc:.2f}%")



==> Testing on validation set
 Test Loss: 0.3378, Accuracy: 88.54%

Final model saved as 'final_resnet50.pth'. Best val accuracy: 93.05%
