In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

from local_models import SmallCNN, MiniResNet

from tqdm.notebook import tqdm

from utils import evaluate

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


### Load and Evaluate Default Models

Note that ResNet and DenseNet are both explored in previous works examining the structure of loss functions. Pretrained models are available for both through pytorch. However, these are trained in the ImageNet1K dataset which is extremely large (1.2M images). It is tbd whether I load a small portion of ImageNet1K or fine-tune both models on CIFAR

In [2]:
# Load the CIFAR dataset

transform = transforms.Compose([
    transforms.Resize(224),  # needed for ImageNet-pretrained models
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

trainset = torchvision.datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=transform,
)

testset = torchvision.datasets.CIFAR10(
    root="./data",
    train=False,
    download=True,
    transform=transform,
)

trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2
)

testloader = torch.utils.data.DataLoader(
    testset, batch_size=64, shuffle=False, num_workers=2
)

In [3]:
# Load pretrained models
mini_cnn = SmallCNN()
mini_resnet = MiniResNet()

mini_cnn.to(device)
mini_resnet.to(device)

loss_fn = torch.nn.CrossEntropyLoss()


In [4]:
def train(model, loader, device, num_epochs=100):
    """ Train the models on the CIFAR10 dataset """

    optimizer = torch.optim.AdamW(model.parameters())

    # Set minority of weights to be trainable (linear maps)

    model.train()

    total_loss = 0

    # Fine-tune on the training set for a given number of epochs
    for epoch, (image, labels) in tqdm(enumerate(loader)):
        if epoch > num_epochs: break
        images, labels = image.to(device), labels.to(device)
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() 

        if epoch % 10 == 0 and epoch != 0:
            print(f"Epoch {epoch}: Loss {total_loss}")
            total_loss = 0

In [5]:
# For both models, finetune, evaluate, and save
train(mini_cnn, trainloader, device, num_epochs=500)
mini_cnn_acc, mini_cnn_loss = evaluate(model=mini_cnn, loader=testloader, loss_fn=loss_fn, device=device)
print(mini_cnn_acc, mini_cnn_loss)
torch.save(mini_cnn.state_dict(), 'models/mini_cnn_cifar10.pt')

train(mini_resnet, trainloader, device, num_epochs=500)
mini_resnet_acc, mini_resnet_loss = evaluate(model=mini_resnet, loader=testloader, loss_fn=loss_fn, device=device)
print(mini_cnn_acc, mini_resnet_loss)
torch.save(mini_resnet.state_dict(), 'models/mini_resnet_cifar10.pt')

0it [00:00, ?it/s]

Epoch 10: Loss 25.139097213745117
Epoch 20: Loss 22.388288974761963
Epoch 30: Loss 21.888100147247314
Epoch 40: Loss 21.26663589477539
Epoch 50: Loss 21.004374980926514
Epoch 60: Loss 20.92983341217041
Epoch 70: Loss 20.80238914489746
Epoch 80: Loss 20.85324239730835
Epoch 90: Loss 20.848198890686035
Epoch 100: Loss 20.522931456565857
Epoch 110: Loss 20.42978525161743
Epoch 120: Loss 20.727654933929443
Epoch 130: Loss 20.437235832214355
Epoch 140: Loss 20.62094759941101
Epoch 150: Loss 20.524761080741882
Epoch 160: Loss 20.43184721469879
Epoch 170: Loss 20.329602479934692
Epoch 180: Loss 20.123048424720764
Epoch 190: Loss 20.05565905570984
Epoch 200: Loss 20.329110980033875
Epoch 210: Loss 20.355631709098816
Epoch 220: Loss 19.933977007865906
Epoch 230: Loss 20.16918122768402
Epoch 240: Loss 20.375797033309937
Epoch 250: Loss 19.726994514465332
Epoch 260: Loss 20.051657676696777
Epoch 270: Loss 19.936481833457947
Epoch 280: Loss 20.089476227760315
Epoch 290: Loss 20.18186628818512
Epoc

  0%|          | 0/157 [00:00<?, ?it/s]

0.2691 0.03112415976524353


0it [00:00, ?it/s]

Epoch 10: Loss 25.449487924575806
Epoch 20: Loss 22.923524141311646
Epoch 30: Loss 22.808921813964844
Epoch 40: Loss 22.746204614639282
Epoch 50: Loss 22.503289699554443
Epoch 60: Loss 21.9004008769989
Epoch 70: Loss 21.364984035491943
Epoch 80: Loss 21.068109035491943
Epoch 90: Loss 20.855353355407715
Epoch 100: Loss 21.000168085098267
Epoch 110: Loss 20.919428944587708
Epoch 120: Loss 20.678924083709717
Epoch 130: Loss 20.54681372642517
Epoch 140: Loss 20.364158868789673
Epoch 150: Loss 19.93267321586609
Epoch 160: Loss 20.4872328042984
Epoch 170: Loss 20.283712029457092
Epoch 180: Loss 20.318406462669373
Epoch 190: Loss 20.38431179523468
Epoch 200: Loss 20.507373332977295
Epoch 210: Loss 20.326703310012817
Epoch 220: Loss 19.98502016067505
Epoch 230: Loss 20.311936736106873
Epoch 240: Loss 20.28835391998291
Epoch 250: Loss 20.273168444633484
Epoch 260: Loss 20.25411605834961
Epoch 270: Loss 20.123975038528442
Epoch 280: Loss 20.153377532958984
Epoch 290: Loss 20.05859923362732
Epoch

  0%|          | 0/157 [00:00<?, ?it/s]

0.2691 0.031153733611106873
