In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import timm  # For Vision Transformer
import numpy as np

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [11]:
#Models

#CNN Model:

cnn_model = torchvision.models.resnet50(pretrained=True)
cnn_model.fc = nn.Linear(cnn_model.fc.in_features, 10)  # STL-10 has 10 classes
cnn_model = cnn_model.to(device)


#ViT Model:
vit_model = timm.create_model("vit_small_patch16_224", pretrained=True, num_classes=10)
vit_model = vit_model.to(device)




In [12]:
#Base Dataset

cnn_transform = transforms.Compose([
    transforms.Resize((96, 96)),  # upscale CIFAR-10 (32x32) to 96x96 so both models work fine
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225))  # ImageNet normalization
])

vit_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # upscale CIFAR-10 (32x32) to 96x96 so both models work fine
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225))  # ImageNet normalization
])

cnn_train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=cnn_transform)
cnn_test_dataset  = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=cnn_transform)

cnn_train_loader = DataLoader(cnn_train_dataset, batch_size=64, shuffle=True, pin_memory=True)
cnn_test_loader  = DataLoader(cnn_test_dataset,  batch_size=64, shuffle=False, pin_memory=True)

vit_train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=vit_transform)
vit_test_dataset  = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=vit_transform)

vit_train_loader = DataLoader(vit_train_dataset, batch_size=64, shuffle=True, pin_memory=True)
vit_test_loader  = DataLoader(vit_test_dataset,  batch_size=64, shuffle=False, pin_memory=True)

In [13]:
X, y = next(iter(cnn_train_loader))
print("Batch images shape:", X.shape)
print("Batch labels shape:", y.shape)

Batch images shape: torch.Size([64, 3, 96, 96])
Batch labels shape: torch.Size([64])


In [14]:
#Model functions
def accuracy_fn(y_true, y_pred):
    return (y_true == y_pred).sum().item() / len(y_true) * 100

loss_fn = nn.CrossEntropyLoss()
optimizer_cnn = optim.Adam(cnn_model.parameters(), lr=1e-4)
optimizer_vit = optim.AdamW(vit_model.parameters(), lr=1e-4)

In [15]:
from tqdm.auto import tqdm

def train_step(model, data_loader, loss_fn, optimizer, accuracy_fn, device=device):
    model.train()
    train_loss, train_acc = 0, 0

    # Create a progress bar for the DataLoader
    progress_bar = tqdm(data_loader, desc="Training")

    for X, y in progress_bar:
        X, y = X.to(device), y.to(device)

        # Forward pass
        y_pred = model(X)

        # Loss
        loss = loss_fn(y_pred, y)
        train_loss += loss.item()

        # Accuracy
        train_acc += accuracy_fn(y_true=y, y_pred=y_pred.argmax(dim=1))

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update the progress bar with current loss and accuracy
        progress_bar.set_postfix(loss=loss.item(), accuracy=accuracy_fn(y_true=y, y_pred=y_pred.argmax(dim=1)))

    return train_loss / len(data_loader), train_acc / len(data_loader)

def test_step(model, data_loader, loss_fn, accuracy_fn, device=device):
    model.eval()
    test_loss, test_acc = 0, 0

    with torch.inference_mode():
        # Create a progress bar for the DataLoader
        progress_bar = tqdm(data_loader, desc="Testing")
        for X, y in progress_bar:
            X, y = X.to(device), y.to(device)
            y_pred = model(X)

            # Loss + accuracy
            loss = loss_fn(y_pred, y).item()
            test_loss += loss
            test_acc += accuracy_fn(y_true=y, y_pred=y_pred.argmax(dim=1))

            # Update the progress bar with current loss and accuracy
            progress_bar.set_postfix(loss=loss, accuracy=accuracy_fn(y_true=y, y_pred=y_pred.argmax(dim=1)))

    return test_loss / len(data_loader), test_acc / len(data_loader)

# def train_step(model, data_loader, loss_fn, optimizer, accuracy_fn, device=device):
#     model.train()
#     train_loss, train_acc = 0, 0

#     for X, y in data_loader:
#         X, y = X.to(device), y.to(device)

#         # Forward pass
#         y_pred = model(X)

#         # Loss
#         loss = loss_fn(y_pred, y)
#         train_loss += loss.item()

#         # Accuracy
#         train_acc += accuracy_fn(y_true=y, y_pred=y_pred.argmax(dim=1))

#         # Backpropagation
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         progress_bar.set_postfix(loss=loss.item(), accuracy=accuracy_fn(y_true=y, y_pred=y_pred.argmax(dim=1)))

#     return train_loss / len(data_loader), train_acc / len(data_loader)


# def test_step(model, data_loader, loss_fn, accuracy_fn, device=device):
#     model.eval()
#     test_loss, test_acc = 0, 0

#     with torch.inference_mode():
#         for X, y in data_loader:
#             X, y = X.to(device), y.to(device)
#             y_pred = model(X)

#             # Loss + accuracy
#             test_loss += loss_fn(y_pred, y).item()
#             test_acc += accuracy_fn(y_true=y, y_pred=y_pred.argmax(dim=1))

#     return test_loss / len(data_loader), test_acc / len(data_loader)

In [16]:
#Simple test

cnn_train_loss, cnn_train_acc = train_step(cnn_model, cnn_train_loader, loss_fn, optimizer_cnn, accuracy_fn)
cnn_test_loss, cnn_test_acc   = test_step(cnn_model, cnn_test_loader, loss_fn, accuracy_fn)

print(f"[CNN] Train Loss: {cnn_train_loss:.4f}, Train Acc: {cnn_train_acc:.2f}%")
print(f"[CNN] Test  Loss: {cnn_test_loss:.4f}, Test  Acc: {cnn_test_acc:.2f}%")

vit_train_loss, vit_train_acc = train_step(vit_model, vit_train_loader, loss_fn, optimizer_vit, accuracy_fn)
vit_test_loss, vit_test_acc   = test_step(vit_model, vit_test_loader, loss_fn, accuracy_fn)

print(f"[ViT] Train Loss: {vit_train_loss:.4f}, Train Acc: {vit_train_acc:.2f}%")
print(f"[ViT] Test  Loss: {vit_test_loss:.4f}, Test  Acc: {vit_test_acc:.2f}%")

Training: 100%|██████████| 782/782 [01:45<00:00,  7.44it/s, accuracy=100, loss=0.0529] 
Testing: 100%|██████████| 157/157 [00:09<00:00, 15.99it/s, accuracy=93.8, loss=0.407] 


[CNN] Train Loss: 0.3184, Train Acc: 89.42%
[CNN] Test  Loss: 0.1943, Test  Acc: 93.32%


Training: 100%|██████████| 782/782 [07:21<00:00,  1.77it/s, accuracy=100, loss=0.0161] 
Testing: 100%|██████████| 157/157 [00:35<00:00,  4.44it/s, accuracy=93.8, loss=0.135] 

[ViT] Train Loss: 0.1439, Train Acc: 95.33%
[ViT] Test  Loss: 0.1033, Test  Acc: 96.60%





In [19]:
test_loss, test_acc   = test_step(cnn_model, cnn_test_loader, loss_fn, accuracy_fn)
print(f"[CNN] Test  Loss: {test_loss:.4f}, Test  Acc: {test_acc:.2f}%")

test_loss, test_acc   = test_step(vit_model, vit_test_loader, loss_fn, accuracy_fn)
print(f"[ViT] Test  Loss: {test_loss:.4f}, Test  Acc: {test_acc:.2f}%")

Testing: 100%|██████████| 157/157 [00:10<00:00, 15.30it/s, accuracy=93.8, loss=0.407] 


[CNN] Test  Loss: 0.1943, Test  Acc: 93.32%


Testing: 100%|██████████| 157/157 [00:32<00:00,  4.84it/s, accuracy=93.8, loss=0.135] 

[ViT] Test  Loss: 0.1033, Test  Acc: 96.60%





In [20]:
#Define file paths for saving
cnn_path = "resnet50_cifar10_new.pth"
vit_path = "vit_small_cifar10_new.pth"

print("Saving models...")

# Save the CNN model's state_dict
torch.save(cnn_model.state_dict(), cnn_path)

# Save the ViT model's state_dict
torch.save(vit_model.state_dict(), vit_path)

print("Models saved successfully!")

Saving models...
Models saved successfully!
