In [5]:
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.models import DenseNet121_Weights, ResNet101_Weights

from tqdm.notebook import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


### Load and Evaluate Default Models

Note that ResNet and DenseNet are both explored in previous works examining the structure of loss functions. Pretrained models are available for both through pytorch. However, these are trained in the ImageNet1K dataset which is extremely large (1.2M images). It is tbd whether I load a small portion of ImageNet1K or fine-tune both models on CIFAR

In [6]:
# Load the CIFAR dataset

transform = transforms.Compose([
    transforms.Resize(224),  # needed for ImageNet-pretrained models
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

trainset = torchvision.datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=transform,
)

testset = torchvision.datasets.CIFAR10(
    root="./data",
    train=False,
    download=True,
    transform=transform,
)

trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2
)

testloader = torch.utils.data.DataLoader(
    testset, batch_size=64, shuffle=False, num_workers=2
)

In [7]:
# Load pretrained models
densenet = models.densenet121(weights=DenseNet121_Weights.IMAGENET1K_V1)
resnet = models.resnet101(weights=ResNet101_Weights.IMAGENET1K_V2)

# Edit the output size
densenet.classifier = torch.nn.Linear(densenet.classifier.in_features, 10)
resnet.fc = torch.nn.Linear(resnet.fc.in_features, 10)

densenet = densenet.to(device)
resnet = resnet.to(device)


In [8]:
def finetune(model, loader, device, num_epochs=100, type=None):
    """ Train the last classifier layer to align the models with the CIFAR10 dataset """

    optimizer = torch.optim.AdamW(model.parameters())
    loss_fn = torch.nn.CrossEntropyLoss()

    # Set minority of weights to be trainable (linear maps)

    model.eval()
    if type == "densenet":
        for name, param in model.named_parameters():
            param.requires_grad = name.startswith("classifier")
        model.classifier.train()
    elif type == "resnet":
        for name, param in model.named_parameters():
            param.requires_grad = name.startswith("fc")
        model.fc.train()
    else:
        raise Exception

    total_loss = 0

    # Fine-tune on the training set for a given number of epochs
    for epoch, (image, labels) in tqdm(enumerate(loader)):
        if epoch > num_epochs: break
        images, labels = image.to(device), labels.to(device)
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() 

        if epoch % 10 == 0 and epoch != 0:
            print(f"Epoch {epoch}: Loss {total_loss}")
            total_loss = 0


def evaluate(model, loader, device):
    """ Basic evaluation script """

    model.eval()

    correct, total = 0, 0

    with torch.no_grad():
        for images, labels in tqdm(loader):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            preds = outputs.argmax(dim=1)

            correct += (preds == labels).sum().item()
            total += labels.size(0)

    return correct / total

In [None]:
# For both models, finetune, evaluate, and save
finetune(densenet, trainloader, device, num_epochs=500, type="densenet")
densenet_acc = evaluate(densenet, testloader, device)
print(densenet_acc)
torch.save(densenet.state_dict(), 'models/densenet_cifar10.pt')

finetune(resnet, trainloader, device, num_epochs=500, type="resnet")
resnet_acc = evaluate(resnet, testloader, device)
print(resnet_acc)
torch.save(resnet.state_dict(), 'models/resnet_cifar10.pt')

### Sample from the Loss Function

In [None]:
# Select random directions in weight space to select weights

# Select new weights from a normal distribution centered at the actual weights

# Evaluate the loss function at these points

# Save this set of points (we can use this for both modeling approaches)



### Train Geometric AutoEncoder to Approximate Loss Function

### Map Loss with Standard and Geometric AutoEncoder Approaches