<a href="https://colab.research.google.com/github/aamish007/23-CS-006-CS318-DL-Lab/blob/main/23_CS_006_Experiment_3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader


In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


device(type='cpu')

In [16]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.247, 0.243, 0.261]
    )
])

train_data = datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform
)
test_data = datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform
)

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64)


In [17]:
class SimpleCNN(nn.Module):
    def __init__(self, activation='relu'):
        super().__init__()

        if activation == 'relu':
            act = nn.ReLU()
        elif activation == 'tanh':
            act = nn.Tanh()
        else:
            act = nn.LeakyReLU(0.01)

        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            act,
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, 3, padding=1),
            act,
            nn.MaxPool2d(2),

            nn.Flatten(),
            nn.Linear(64 * 8 * 8, 128),
            act,
            nn.Linear(128, 10)
        )

    def forward(self, x):
        return self.model(x)


In [18]:
def init_weights(model, method):
    for m in model.modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            if method == 'xavier':
                nn.init.xavier_uniform_(m.weight)
            elif method == 'kaiming':
                nn.init.kaiming_normal_(m.weight)
            else:
                nn.init.normal_(m.weight, mean=0, std=0.01)


In [19]:
def get_optimizer(name, model):
    if name == 'sgd':
        return optim.SGD(model.parameters(), lr=0.01)
    elif name == 'adam':
        return optim.Adam(model.parameters(), lr=0.001)
    else:
        return optim.RMSprop(model.parameters(), lr=0.001)


In [22]:
def train_model(activation, init_method, optimizer_name):
    model = SimpleCNN(activation).to(device)
    init_weights(model, init_method)

    optimizer = get_optimizer(optimizer_name, model)
    criterion = nn.CrossEntropyLoss()

    model.train()
    for epoch in range(5):  # only 2 epochs (lab purpose)
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

    print(f"Completed → Act:{activation}, Init:{init_method}, Opt:{optimizer_name}")


In [23]:
train_model('relu', 'xavier', 'adam')
train_model('tanh', 'kaiming', 'sgd')
train_model('leaky', 'random', 'rmsprop')


Completed → Act:relu, Init:xavier, Opt:adam
Completed → Act:tanh, Init:kaiming, Opt:sgd
Completed → Act:leaky, Init:random, Opt:rmsprop


In [24]:
resnet = models.resnet18(pretrained=True)
resnet.fc = nn.Linear(resnet.fc.in_features, 10)
resnet = resnet.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet.parameters(), lr=0.0001)

resnet.train()
for images, labels in train_loader:
    images, labels = images.to(device), labels.to(device)
    optimizer.zero_grad()
    outputs = resnet(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    break  # 1 batch only (demo)


In [25]:
def evaluate(model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print("Accuracy:", correct / total)

evaluate(resnet)


Accuracy: 0.133
