In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision import datasets, transforms, models
from torch.utils.data import random_split, DataLoader
import random
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os


In [None]:
# Set seed for reproducibility
random.seed(42)
torch.manual_seed(42)
np.random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === Data loading ===
transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

full_train = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
test_set = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)

val_ratio = 0.1
val_size = int(len(full_train) * val_ratio)
train_size = len(full_train) - val_size

train_set, val_set = random_split(full_train, [train_size, val_size])

train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
val_loader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2)
test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)


In [None]:
# === Model ===
model = models.resnet18(weights='DEFAULT')  # Load pretrained
model.fc = nn.Linear(model.fc.in_features, 100)  # CIFAR-100 = 100 classes
model = model.to(device)

# === Load CIFAR-10-trained weights, excluding classifier ===
state_dict = torch.load("/content/drive/MyDrive/rec_model/resnet18/best_resnet18.pth")

# Remove fully connected layer weights by explicitly checking for 'fc.' prefix
filtered_dict = {k: v for k, v in state_dict.items() if not k.startswith("fc.")}
model.load_state_dict(filtered_dict, strict=False)  # ignore any remaining mismatch (should primarily be fc)


# === Loss, Optimizer, Scheduler ===
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1,
                patience=5, threshold=1e-3, threshold_mode='rel', verbose=True)


# === Training Config ===
epochs = 200
early_stop_patience = 15
best_val_acc = 0.0
no_improvement = 0
save_path = "/content/drive/MyDrive/cifar100model/best_resnet18.pth"

train_losses = []
train_accuracies = []
val_accuracies = []
lr_history = []

In [None]:
# === Training Loop ===
for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    correct = 0
    total = 0

    for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = epoch_loss / len(train_loader)
    train_acc = correct / total

    # === Validation ===
    model.eval()
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            preds = outputs.argmax(dim=1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)
    val_acc = val_correct / val_total

    # Logging
    train_losses.append(avg_loss)
    train_accuracies.append(train_acc)
    val_accuracies.append(val_acc)
    lr_history.append(optimizer.param_groups[0]['lr'])

    print(f"[Epoch {epoch+1}] Loss: {avg_loss:.4f} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")
    print(f"[Epoch {epoch+1}] LR: {optimizer.param_groups[0]['lr']:.6f}")

    scheduler.step(val_acc)

    # Early Stopping
    if val_acc > best_val_acc:
        print(f"New best val acc! Saving to {save_path}")
        best_val_acc = val_acc
        no_improvement = 0
        torch.save(model.state_dict(), save_path)
    else:
        no_improvement += 1
        print(f"No improvement for {no_improvement} epoch(s).")

    if no_improvement >= early_stop_patience:
        print("Early stopping triggered.")
        break



In [None]:
# === Plot Accuracy ===
plt.figure(figsize=(10,5))
plt.plot(train_accuracies, label="Train Accuracy")
plt.plot(val_accuracies, label="Validation Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.title("ResNet-18 Accuracy on CIFAR-100")
plt.grid()
plt.savefig("/content/drive/MyDrive/cifar100model/resnet18_accuracy_curve.png")
plt.show()

In [None]:
plt.plot(train_losses, label="Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss Curve")
plt.grid(True)
plt.show()


In [None]:
import pandas as pd

log_df = pd.DataFrame({
    "epoch": list(range(1, len(train_losses)+1)),
    "train_loss": train_losses,
    "train_acc": train_accuracies,
    "val_acc": val_accuracies,
    "lr": lr_history
})

log_df.to_csv("/content/drive/MyDrive/cifar100model/resnet18_training_log.csv", index=False)
log_df.head()

In [None]:
# === Evaluation Function ===
def evaluate(model, data_loader):
    model.eval()  # Set the model to evaluation mode
    correct = 0
    total = 0
    with torch.no_grad():  # No gradients needed for evaluation
        for imgs, labels in data_loader:
            # Assuming 'device' is defined elsewhere (e.g., "cuda" or "cpu")
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return correct / total

model.load_state_dict(torch.load(save_path))
test_acc = evaluate(model, test_loader)
print(f"Final Test Accuracy: {test_acc:.4f}")