In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torchsummary import summary

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the transformations for the dataset
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-100 dataset
cifar100_train = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform)
cifar100_test = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=test_transform)

X0, y0 = cifar100_train[0]
input_shape = X0.shape
print(f'input shape {input_shape}')

trainloader = torch.utils.data.DataLoader(cifar100_train, batch_size=32, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(cifar100_test, batch_size=32, shuffle=False, num_workers=2)

# Define ResNet model
resnet_model = torchvision.models.resnet18(pretrained=True)
# Modify the output layer to match the number of classes in CIFAR-100 (100 classes)
resnet_model.fc = nn.Linear(resnet_model.fc.in_features, 100)


resnet_model.to(device)

# Print the summary of the model
summary(resnet_model, input_shape)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet_model.parameters(), lr=0.0001, momentum=0.9)


# Register a hook to print gradient norms and layer names during training
def print_grad_norm(layer_name, counter, print_freq):
    def hook(grad):
        nonlocal counter
        if counter % print_freq == 0:
            norm = torch.norm(grad)
            print(f'Iteration {counter}, Layer: {layer_name}, Gradient Norm: {norm.item()}')
        counter += 1

    return hook

# Initialize the counter
counter = 0

# Register the hook for each layer
for name, param in resnet_model.named_parameters():
    if param.requires_grad:
      hook = print_grad_norm(name, counter, print_freq=100)
      param.register_hook(hook)

# Training the model
num_epochs = 10

batch_size = 32
print_freq = 10 * batch_size

for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(tqdm(trainloader, desc=f'Epoch {epoch + 1}/{num_epochs}', unit='batch')):
        inputs, labels = data

        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = resnet_model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        if (i * batch_size) % print_freq == 0:
            print(f'loss {loss.item()}')

        running_loss += loss.item()

    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(trainloader)}')


In [None]:
# Testing the model
resnet_model.eval()
correct = 0
total = 0

with torch.no_grad():
    for data in testloader:
        images, labels = data
        images = images.to(device)
        labels = labels.to(device)
        outputs = resnet_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy on the test set: {100 * correct / total}%')