In [26]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import timm  # For Vision Transformer
import numpy as np

In [27]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [28]:
#Models

#CNN Model:

cnn_model = torchvision.models.resnet50(pretrained=True)
cnn_model.fc = nn.Linear(cnn_model.fc.in_features, 10)  # STL-10 has 10 classes
cnn_model = cnn_model.to(device)


#ViT Model:
vit_model = timm.create_model("vit_small_patch16_224", pretrained=True, num_classes=10)
vit_model = vit_model.to(device)


In [29]:
#Base Dataset

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # upscale CIFAR-10 (32x32) to 96x96 so both models work fine
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225))  # ImageNet normalization
])

train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
test_dataset  = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, pin_memory=True)
test_loader  = DataLoader(test_dataset,  batch_size=64, shuffle=False, pin_memory=True)

In [30]:
X, y = next(iter(train_loader))
print("Batch images shape:", X.shape)
print("Batch labels shape:", y.shape)

Batch images shape: torch.Size([64, 3, 224, 224])
Batch labels shape: torch.Size([64])


In [31]:
#Model functions
def accuracy_fn(y_true, y_pred):
    return (y_true == y_pred).sum().item() / len(y_true) * 100

loss_fn = nn.CrossEntropyLoss()
optimizer_cnn = optim.Adam(cnn_model.parameters(), lr=1e-4)
optimizer_vit = optim.AdamW(vit_model.parameters(), lr=1e-4)

In [32]:
from tqdm.auto import tqdm

def train_step(model, data_loader, loss_fn, optimizer, accuracy_fn, device=device):
    model.train()
    train_loss, train_acc = 0, 0

    # Create a progress bar for the DataLoader
    progress_bar = tqdm(data_loader, desc="Training")

    for X, y in progress_bar:
        X, y = X.to(device), y.to(device)

        # Forward pass
        y_pred = model(X)

        # Loss
        loss = loss_fn(y_pred, y)
        train_loss += loss.item()

        # Accuracy
        train_acc += accuracy_fn(y_true=y, y_pred=y_pred.argmax(dim=1))

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update the progress bar with current loss and accuracy
        progress_bar.set_postfix(loss=loss.item(), accuracy=accuracy_fn(y_true=y, y_pred=y_pred.argmax(dim=1)))

    return train_loss / len(data_loader), train_acc / len(data_loader)

def test_step(model, data_loader, loss_fn, accuracy_fn, device=device):
    model.eval()
    test_loss, test_acc = 0, 0

    with torch.inference_mode():
        # Create a progress bar for the DataLoader
        progress_bar = tqdm(data_loader, desc="Testing")
        for X, y in progress_bar:
            X, y = X.to(device), y.to(device)
            y_pred = model(X)

            # Loss + accuracy
            loss = loss_fn(y_pred, y).item()
            test_loss += loss
            test_acc += accuracy_fn(y_true=y, y_pred=y_pred.argmax(dim=1))

            # Update the progress bar with current loss and accuracy
            progress_bar.set_postfix(loss=loss, accuracy=accuracy_fn(y_true=y, y_pred=y_pred.argmax(dim=1)))

    return test_loss / len(data_loader), test_acc / len(data_loader)

# def train_step(model, data_loader, loss_fn, optimizer, accuracy_fn, device=device):
#     model.train()
#     train_loss, train_acc = 0, 0

#     for X, y in data_loader:
#         X, y = X.to(device), y.to(device)

#         # Forward pass
#         y_pred = model(X)

#         # Loss
#         loss = loss_fn(y_pred, y)
#         train_loss += loss.item()

#         # Accuracy
#         train_acc += accuracy_fn(y_true=y, y_pred=y_pred.argmax(dim=1))

#         # Backpropagation
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         progress_bar.set_postfix(loss=loss.item(), accuracy=accuracy_fn(y_true=y, y_pred=y_pred.argmax(dim=1)))

#     return train_loss / len(data_loader), train_acc / len(data_loader)


# def test_step(model, data_loader, loss_fn, accuracy_fn, device=device):
#     model.eval()
#     test_loss, test_acc = 0, 0

#     with torch.inference_mode():
#         for X, y in data_loader:
#             X, y = X.to(device), y.to(device)
#             y_pred = model(X)

#             # Loss + accuracy
#             test_loss += loss_fn(y_pred, y).item()
#             test_acc += accuracy_fn(y_true=y, y_pred=y_pred.argmax(dim=1))

#     return test_loss / len(data_loader), test_acc / len(data_loader)

In [33]:
#Simple test

train_loss, train_acc = train_step(cnn_model, train_loader, loss_fn, optimizer_cnn, accuracy_fn)
test_loss, test_acc   = test_step(cnn_model, test_loader, loss_fn, accuracy_fn)

print(f"[CNN] Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
print(f"[CNN] Test  Loss: {test_loss:.4f}, Test  Acc: {test_acc:.2f}%")

train_loss, train_acc = train_step(vit_model, train_loader, loss_fn, optimizer_vit, accuracy_fn)
test_loss, test_acc   = test_step(vit_model, test_loader, loss_fn, accuracy_fn)

print(f"[ViT] Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
print(f"[ViT] Test  Loss: {test_loss:.4f}, Test  Acc: {test_acc:.2f}%")

Training: 100%|██████████| 782/782 [30:34<00:00,  2.35s/it, accuracy=93.8, loss=0.136] 
Testing: 100%|██████████| 157/157 [00:37<00:00,  4.20it/s, accuracy=93.8, loss=0.368] 


[CNN] Train Loss: 0.2925, Train Acc: 90.32%
[CNN] Test  Loss: 0.1955, Test  Acc: 93.28%


Training: 100%|██████████| 782/782 [12:41<00:00,  1.03it/s, accuracy=100, loss=0.00133]
Testing: 100%|██████████| 157/157 [00:52<00:00,  2.98it/s, accuracy=93.8, loss=0.0774]

[ViT] Train Loss: 0.1450, Train Acc: 95.32%
[ViT] Test  Loss: 0.1024, Test  Acc: 96.55%





In [38]:
test_loss, test_acc   = test_step(cnn_model, test_loader, loss_fn, accuracy_fn)
print(f"[CNN] Test  Loss: {test_loss:.4f}, Test  Acc: {test_acc:.2f}%")
test_loss, test_acc   = test_step(vit_model, test_loader, loss_fn, accuracy_fn)
print(f"[ViT] Test  Loss: {test_loss:.4f}, Test  Acc: {test_acc:.2f}%")

Testing: 100%|██████████| 157/157 [01:03<00:00,  2.47it/s, accuracy=93.8, loss=0.368] 


[CNN] Test  Loss: 0.1955, Test  Acc: 93.28%


Testing: 100%|██████████| 157/157 [00:56<00:00,  2.80it/s, accuracy=93.8, loss=0.0774]

[ViT] Test  Loss: 0.1024, Test  Acc: 96.55%





In [39]:
#Define file paths for saving
cnn_path = "resnet50_cifar10_new.pth"
vit_path = "vit_small_cifar10_new.pth"

print("Saving models...")

# Save the CNN model's state_dict
torch.save(cnn_model.state_dict(), cnn_path)

# Save the ViT model's state_dict
torch.save(vit_model.state_dict(), vit_path)

print("Models saved successfully!")

Saving models...
Models saved successfully!
