# Computing cross-entropy for several models

In [9]:
%load_ext autoreload
%autoreload 2

import numpy as np
from matplotlib import pyplot as plt
import torch

from tqdm.notebook import tqdm, trange

from main import *

params = {
      'text.usetex': True,
      'font.family': 'sans serif'
}
matplotlib.rcParams.update(params)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Load energy models (on ImageNet color)

In [None]:
def load_exp(name, step="last", log=True, dataloaders=False):
    """ Load an experiment with a given name. step can be an integer, "best", or "last" (default). """
    exp_dir = Path("models") / name

    with open(exp_dir / "args.json") as f:
        args_dict = json.load(f)

    ctx = TrainingContext(**args_dict, step=step, key_remap=None, seed=None, dataloaders=dataloaders, writer=False)
    if log:
        print(f"{name}: retrieved model at step {ctx.step} and test loss {ctx.test_perf.loss:.2e}")

    # Disable DataParallel (needed for Hessian computation)
    ctx.model.network = ctx.model.network.module

    # Put in eval mode and disable gradients with respect to all parameters.
    ctx.model.eval()
    for p in ctx.model.parameters():
        p.requires_grad = False

    return ctx

ctxs = {
    "color": load_exp("finalclean_imagenet64_color_lr0.0005_1Msteps10decays", dataloaders=True),
    # Add other models here as needed.
}
default_ctx = ctxs["color"]
dataset_info = default_ctx.dataset_info
d = dataset_info.dimension

finalclean_imagenet64_color_lr0.0005_1Msteps10decays: retrieved model at step 1000000 and test loss 4.22e-01


## Compute cross-entropy/NLL

In [11]:
# Compare their probabilities on clean images

def samples_energy(dataloader, t=0):
    default_ctx.time_tracker.reset()

    xs = []
    es = defaultdict(list)
    for x in tqdm(dataloader):
        x = x[0]
        xs.append(x.cpu())
        clean = x.cuda()
        input = model_input(clean, noise_level=t)

        default_ctx.time_tracker.switch("forward")
        for k, c in ctxs.items():
            output = c.model(input, compute_scores=False, create_graph=False)  # (B,)
            es[k].append(output.energy.cpu())

    print(default_ctx.time_tracker.pretty_print())

    xs = torch.cat(xs, dim=0)
    es = {k: torch.cat(e, dim=0) for k, e in es.items()}
    return xs, es

# Compute energies in nats (summed over dimensions) over the whole dataset.
train_dataloader = default_ctx.new_dataloader(train=True, batch_size=2048, num_samples=50_000)  # We don't need the full dataset here.
imgs_train, energies_train = samples_energy(train_dataloader)
test_dataloader = default_ctx.new_dataloader(train=False, batch_size=2048)
imgs_test, energies_test = samples_energy(test_dataloader)

# Compute normalization constant by computing average energy at large noise level
t = 1e3
test_dataloader = default_ctx.new_dataloader(train=False, batch_size=2048)
_, energies_t = samples_energy(test_dataloader, t=t)
constants = {key: d/2 * np.log(2 * np.pi * np.e * t) - energy.mean() for key, energy in energies_t.items()}

for key in ctxs:
    nll_train = 8 + (energies_train[key].mean() + constants[key]) / (d * np.log(2))
    nll_test = 8 + (energies_test[key].mean() + constants[key]) / (d * np.log(2))
    print(f"{key}: train NLL {nll_train:.2f}, test NLL {nll_test:.2f}")

  0%|          | 0/25 [00:00<?, ?it/s]

- forward: 18s660ms


  0%|          | 0/25 [00:00<?, ?it/s]

- forward: 18s682ms


  0%|          | 0/25 [00:00<?, ?it/s]

- forward: 18s498ms
clean_lr0.0005: train NLL 3.33, test NLL 3.36


## Releasing GPU memory

In [None]:
1/0

ZeroDivisionError: division by zero

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()