# CIFAR100 Based Image classification Model Retrain

###  Train EfficientNet-B0 on CIFAR-100 using PyTorch

In [1]:
# Import required:
import os
import torch
import platform
import torchvision
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torchvision.models import efficientnet_b0

In [2]:
# Define hyperparameters:
epochs = 40
LEARNING_RATE = 0.0005

In [3]:
# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Display GPU details if available
if device.type == "cuda":
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    
# Display CPU details
print(f"CPU: {platform.processor()}")
print(f"System: {platform.system()} {platform.release()}")

Using device: cuda
GPU Name: Orin
CUDA Version: 12.6
GPU Memory: 7.99 GB
CPU: aarch64
System: Linux 5.15.148-tegra


In [4]:
# Define transforms
transform = transforms.Compose([
    transforms.Resize((32, 32)), 
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [5]:
# Load CIFAR-100 dataset
trainset = torchvision.datasets.CIFAR100(root="./data", train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testset = torchvision.datasets.CIFAR100(root="./data", train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
# Load pre-trained EfficientNet-B0
model = efficientnet_b0(weights="IMAGENET1K_V1")  
model.classifier[1] = nn.Linear(1280, 100)  # Modify last layer for CIFAR-100 classes
model.to(device)

In [7]:
# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [None]:
# Store training history
train_losses = []
train_accuracies = []
val_losses = []  # Added for validation loss
val_accuracies = [] # Added for validation accuracy

In [8]:
# Training loop
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in trainloader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # Compute accuracy
        _, predicted = torch.max(outputs, 1)  # Get class with highest score
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / len(trainloader)
    epoch_accuracy = 100 * correct / total
    train_losses.append(epoch_loss)
    train_accuracies.append(epoch_accuracy)


    # Validation loop (add this part)
    model.eval()  # Set to evaluation mode
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad(): # No gradients during validation
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            val_correct += (predicted == labels).sum().item()
            val_total += labels.size(0)

    val_epoch_loss = val_loss / len(testloader)
    val_epoch_accuracy = 100 * val_correct / val_total
    val_losses.append(val_epoch_loss)
    val_accuracies.append(val_epoch_accuracy)


    print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%, Val Loss: {val_epoch_loss:.4f}, Val Acc: {val_epoch_accuracy:.2f}%")


Epoch 1/40, Loss: 2.9804, Accuracy: 26.73%
Epoch 2/40, Loss: 1.9802, Accuracy: 46.03%
Epoch 3/40, Loss: 1.6106, Accuracy: 54.69%
Epoch 4/40, Loss: 1.4016, Accuracy: 59.75%
Epoch 5/40, Loss: 1.2055, Accuracy: 64.54%
Epoch 6/40, Loss: 1.0404, Accuracy: 68.87%
Epoch 7/40, Loss: 0.9161, Accuracy: 72.03%
Epoch 8/40, Loss: 0.7870, Accuracy: 75.69%
Epoch 9/40, Loss: 0.7073, Accuracy: 78.00%
Epoch 10/40, Loss: 0.6374, Accuracy: 79.81%
Epoch 11/40, Loss: 0.5889, Accuracy: 81.32%
Epoch 12/40, Loss: 0.5694, Accuracy: 82.07%
Epoch 13/40, Loss: 0.5253, Accuracy: 83.33%
Epoch 14/40, Loss: 0.4385, Accuracy: 85.97%
Epoch 15/40, Loss: 0.4285, Accuracy: 86.14%
Epoch 16/40, Loss: 0.3841, Accuracy: 87.60%
Epoch 17/40, Loss: 0.3771, Accuracy: 87.83%
Epoch 18/40, Loss: 0.3741, Accuracy: 88.05%
Epoch 19/40, Loss: 0.3851, Accuracy: 87.68%
Epoch 20/40, Loss: 0.3289, Accuracy: 89.40%
Epoch 21/40, Loss: 0.2887, Accuracy: 90.65%
Epoch 22/40, Loss: 0.3076, Accuracy: 90.28%
Epoch 23/40, Loss: 0.2998, Accuracy: 90.3

In [9]:

# Define base model path
base_path = "/home/aman-nvidia/My_files/cv_projects/image_classification_webGUI/efficientnet_cifar100"
file_extension = ".pth"

# Function to create filename with num_epochs AND counter (for saving)
def get_model_path_save(base_path, file_extension, num_epochs):  # Renamed to clarify
    epoch_path = f"{base_path}_{num_epochs}"

    counter = 1
    model_path = f"{epoch_path}{file_extension}"

    while os.path.exists(model_path):
        model_path = f"{epoch_path}_{counter}{file_extension}"
        counter += 1

    return model_path

# Save the model
model_path = get_model_path_save(base_path, file_extension, epochs)
torch.save(model.state_dict(), model_path)
print(f"Model saved successfully as: {model_path}")

Model saved successfully as: /home/aman-nvidia/My_files/cv_projects/image_classification_webGUI/efficientnet_cifar100.pth


In [None]:
# Plot training history
plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)  # Create a subplot for loss
plt.plot(range(1, epochs + 1), train_losses, label="Train Loss")
plt.plot(range(1, epochs + 1), val_losses, label="Validation Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.title("Training/Validation Loss")

plt.subplot(1, 2, 2)  # Create a subplot for accuracy
plt.plot(range(1, epochs + 1), train_accuracies, label="Train Accuracy")
plt.plot(range(1, epochs + 1), val_accuracies, label="Validation Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.title("Training/Validation Accuracy")

plt.tight_layout() # Adjust subplot params so that subplots fit in to the figure area.
plt.show()
