In [1]:
import os
os.chdir("/nfs/homedirs/ayle/guided-research/SNIP-it/glow")

In [2]:
from tqdm import tqdm
import numpy as np
from PIL import Image
from math import log, sqrt, pi

import argparse

import torch
from torch import nn, optim
from torch.autograd import Variable, grad
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils

from model import Glow
from glow.Johnit import Johnit
from glow.SNIPit import SNIPit
from glow.SNIP import SNIP
from glow.criterions.StructuredEFGit import StructuredEFGit
from glow.criterions.SNAP import SNAP
from glow.train import get_celeba_loaders

from copy import deepcopy
from utils.metrics import calculate_aupr, calculate_auroc

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch = 32
n_flow = 32
n_block = 4
no_lu = False
affine = False
n_bits = 5
lr = 1e-5
img_size = 32
channels = 3
temp = 0.7
n_sample = 20
iterations = 1000

pruning_limit = 0.3
local_pruning = False

In [4]:
def sample_data(path, batch_size, image_size):
    transform = transforms.Compose(
        [
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    )

    dataset = datasets.ImageFolder(path, transform=transform)
    loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=4)
    loader = iter(loader)

    while True:
        try:
            yield next(loader)

        except StopIteration:
            loader = DataLoader(
                dataset, shuffle=True, batch_size=batch_size, num_workers=4
            )
            loader = iter(loader)
            yield next(loader)


def calc_z_shapes(n_channel, input_size, n_flow, n_block):
    z_shapes = []

    for i in range(n_block - 1):
        input_size //= 2
        n_channel *= 2

        z_shapes.append((n_channel, input_size, input_size))

    input_size //= 2
    z_shapes.append((n_channel * 4, input_size, input_size))

    return z_shapes


def calc_loss(log_p, logdet, image_size, n_bins, channels):
    # log_p = calc_log_p([z_list])
    n_pixel = image_size * image_size * channels

    loss = -log(n_bins) * n_pixel
    loss = loss + logdet + log_p

    return (
        (-loss / (log(2) * n_pixel)).mean(),
        (log_p / (log(2) * n_pixel)).mean(),
        (logdet / (log(2) * n_pixel)).mean(),
    )

In [37]:
model_single = Glow(
    channels, n_flow, n_block, affine=affine, conv_lu=not no_lu
)
model = nn.DataParallel(model_single)
model = model.to(device)

# model.load_state_dict(torch.load("checkpoint/model_dataset=CELEBA_criterion=Johnit_sparsity=0.5_local=True.pt"))
model.load_state_dict(torch.load("/nfs/students/ayle/guided-research/glow/checkpoints/model_dataset=train_criterion=EmptyCrit_sparsity=0.0_local=False.pt"))

<All keys matched successfully>

In [6]:
backup_model = deepcopy(model)

In [39]:
# CELEBA 
# dataset = iter(get_celeba_loaders('/nfs/students/ayle/guided-research/', batch, img_size))
# len_dataset = len(datasets.CelebA(
#             '/nfs/students/ayle/guided-research/',
#             split='train',
#             download=True
#         ))

# CIFAR10
transform = transforms.Compose(
        [
            transforms.Resize(img_size),
            transforms.CenterCrop(img_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    )
path  ="/nfs/students/ayle/guided-research/CIFAR-10-images/train"
dataset = datasets.ImageFolder(path, transform=transform)
loader = DataLoader(dataset, shuffle=True, batch_size=batch, num_workers=4)
dataset = iter(loader)

# compute scores
n_bins = 2.0 ** n_bits

criterion = SNIP(limit=pruning_limit, model=model.module, generative=True, nbins=n_bins, img_size=img_size, channels=channels, loss_f=calc_loss)
criterion.prune(pruning_limit, train_loader=get_celeba_loaders('/nfs/students/ayle/guided-research/', batch, img_size), local=local_pruning)

Files already downloaded and verified


In [40]:
orig_grads = criterion.grads_abs

In [41]:
# backup_model = deepcopy(model)

In [42]:
batch = 1

In [44]:
# CIFAR10
transform = transforms.Compose(
        [
            transforms.Resize(img_size),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
        ]
    )
path  ="/nfs/students/ayle/guided-research/CIFAR-10-images/test"
dataset = datasets.ImageFolder(path, transform=transform)
dataset = DataLoader(dataset, shuffle=True, batch_size=batch, num_workers=4)

# CELEBA test set
test_transform = transforms.Compose(
        [
            transforms.Resize(img_size),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
        ]
    )
test_set = datasets.CelebA(
    '/nfs/students/ayle/guided-research/',
    split='test',
    download=True,
    transform=test_transform
)
dataset = DataLoader(test_set, shuffle=False, batch_size=batch, num_workers=4)

# compute scores
n_bins = 2.0 ** n_bits

norms = []
for i, (x, y) in enumerate(tqdm(dataset)):
    if i == 100:
        break
    model = deepcopy(backup_model)
    
    model.eval()
    
    criterion = SNIP(limit=pruning_limit, model=model.module, generative=True, nbins=n_bins, img_size=img_size, channels=channels, loss_f=calc_loss)
    criterion.prune(pruning_limit, train_loader=[(x, y)], local=local_pruning)
    
    layer_norms = []
    for grad1, grad2 in zip(orig_grads.values(), criterion.grads_abs.values()):
        layer_norms.append(torch.norm(grad1 - grad2, p=5).cpu().numpy())
    norms.append(np.mean(layer_norms))

Files already downloaded and verified


  1%|          | 100/19962 [02:37<8:41:59,  1.58s/it]


In [45]:
# SVHN data loader
transformers = transforms.Compose([transforms.ToTensor()
                                  ])
ood_test_set = datasets.SVHN(
        '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/data',
        split='test',
        download=True,
        transform=transformers
    )
ood_dataset = torch.utils.data.DataLoader(
        ood_test_set,
        batch_size=batch,
        shuffle=False,
        pin_memory=True,
        num_workers=4
    )

ood_norms = []
for i, (x, y) in enumerate(tqdm(ood_dataset)):
    if i == 100:
        break
    
    model.eval()
    
    criterion = SNIP(limit=pruning_limit, model=model.module, generative=True, nbins=n_bins, img_size=img_size, channels=channels, loss_f=calc_loss)
    criterion.prune(pruning_limit, train_loader=[(x, y)], local=local_pruning)
    
    layer_norms = []
    for grad1, grad2 in zip(orig_grads.values(), criterion.grads_abs.values()):
        layer_norms.append(torch.norm(grad1 - grad2, p=5).cpu().numpy())
    ood_norms.append(np.mean(layer_norms))

Using downloaded and verified file: /nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/data/test_32x32.mat


  0%|          | 100/26032 [02:14<9:37:09,  1.34s/it]Traceback (most recent call last):
  File "/nfs/homedirs/ayle/miniconda3/envs/gr/lib/python3.8/multiprocessing/queues.py", line 245, in _feed
    send_bytes(obj)
  File "/nfs/homedirs/ayle/miniconda3/envs/gr/lib/python3.8/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/nfs/homedirs/ayle/miniconda3/envs/gr/lib/python3.8/multiprocessing/connection.py", line 411, in _send_bytes
    self._send(header + buf)
  File "/nfs/homedirs/ayle/miniconda3/envs/gr/lib/python3.8/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
  0%|          | 100/26032 [02:14<9:42:47,  1.35s/it]


In [46]:
np.mean(norms)

0.019479753

In [48]:
norms = np.array(norms)
ood_norms = np.array(ood_norms)

In [49]:
calculate_auroc(np.concatenate((np.zeros_like(norms), np.ones_like(ood_norms))), np.concatenate((norms, ood_norms)))

0.8903