In [11]:
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.models import DenseNet121_Weights, ResNet101_Weights, ResNet18_Weights

from tqdm.notebook import tqdm

from utils import evaluate

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


### Load and Evaluate Default Models

Note that ResNet and DenseNet are both explored in previous works examining the structure of loss functions. Pretrained models are available for both through pytorch. However, these are trained in the ImageNet1K dataset which is extremely large (1.2M images). It is tbd whether I load a small portion of ImageNet1K or fine-tune both models on CIFAR

In [7]:
# Load the CIFAR dataset

transform = transforms.Compose([
    transforms.Resize(224),  # needed for ImageNet-pretrained models
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

trainset = torchvision.datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=transform,
)

testset = torchvision.datasets.CIFAR10(
    root="./data",
    train=False,
    download=True,
    transform=transform,
)

trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2
)

testloader = torch.utils.data.DataLoader(
    testset, batch_size=64, shuffle=False, num_workers=2
)

In [None]:
# Load pretrained models
densenet121 = models.densenet121(weights=DenseNet121_Weights.IMAGENET1K_V1)
resnet101 = models.resnet101(weights=ResNet101_Weights.IMAGENET1K_V2)
resnet18 = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)

# Edit the output size
densenet121.classifier = torch.nn.Linear(densenet121.classifier.in_features, 10)
resnet101.fc = torch.nn.Linear(resnet101.fc.in_features, 10)
resnet18.fc = torch.nn.Linear(resnet18.fc.in_features, 10)

densenet121 = densenet121.to(device)
resnet101 = resnet101.to(device)
resnet18 = resnet18.to(device)

loss_fn = torch.nn.CrossEntropyLoss()


In [None]:
def finetune(model, loader, device, num_epochs=100, type=None):
    """ Train the last classifier layer to align the models with the CIFAR10 dataset """

    optimizer = torch.optim.AdamW(model.parameters())

    # Set minority of weights to be trainable (linear maps)

    model.eval()
    if type == "densenet":
        for name, param in model.named_parameters():
            param.requires_grad = name.startswith("classifier")
        model.classifier.train()
    elif type == "resnet":
        for name, param in model.named_parameters():
            param.requires_grad = name.startswith("fc")
        model.fc.train()
    else:
        raise Exception

    total_loss = 0

    # Fine-tune on the training set for a given number of epochs
    for epoch, (image, labels) in tqdm(enumerate(loader)):
        if epoch > num_epochs: break
        images, labels = image.to(device), labels.to(device)
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() 

        if epoch % 10 == 0 and epoch != 0:
            print(f"Epoch {epoch}: Loss {total_loss}")
            total_loss = 0

In [None]:
# For both models, finetune, evaluate, and save
finetune(resnet18, trainloader, device, num_epochs=500, type="resnet")
loss_fn = torch.nn.CrossEntropyLoss()
resnet18_acc, resnet18_loss = evaluate(model=resnet18, loader=testloader, loss_fn=loss_fn, device=device)
print(resnet18_acc, resnet18_loss)
torch.save(resnet18.state_dict(), 'models/resnet18_cifar10.pt')

  0%|          | 0/157 [00:00<?, ?it/s]

(0.8537, 0.007023197868466377)


In [None]:
# # For both models, finetune, evaluate, and save
# finetune(densenet, trainloader, device, num_epochs=500, type="densenet")
# densenet_acc = evaluate(densenet, testloader, device)
# print(densenet_acc)
# torch.save(densenet.state_dict(), 'models/densenet121_cifar10.pt')

# finetune(resnet, trainloader, device, num_epochs=500, type="resnet")
# resnet_acc = evaluate(resnet, testloader, device)
# print(resnet_acc)
# torch.save(resnet.state_dict(), 'models/resnet18_cifar10.pt')