In [1]:
from resnet import ResNet18
import torchvision
import torch
from torch.utils.data import DataLoader
import argparse
import time
from lab2 import get_cifar10_dataloaders, train

In [None]:
TRAIN_BATCH_SIZE = 100
TEST_BATCH_SIZE = 128
DOWNLOAD_PATH = './data'

In [None]:
device = torch.device('cuda')

num_workers = 0

while True:
    train_loader, test_loader = get_cifar10_dataloaders(TRAIN_BATCH_SIZE, num_workers, TEST_BATCH_SIZE, num_workers, download_path=args.data_download_path)

mod = ResNet18(include_batch_norm_layers=True).to(device)

    if args.optimizer == 'sgd':
        optim = torch.optim.SGD(mod.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov)
    elif args.optimizer == 'adam':
        optim = torch.optim.Adam(mod.parameters(), lr=args.lr, eps=args.epsilson, weight_decay=args.wd)
    elif args.optimizer == 'adagrad':
        optim = torch.optim.Adagrad(mod.parameters(), lr=args.lr, eps=args.epsilson, weight_decay=args.wd)
    elif args.optimizer == 'rmsprop':
        optim = torch.optim.RMSprop(mod.parameters(), lr=args.lr, eps=args.epsilon, momentum=args.momentum, weight_decay=args.weight_decay)
    else:
        optim = torch.optim.Adadelta(mod.parameters(), lr=args.lr, eps=args.epsilson, weight_decay=args.wd)

    loss_func = torch.nn.CrossEntropyLoss()

    dl_times = []
    train_times = []
    metrics_times = []
    epoch_times = []

    for i in range(args.epochs):
        epoch_start_time = time.perf_counter()
        epoch_loss, epoch_total, epoch_correct, epoch_dl_time, epoch_train_time, epoch_metrics_time = train(train_loader, i, mod, optim, loss_func, device, args.enable_torch_profiling, verbose=args.verbose)
        dl_times.append(epoch_dl_time)
        train_times.append(epoch_train_time)
        metrics_times.append(epoch_metrics_time)
        epoch_end_time = time.perf_counter()
        epoch_time = epoch_end_time - epoch_start_time
        epoch_times.append(epoch_time)

        print(f'[EPOCH {i+1} SUMMARY] DL Time: {epoch_dl_time}, Train Time: {epoch_train_time}, Metrics Time: {epoch_metrics_time}, Total Running Time: {epoch_time}, Training Loss: {epoch_loss}, Top-1 Accuracy: {epoch_correct / epoch_total}')
        print()
    
    print(f'[BENCHMARKING SUMMARY] Total DL Time: {sum(dl_times)}, Total Train Time: {sum(train_times)}, Total Metrics Time: {sum(metrics_times)}, Total Runtime Across All Epochs: {sum(epoch_times)}')