In [20]:
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

from utils import evaluate

from torch.nn.utils import parameters_to_vector, vector_to_parameters

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

import copy

cuda


In [2]:
# Load the CIFAR dataset

transform = transforms.Compose([
    transforms.Resize(224),  # needed for ImageNet-pretrained models
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

trainset = torchvision.datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=transform,
)

testset = torchvision.datasets.CIFAR10(
    root="./data",
    train=False,
    download=True,
    transform=transform,
)

trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2
)

testloader = torch.utils.data.DataLoader(
    testset, batch_size=64, shuffle=False, num_workers=2
)

testloader_sampling = torch.utils.data.DataLoader(
    testset, batch_size=64, shuffle=True, num_workers=2
)

### Sample from the Loss Function

In [16]:
# First figure out an appropriate scale of pertubation based on scale of parameters (heuristic)

loss_fn = torch.nn.CrossEntropyLoss()

densenet = models.densenet121()
densenet.classifier = torch.nn.Linear(densenet.classifier.in_features, 10)
densenet.load_state_dict(torch.load("./models/densenet_cifar10.pt", weights_only=True))
densenet.to(device)

weights_vec = parameters_to_vector(densenet.parameters())

norm = torch.linalg.vector_norm(weights_vec, ord=2).item()

n = weights_vec.shape[0]

epsilon = norm * 1e-3 / n

print(epsilon)

1.2337409745894137e-08


In [4]:
# Get a baseline
baseline_acc, baseline_loss = evaluate(
    model=densenet, 
    loader=testloader_sampling, 
    loss_fn=loss_fn, 
    device=device, 
    num_epochs=20
)
print(baseline_acc)
print(baseline_loss)

  0%|          | 0/157 [00:00<?, ?it/s]

0.87421875
0.006091292761266231


### Attempt 1: Add Gaussian noise to all (or a subset of) parameters

In [28]:
# Create new model
densenet_temp = copy.deepcopy(densenet)
densenet_temp.to(device)

# Load in perturbed weights 
perturbed_weights = weights_vec.clone()
perturbed_weights += torch.randn(size = (n, ), device=device) * 1e-2

vector_to_parameters(perturbed_weights, densenet_temp.parameters())

# Assess the new model against the old
perturbed_acc, perturbed_loss = evaluate(
    model=densenet_temp, 
    loader=testloader_sampling, 
    loss_fn=loss_fn,
    device=device,
    num_epochs=20
)
print(perturbed_acc)
print(perturbed_loss)


  0%|          | 0/157 [00:00<?, ?it/s]

0.26640625
0.04571987185627222


In [None]:
print(perturbed_weights)

In [None]:
print(weights_vec)

### Attempt 2: Rescale parameters relative to initial size

In [14]:
# Set up new model
perturbed_densenet = models.densenet121()
perturbed_densenet.classifier = torch.nn.Linear(perturbed_densenet.classifier.in_features, 10)
perturbed_densenet.to(device)

# Load in perturbed weights 
k = int(n * 0.00001)
idx = torch.randperm(n)[:k].to(device)
perturbed_weights = weights_vec.clone()
# perturbed_weights[idx] *= 1 + torch.randn(size = (k, ), device=device) * 1e-8 * (1/(n-idx))

vector_to_model(perturbed_weights, perturbed_densenet)

# Assess the new model against the old
perturbed_acc, perturbed_loss = evaluate(
    model=perturbed_densenet, 
    loader=testloader_sampling, 
    loss_fn=loss_fn,
    device=device, 
    num_epochs=20
)
print(perturbed_acc)
print(perturbed_loss)

  0%|          | 0/157 [00:00<?, ?it/s]

0.08828125
0.03890266362577677
