<a href="https://colab.research.google.com/github/Moinuddin-Hasan/resnet50/blob/master/resnet50.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install tqdm for progress bars
!pip install -q tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import os
import time
from tqdm import tqdm

# For downloading the dataset
import requests
import tarfile

In [None]:
# --- Configuration ---
config = {
    "DATA_PATH": "/content/imagenette2-320",
    "MODEL_NAME": "resnet50",
    "NUM_CLASSES": 10,  # Imagenette has 10 classes
    "BATCH_SIZE": 64,   # Adjust based on Colab GPU memory
    "NUM_EPOCHS": 10,   # Number of epochs for this test run
    "LR": 0.01,         # Initial Learning Rate
    "MOMENTUM": 0.9,
    "WEIGHT_DECAY": 1e-4,
    "LR_STEP_SIZE": 5,  # Decay LR every 5 epochs
    "LR_GAMMA": 0.1,    # LR decay factor
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
    "CHECKPOINT_PATH": "outputs/checkpoints"
}

# Create output directory
os.makedirs(config["CHECKPOINT_PATH"], exist_ok=True)

print(f"Using device: {config['DEVICE']}")

In [None]:
# --- Dataset Preparation ---
def download_and_extract_imagenette():
    url = "https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz"
    target_path = '/content/imagenette2-320.tgz'
    extract_path = '/content/'

    if os.path.exists(config["DATA_PATH"]):
        print("Dataset already downloaded and extracted.")
        return

    print("Downloading Imagenette...")
    response = requests.get(url, stream=True)
    with open(target_path, "wb") as f:
        f.write(response.raw.read())

    print("Extracting dataset...")
    with tarfile.open(target_path, "r:gz") as tar:
        tar.extractall(path=extract_path)

    print("Dataset ready.")

download_and_extract_imagenette()

In [None]:
# --- Data Augmentation and Loaders ---

# Input size for ResNet is typically 224x224
input_size = 224

# Normalization values for ImageNet
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

# Data augmentation for training
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(input_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize,
])

# Just resize, center crop, and normalize for validation
val_transform = transforms.Compose([
    transforms.Resize(input_size + 32), # Resize to 256
    transforms.CenterCrop(input_size), # Center crop to 224
    transforms.ToTensor(),
    normalize,
])

# Create datasets
train_dataset = ImageFolder(os.path.join(config["DATA_PATH"], 'train'), transform=train_transform)
val_dataset = ImageFolder(os.path.join(config["DATA_PATH"], 'val'), transform=val_transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=config["BATCH_SIZE"], shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=config["BATCH_SIZE"], shuffle=False, num_workers=2, pin_memory=True)

print(f"Found {len(train_dataset)} training images in {len(train_dataset.classes)} classes.")
print(f"Found {len(val_dataset)} validation images in {len(val_dataset.classes)} classes.")

In [None]:
# --- Model Definition ---

# Load a ResNet50 model but do NOT use pre-trained weights
model = torchvision.models.resnet50(weights=None, num_classes=config["NUM_CLASSES"])

# Move the model to the configured device (GPU)
model = model.to(config["DEVICE"])

# Print model summary (optional)
# from torchsummary import summary
# summary(model, (3, 224, 224))

In [None]:
# --- Training Components ---

# Loss Function
criterion = nn.CrossEntropyLoss()

# Optimizer (SGD with Momentum)
optimizer = optim.SGD(
    model.parameters(),
    lr=config["LR"],
    momentum=config["MOMENTUM"],
    weight_decay=config["WEIGHT_DECAY"]
)

# Learning Rate Scheduler (Step Decay)
scheduler = StepLR(
    optimizer,
    step_size=config["LR_STEP_SIZE"],
    gamma=config["LR_GAMMA"]
)

In [None]:
# --- Training and Validation Logic ---

def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch):
    model.train()
    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    # Wrap data_loader with tqdm for a progress bar
    progress_bar = tqdm(data_loader, desc=f"Epoch {epoch+1}/{config['NUM_EPOCHS']} [T]")
    for inputs, labels in progress_bar:
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Statistics
        total_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total_samples += labels.size(0)
        correct_predictions += (predicted == labels).sum().item()

        # Update progress bar
        progress_bar.set_postfix(loss=total_loss/total_samples, acc=f"{(100*correct_predictions/total_samples):.2f}%")

    epoch_loss = total_loss / total_samples
    epoch_acc = correct_predictions / total_samples
    return epoch_loss, epoch_acc


def validate(model, criterion, data_loader, device):
    model.eval()
    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    with torch.no_grad(): # No gradients needed for validation
        progress_bar = tqdm(data_loader, desc="Validating")
        for inputs, labels in progress_bar:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            total_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()

            progress_bar.set_postfix(loss=total_loss/total_samples, acc=f"{(100*correct_predictions/total_samples):.2f}%")

    epoch_loss = total_loss / total_samples
    epoch_acc = correct_predictions / total_samples
    return epoch_loss, epoch_acc

In [None]:
# --- Main Training Loop ---

best_val_acc = 0.0

print("Starting training...")
start_time = time.time()

for epoch in range(config["NUM_EPOCHS"]):
    # --- Train ---
    train_loss, train_acc = train_one_epoch(model, criterion, optimizer, train_loader, config["DEVICE"], epoch)

    # --- Validate ---
    val_loss, val_acc = validate(model, criterion, val_loader, config["DEVICE"])

    # --- Log Results ---
    print(
        f"Epoch {epoch+1}/{config['NUM_EPOCHS']} | "
        f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
        f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}"
    )

    # --- Step the scheduler ---
    scheduler.step()

    # --- Save the best model ---
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_path = os.path.join(config["CHECKPOINT_PATH"], "best_model.pth")
        torch.save(model.state_dict(), best_model_path)
        print(f"New best model saved to {best_model_path} with accuracy: {val_acc:.4f}")


end_time = time.time()
print(f"Training finished in {(end_time - start_time)/60:.2f} minutes.")
print(f"Best validation accuracy: {best_val_acc:.4f}")