In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [None]:
# Using the GPU if it exists
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

train_dataset = datasets.CIFAR100(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.CIFAR100(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000)

In [None]:
model_normal = nn.Sequential(
    nn.Flatten(),
    nn.Linear(32*32*3, 512),  #From 32x32 RGB 
    nn.ReLU(),
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Linear(256, 100)       
)

In [None]:
# Training loop
def train(model, loader, optimizer, loss_fn, epochs=5):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            logits = model(x)
            loss = loss_fn(logits, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

# Testing loop
def test(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    print(f"Accuracy: {100 * correct / total:.2f}%")


In [None]:
print("\nCifar-100 Model")
cifar_model = model_normal.to(device)
optimizer_seq = optim.Adam(cifar_model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

train(cifar_model, train_loader, optimizer_seq, loss_fn)
test(cifar_model, test_loader)

In [None]:
class CIFARNet(nn.Module):
    def __init__(self):
        super(CIFARNet, self).__init__()
        self.fc1 = nn.Linear(32*32*3, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 100)
        
    def forward(self, x):
        x = x.view(-1, 32*32*3)     # Flatten the image
        x = F.relu(self.fc1(x))     # First layer + ReLU
        x = F.relu(self.fc2(x))     # Second layer + ReLU
        x = self.fc3(x)             # Final layer (logits)
        return x

In [None]:
print("\nUsing nn.functional")
model = CIFARNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

train(model, train_loader, optimizer, loss_fn)
test(model, test_loader)

In [None]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)   # out: 32x32x32
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)  # out: 64x32x32
        self.pool = nn.MaxPool2d(2, 2)                            # out: 64x16x16
        self.fc1 = nn.Linear(64 * 16 * 16, 256)
        self.fc2 = nn.Linear(256, 100)

    def forward(self, x):
        x = F.relu(self.conv1(x))       # -> 32x32x32
        x = self.pool(F.relu(self.conv2(x)))  # -> 64x16x16
        x = x.view(-1, 64 * 16 * 16)    # flatten
        x = F.relu(self.fc1(x))         
        x = self.fc2(x)                 
        return x

In [None]:
print("\nCIFAR-100 ConvNet")
conv_model = ConvNet().to(device)
optimizer = optim.Adam(conv_model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

train(conv_model, train_loader, optimizer, loss_fn)
test(conv_model, test_loader)

In [None]:
def visualize(model, loader, n=5):
    model.eval()
    x, y = next(iter(loader))
    x, y = x.to(device), y.to(device)
    preds = model(x).argmax(dim=1)

    plt.figure(figsize=(10, 2))
    for i in range(n):
        plt.subplot(1, n, i+1)
        plt.imshow(x[i].cpu().permute(1, 2, 0))  
        plt.title(f"T:{y[i].item()} P:{preds[i].item()}")
        plt.axis('off')
    plt.show()

visualize(conv_model, test_loader) # Uncomment this later
#visualize(sequential_model, test_loader)