In [27]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm
import math
import pickle
from argparse import Namespace
import os
import time
import model

torch.manual_seed(1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  # Random crop of size 32x32 with padding of 4 pixels
    transforms.RandomHorizontalFlip(),  # Randomly flip the image horizontally
    transforms.ToTensor()  # Convert the image to a tensor
])

trainset = torchvision.datasets.CIFAR10(root="./data/CIFAR10", train=True, download=True, transform=transform)

testset = torchvision.datasets.CIFAR10(root="./data/CIFAR10", train=False, download=True, transform=transform)


def train(args):
    subdir = os.path.join(f"./observations/{args.name}", args.folder)
    if not os.path.exists(subdir):
        os.makedirs(subdir)

    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=args.batch_size, shuffle=True
    )
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=args.batch_size, shuffle=False
    )

    net = getattr(model, args.name)(args.hidden_dim, args.depth)
    net = net.to(device)

    print("Number of parameters:", sum([p.numel() for p in net.parameters()]))

    criterion = nn.CrossEntropyLoss()
    optimizer = getattr(optim, args.optimizer)(net.parameters(), lr=args.lr, weight_decay=0.0001)
    scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1, end_factor=0, total_iters=len(trainloader)*args.epochs) # total_updates = (trainset / batch_size) * num_epochs
    train_scores = []
    test_scores = []
    avg_test_losses = []
    average_gradients = []
    average_parameters = []
    num_test_batches = math.ceil(10000 / args.batch_size)
    time_start = time.time()
    for epoch in range(args.epochs):
        for data in tqdm(trainloader):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outputs = net(inputs)
            loss = criterion(outputs, labels)
            train_scores.append(loss.item())
            loss.backward()
            optimizer.step()
            scheduler.step()
            avg_grad_magnitude = 0.0
            avg_param_magnitude = 0.0
            total_parameters = 0
            for name, param in net.named_parameters():
                avg_param_magnitude += param.data.abs().sum().item()
                avg_grad_magnitude += param.grad.abs().sum().item()
                total_parameters += param.numel()
            avg_grad_magnitude /= total_parameters
            avg_param_magnitude /= total_parameters
            average_gradients.append(avg_grad_magnitude)
            average_parameters.append(avg_param_magnitude)
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                images = images.to(device)
                labels = labels.to(device)
                outputs = net(images)
                l_test = criterion(outputs, labels)
                test_scores.append(l_test.item())

        avg_test_loss = sum(test_scores[-num_test_batches:]) / len(
            test_scores[-num_test_batches:]
        )

        avg_test_losses.append(avg_test_loss)
        print(avg_test_loss)

    time_end = time.time()
    fig, (ax1, ax2) = plt.subplots(1, 2)
    ax1.plot(average_gradients)
    ax2.plot(average_parameters)
    ax1.set_title("Average Gradient Magnitude")
    ax2.set_title("Average Parameter Magnitude")
    fig.savefig(f"observations/{args.name}/{args.folder}/{args.optimizer}{args.depth}.png")   # save the figure to file
    plt.close(fig)    # close the figure window

    running_sum = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            outputs = torch.nn.functional.log_softmax(net(images), dim=1)
            running_sum += (outputs.argmax(dim=1) != labels).sum().item()

    print("Error %: ", (running_sum / len(testset)))

    if args.save:
        with open(
            f"observations/{args.name}/{args.folder}/train_scores_b{args.batch_size}dr{args.dropout}lr{args.lr}d{args.depth}w{args.hidden_dim}",
            "wb",
        ) as f:
            pickle.dump(train_scores, f)
            f.close()
        with open(
            f"observations/{args.name}/{args.folder}/test_scores_b{args.batch_size}dr{args.dropout}lr{args.lr}d{args.depth}w{args.hidden_dim}",
            "wb",
        ) as f:
            pickle.dump(test_scores, f)
            f.close()
    if args.log:
        with open(f"observations/{args.name}/{args.folder}/analytics.txt", "a") as f:
            f.write(
                f"batch size: {args.batch_size}, lr: {args.lr}, hidden dim: {args.hidden_dim}, depth: {args.depth}, params: {sum([p.numel() for p in net.parameters()])}, dropout: {args.dropout}, loss: {min(avg_test_losses)}, error %: {running_sum / len(testset)}, time: {time_end - time_start}, epochs: {args.epochs}\n"
            )
            f.close()


for optimizer in ["Adam", "AdamW"]:
    for depths in [5, 9]:
        train(
            Namespace(
                name="ResNet",
                epochs=15,
                batch_size=128,
                lr=0.01,
                optimizer=optimizer,
                hidden_dim=16,
                depth=depths,
                dropout=0,
                save=True,
                log=True,
                folder="joshgradient",
            )
        )

Files already downloaded and verified
Files already downloaded and verified
Number of parameters: 470058


100%|██████████| 391/391 [00:40<00:00,  9.60it/s]


1.495748074748848


100%|██████████| 391/391 [00:49<00:00,  7.89it/s]


1.3187243847907344


100%|██████████| 391/391 [00:43<00:00,  8.99it/s]


1.2139592781851563


100%|██████████| 391/391 [00:45<00:00,  8.68it/s]


1.148245943498008


100%|██████████| 391/391 [00:41<00:00,  9.51it/s]


1.0838414211816425


100%|██████████| 391/391 [00:43<00:00,  9.05it/s]


1.0099486210678197


100%|██████████| 391/391 [00:42<00:00,  9.15it/s]


0.9609569484674478


100%|██████████| 391/391 [00:43<00:00,  9.02it/s]


0.9706235701524759


100%|██████████| 391/391 [00:44<00:00,  8.88it/s]


0.9374853010419049


100%|██████████| 391/391 [00:42<00:00,  9.19it/s]


0.8705701307405399


100%|██████████| 391/391 [00:44<00:00,  8.80it/s]


0.8629575779166403


100%|██████████| 391/391 [00:46<00:00,  8.40it/s]


0.8115251592442959


100%|██████████| 391/391 [00:48<00:00,  8.14it/s]


0.7944017678876466


100%|██████████| 391/391 [00:48<00:00,  8.07it/s]


0.7681537376174444


100%|██████████| 391/391 [00:49<00:00,  7.95it/s]


0.7555655391910409
Error %:  0.2569
Number of parameters: 859818


100%|██████████| 391/391 [00:59<00:00,  6.59it/s]


1.6515271241151834


100%|██████████| 391/391 [01:04<00:00,  6.06it/s]


1.416055929811695


100%|██████████| 391/391 [01:03<00:00,  6.14it/s]


1.2459617430650736


100%|██████████| 391/391 [01:00<00:00,  6.46it/s]


1.166567654549321


100%|██████████| 391/391 [01:03<00:00,  6.19it/s]


1.0896788714807244


100%|██████████| 391/391 [01:02<00:00,  6.30it/s]


1.0611165263984776


100%|██████████| 391/391 [01:00<00:00,  6.44it/s]


1.0293714479555058


100%|██████████| 391/391 [01:03<00:00,  6.19it/s]


0.9712739051142826


100%|██████████| 391/391 [01:03<00:00,  6.19it/s]


0.9330684497386594


100%|██████████| 391/391 [01:11<00:00,  5.44it/s]


0.906228703034075


100%|██████████| 391/391 [01:08<00:00,  5.70it/s]


0.8788764325878288


100%|██████████| 391/391 [01:11<00:00,  5.48it/s]


0.845771459839012


100%|██████████| 391/391 [01:23<00:00,  4.70it/s]


0.8043798261050936


100%|██████████| 391/391 [01:20<00:00,  4.88it/s]


0.770236040217967


100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


0.750797238530992
Error %:  0.2558
Number of parameters: 470058


100%|██████████| 391/391 [00:43<00:00,  8.92it/s]


1.5001700573329684


100%|██████████| 391/391 [00:42<00:00,  9.25it/s]


1.3129295231420783


100%|██████████| 391/391 [00:42<00:00,  9.27it/s]


1.158948670459699


100%|██████████| 391/391 [00:42<00:00,  9.13it/s]


1.0763975540293922


100%|██████████| 391/391 [00:42<00:00,  9.16it/s]


1.055736378023896


100%|██████████| 391/391 [00:41<00:00,  9.34it/s]


0.9586033926734442


100%|██████████| 391/391 [00:44<00:00,  8.87it/s]


0.9240797260139562


100%|██████████| 391/391 [00:42<00:00,  9.22it/s]


0.8929239488855193


100%|██████████| 391/391 [00:41<00:00,  9.38it/s]


0.8917643888087212


100%|██████████| 391/391 [00:41<00:00,  9.44it/s]


0.8517005594470833


100%|██████████| 391/391 [00:40<00:00,  9.56it/s]


0.8268860887877548


100%|██████████| 391/391 [00:41<00:00,  9.42it/s]


0.796318042127392


100%|██████████| 391/391 [00:41<00:00,  9.51it/s]


0.7609626243386087


100%|██████████| 391/391 [00:41<00:00,  9.38it/s]


0.7716523303261286


100%|██████████| 391/391 [00:41<00:00,  9.36it/s]


0.7378110074544255
Error %:  0.2567
Number of parameters: 859818


100%|██████████| 391/391 [01:02<00:00,  6.29it/s]


1.6525506520573097


100%|██████████| 391/391 [01:07<00:00,  5.78it/s]


1.487200110773497


100%|██████████| 391/391 [01:23<00:00,  4.66it/s]


1.4063940395282795


100%|██████████| 391/391 [01:12<00:00,  5.40it/s]


1.3121332065968574


100%|██████████| 391/391 [01:06<00:00,  5.84it/s]


1.2077163970923122


100%|██████████| 391/391 [01:05<00:00,  5.95it/s]


1.157616428936584


100%|██████████| 391/391 [01:06<00:00,  5.85it/s]


1.0986458417735523


100%|██████████| 391/391 [01:01<00:00,  6.39it/s]


1.056782824329183


100%|██████████| 391/391 [01:04<00:00,  6.10it/s]


1.0087277398833745


100%|██████████| 391/391 [01:04<00:00,  6.08it/s]


0.9744593048397499


100%|██████████| 391/391 [01:07<00:00,  5.81it/s]


0.9471317620217046


100%|██████████| 391/391 [01:07<00:00,  5.83it/s]


0.9066041308113292


100%|██████████| 391/391 [01:01<00:00,  6.33it/s]


0.8797875884213026


100%|██████████| 391/391 [01:03<00:00,  6.16it/s]


0.8570018146611467


100%|██████████| 391/391 [01:07<00:00,  5.79it/s]


0.8401700809032102
Error %:  0.2904


In [None]:
# After loss.backward, the gradients are stored in the .grad attribute of each layer