In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

from local_models import SmallCNN, MiniResNet

from utils import evaluate

from torch.nn.utils import parameters_to_vector, vector_to_parameters

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

import copy

import pickle

cuda


In [2]:
# Load the CIFAR dataset

transform = transforms.Compose([
    transforms.Resize(224),  # needed for ImageNet-pretrained models
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

trainset = torchvision.datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=transform,
)

testset = torchvision.datasets.CIFAR10(
    root="./data",
    train=False,
    download=True,
    transform=transform,
)

trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2
)

testloader = torch.utils.data.DataLoader(
    testset, batch_size=64, shuffle=False, num_workers=2
)

testloader_sampling = torch.utils.data.DataLoader(
    testset, batch_size=64, shuffle=True, num_workers=2
)

In [13]:
def load_model(name):
    if name == "resnet18":
        model = models.resnet18()
        model.fc = torch.nn.Linear(model.fc.in_features, 10)
        model.load_state_dict(torch.load("./models/resnet18_cifar10.pt", weights_only=True))
    elif name == "resnet101":
        model = models.resnet101()
        model.fc = torch.nn.Linear(model.fc.in_features, 10)
        model.load_state_dict(torch.load("./models/resnet101_cifar10.pt", weights_only=True))
    elif name == "densenet121":
        model = models.densenet121()
        model.classifier = torch.nn.Linear(model.classifier.in_features, 10)
        model.load_state_dict(torch.load("./models/densenet121_cifar10.pt", weights_only=True))
    elif name == "mini_cnn":
        model = SmallCNN()
        model.load_state_dict(torch.load("./models/mini_cnn_cifar10.pt", weights_only=True))
    elif name == "mini_resnet":
        model = MiniResNet()
        model.load_state_dict(torch.load("./models/mini_resnet_cifar10.pt", weights_only=True))
    else:
        raise Exception

    return model

### Sample from the Loss Function

In [4]:
# First figure out an appropriate scale of pertubation based on scale of parameters (heuristic)

loss_fn = torch.nn.CrossEntropyLoss()

model = load_model("mini_cnn")
model.to(device)

weights_vec = parameters_to_vector(model.parameters())

norm = torch.linalg.vector_norm(weights_vec, ord=2).item()

n = weights_vec.shape[0]

epsilon = norm * 1e-3 / n

print(epsilon)

3.741587389611089e-07


In [5]:
# Get a baseline
baseline_acc, baseline_loss = evaluate(
    model=model, 
    loader=testloader_sampling, 
    loss_fn=loss_fn, 
    device=device, 
    num_epochs=20
)
print(baseline_acc)
print(baseline_loss)

  0%|          | 0/157 [00:00<?, ?it/s]

0.26953125
0.030811649560928345


### Attempt 1: Add Gaussian noise to all (or a subset of) parameters

In [6]:
# Create new model
model_new = copy.deepcopy(model)
model_new.to(device)

# Load in perturbed weights 
perturbed_weights = weights_vec.clone()
perturbed_weights += torch.randn(size = (n, ), device=device) * 5e-2

vector_to_parameters(perturbed_weights, model_new.parameters())

# Assess the new model against the old
perturbed_acc, perturbed_loss = evaluate(
    model=model_new, 
    loader=testloader_sampling, 
    loss_fn=loss_fn,
    device=device,
    num_epochs=20
)
print(perturbed_acc)
print(perturbed_loss)


  0%|          | 0/157 [00:00<?, ?it/s]

0.2078125
0.03365680696442723


### Scale Up Experimentation for Both Models

In [None]:
model_name = "mini_resnet"
num_samples = 100

results = []

model = load_model(model_name)
model.to(device)

weights_vec = parameters_to_vector(model.parameters())
n = weights_vec.shape[0]

baseline_acc, baseline_loss = evaluate(
    model=model, 
    loader=testloader_sampling, 
    loss_fn=loss_fn, 
    device=device, 
    num_epochs=20
)

results.append((weights_vec, baseline_acc, baseline_loss))

# Create new model
model_new = copy.deepcopy(model)
model_new.to(device)

while len(results) < num_samples:

    # Load in perturbed weights 
    new_weights_vec = weights_vec.clone()
    new_weights_vec += torch.randn(size = (n, ), device=device) * 5e-2

    vector_to_parameters(new_weights_vec, model_new.parameters())

    # Assess the new model 
    acc, loss = evaluate(
        model=model_new, 
        loader=testloader_sampling, 
        loss_fn=loss_fn,
        device=device,
        num_epochs=20
    )
    
    results.append((new_weights_vec.detach().cpu(), acc, loss))
    print(len(results))

with open(f"./data/loss_samples/{model_name}/loss_samples.pkl", 'wb') as file:
    pickle.dump(results, file)

In [16]:
with open(f"./data/loss_samples/{model_name}/loss_samples.pkl", 'wb') as file:
    pickle.dump(results, file)