In [48]:
import torch
import torchvision
import torchvision.models as models
import torch.optim as optim
import torchvision.transforms as transforms
import statistics as stat
import torch.nn as nn
import torch.nn.functional as F
from functools import wraps
from contextlib import contextmanager, _GeneratorContextManager
from copy import deepcopy

In [42]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        base = models.resnet18(pretrained=True)
        self.base = nn.Sequential(*list(base.children())[:-1])
        in_features = base.fc.in_features
        self.drop = nn.Dropout()
        self.final = nn.Linear(in_features,10)
    
    def forward(self,x):
        x = self.base(x)
        x = self.drop(x.view(-1,self.final.in_features))
        return self.final(x)
    
model = Model().cuda()

In [43]:
#Loading everything
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 64

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,shuffle=False, num_workers=2)

seed = 0

Files already downloaded and verified
Files already downloaded and verified


In [37]:
@contextmanager
def _perturbed_model(
  model,
  sigma: float = 1,
  rng = torch.Generator(),
  magnitude_eps = None
):
  device = next(model.parameters()).device
  if magnitude_eps is not None:
    noise = [torch.normal(0,sigma**2 * torch.abs(p) ** 2 + magnitude_eps ** 2, generator=rng) for p in model.parameters()]
  else:
    noise = [torch.normal(0,sigma**2,p.shape, generator=rng).to(device) for p in model.parameters()]
  model = deepcopy(model)
  try:
    [p.add_(n) for p,n in zip(model.parameters(), noise)]
    yield model
  finally:
    [p.sub_(n) for p,n in zip(model.parameters(), noise)]
    del model

In [57]:
def _pacbayes_sigma(
  model,
  dataloader,
  accuracy: float,
  seed: int,
  magnitude_eps = None,
  search_depth: int = 15,
  montecarlo_samples: int = 10,
  accuracy_displacement: float = 0.1,
  displacement_tolerance: float = 1e-2,
) -> float:
  lower, upper = 0, 2
  sigma = 1

  BIG_NUMBER = 10348628753
  device = next(model.parameters()).device
  rng = torch.Generator(device=device) if magnitude_eps is not None else torch.Generator()
  rng.manual_seed(BIG_NUMBER + seed)

  for _ in range(search_depth):
    sigma = (lower + upper) / 2
    accuracy_samples = []
    for _ in range(montecarlo_samples):
      with _perturbed_model(model, sigma, rng, magnitude_eps) as p_model:
        loss_estimate = 0
        for data, target in dataloader:
          logits = p_model(data.cuda())
          pred = logits.data.max(1, keepdim=True)[1]  # get the index of the max logits
          batch_correct = pred.eq(target.data.view_as(pred)).type(torch.FloatTensor).cpu()
          loss_estimate += batch_correct.sum()
        loss_estimate /= len(dataloader.dataset)
        accuracy_samples.append(loss_estimate)
    displacement = abs(np.mean(accuracy_samples) - accuracy)
    if abs(displacement - accuracy_displacement) < displacement_tolerance:
      break
    elif displacement > accuracy_displacement:
      # Too much perturbation
      upper = sigma
    else:
      # Not perturbed enough to reach target displacement
      lower = sigma
  return sigma

In [58]:
#Main:
def main():
    model.eval()

    #Calculates Model Accuracy
    cross_entropy_loss = 0
    num_correct = 0

    data_loader = [trainloader, testloader][0] #0 for Train, 1 for Test
    num_to_evaluate_on = len(data_loader.dataset)

    for data, target in data_loader:
        data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
        logits = model(data)
        cross_entropy = F.cross_entropy(logits, target, reduction='sum')
        cross_entropy_loss += cross_entropy.item()  # sum up batch loss
        
        pred = logits.data.max(1, keepdim=True)[1]  # get the index of the max logits
        batch_correct = pred.eq(target.data.view_as(pred)).type(torch.FloatTensor).cpu()
        num_correct += batch_correct.sum()

    cross_entropy_loss /= num_to_evaluate_on
    acc = num_correct.item() / num_to_evaluate_on #Get acc for sigma

    return cross_entropy_loss, acc, num_correct

In [59]:
@torch.no_grad()
def run():
    cross_entropy_loss, acc, num_correct = main()
    sigma = _pacbayes_sigma(model, testloader, acc, seed)
    flattness = torch.tensor(1 / sigma ** 2)
    print(flatness)

In [60]:
run()

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!