In [None]:
import os
os.chdir("/nfs/homedirs/ayle/guided-research/SNIP-it/glow")

In [None]:
# !python train.py CELEBA --img_size 32 --channels 3 --batch 64 --prune_criterion SNIPit --pruning_limit 0.5 --local_pruning --checkpoint checkpoint/model_dataset=CELEBA_criterion=EmptyCrit_sparsity=0.0_local=False.pt  # --optim_checkpoint checkpoint/optim_dataset=train_criterion=EmptyCrit_sparsity=0.0_local=False.pt

In [None]:
from tqdm import tqdm
import numpy as np
from PIL import Image
from math import log, sqrt, pi

import argparse

import torch
from torch import nn, optim
from torch.autograd import Variable, grad
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils

from model import Glow
from glow.Johnit import Johnit
from glow.John import John
from glow.SNIPit import SNIPit
from glow.criterions.StructuredEFGit import StructuredEFGit
from glow.criterions.SNAP import SNAP
from glow.train import get_celeba_loaders

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch = 32
n_flow = 32
n_block = 4
no_lu = False
affine = False
n_bits = 5
lr = 1e-5
img_size = 32
channels = 3
temp = 0.7
n_sample = 20
iterations = 1000
n_bins = 2.0 ** n_bits

pruning_limit = 0.9
local_pruning = True

In [None]:
def sample_data(path, batch_size, image_size):
    transform = transforms.Compose(
        [
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    )

    dataset = datasets.ImageFolder(path, transform=transform)
    loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=4)
    loader = iter(loader)

    while True:
        try:
            yield next(loader)

        except StopIteration:
            loader = DataLoader(
                dataset, shuffle=True, batch_size=batch_size, num_workers=4
            )
            loader = iter(loader)
            yield next(loader)


def calc_z_shapes(n_channel, input_size, n_flow, n_block):
    z_shapes = []

    for i in range(n_block - 1):
        input_size //= 2
        n_channel *= 2

        z_shapes.append((n_channel, input_size, input_size))

    input_size //= 2
    z_shapes.append((n_channel * 4, input_size, input_size))

    return z_shapes


def calc_loss(log_p, logdet, image_size, n_bins, channels):
    # log_p = calc_log_p([z_list])
    n_pixel = image_size * image_size * channels

    loss = -log(n_bins) * n_pixel
    loss = loss + logdet + log_p

    return (
        (-loss / (log(2) * n_pixel)).mean(),
        (log_p / (log(2) * n_pixel)).mean(),
        (logdet / (log(2) * n_pixel)).mean(),
    )

In [None]:
model_single = Glow(
    channels, n_flow, n_block, affine=affine, conv_lu=not no_lu
)
model = nn.DataParallel(model_single)
model = model.to(device)

In [None]:
# # model_path = "checkpoint/model_dataset=CELEBA_criterion=Johnit_sparsity=0.5_local=True.pt"
# model_path = "/nfs/students/ayle/guided-research/glow/checkpoints/model_criterion=EmptyCrit_sparsity=0.0_local=False.pt"
# # model_path = "/nfs/students/ayle/guided-research/glow/checkpoints/model_criterion=Johnit_sparsity=0.9_local=True.pt"
model_path = "checkpoint/model_dataset=CELEBA_criterion=EmptyCrit_sparsity=0.0_local=False.pt"

model.load_state_dict(torch.load(model_path))

In [None]:
model.eval()

In [None]:
# for name, module in model.module.named_modules():
#     if name + ".weight" in model.module.mask:
# #         torch.nn.init.kaiming_normal_(
# #                         module.weight.data, mode='fan_in', nonlinearity='relu'
# #                     )
#         torch.nn.init.kaiming_normal_(module.weight.data)

In [None]:
# for name, param in model.named_parameters():
#     print(name)
#     print((param == 0).float().sum() / torch.numel(param))

In [None]:
# path = "/nfs/students/ayle/guided-research/FASHION-jpg/training"

# dataset = iter(sample_data(path, batch, img_size))
# n_bins = 2.0 ** n_bits

# Load train data

In [None]:
# FASHION train data
path = "/nfs/students/ayle/guided-research/FASHION-jpg/training"
dataset = sample_data(path, batch, img_size)

In [None]:
# CELEBA train data
dataset = get_celeba_loaders('/nfs/students/ayle/guided-research/', batch, img_size)

In [None]:
# CIFAR10 train data
path  ="/nfs/students/ayle/guided-research/CIFAR-10-images/train"

# Prune

In [None]:
# CROPit
criterion = Johnit(limit=pruning_limit, model=model.module, generative=True, nbins=n_bins, img_size=img_size, channels=channels, loss_f=calc_loss)
criterion.prune(pruning_limit, train_loader=dataset, local=local_pruning)

In [None]:
# CROP
criterion = John(limit=pruning_limit, model=model.module, generative=True, nbins=n_bins, img_size=img_size, channels=channels, loss_f=calc_loss)
criterion.prune(pruning_limit, train_loader=dataset, local=local_pruning)

In [None]:
# SNIPit
criterion = SNIPit(limit=pruning_limit, model=model.module, generative=True, nbins=n_bins, img_size=img_size, channels=channels, loss_f=calc_loss)
criterion.prune(pruning_limit, train_loader=dataset, local=local_pruning)

# Train

In [None]:
model.train()

In [None]:
optimizer = optim.Adam(model.parameters(), lr=lr)
dataset = iter(dataset)

In [None]:
# TRAINING

z_sample = []
z_shapes = calc_z_shapes(channels, img_size, n_flow, n_block)
for z in z_shapes:
    z_new = torch.randn(n_sample, *z) * temp
    z_sample.append(z_new.to(device))

with tqdm(range(iterations)) as pbar:
    for i in pbar:
        image, _ = next(dataset)
        image = image.to(device)
        
        image = image * 255
        
        if n_bits < 8:
            image = torch.floor(image / 2 ** (8 - n_bits))

        image = image / n_bins - 0.5

        model.module.apply_weight_mask()

        if i == 0:
            with torch.no_grad():
                log_p, logdet, _ = model.module(
                    image + torch.rand_like(image) / n_bins
                )

            with torch.no_grad():
                utils.save_image(
                    model_single.reverse(z_sample).cpu().data,
                    f"sample/{str(i + 1).zfill(6)}.png",
                    normalize=True,
                    nrow=10,
                    range=(-0.5, 0.5),
                    )
                
            continue

        else:
            log_p, logdet, _ = model(image + torch.rand_like(image) / n_bins)

        logdet = logdet.mean()

        loss, log_p, log_det = calc_loss(log_p, logdet, img_size, n_bins, channels=channels)
        model.zero_grad()
        loss.backward()
        # warmup_lr = args.lr * min(1, i * batch_size / (50000 * 10))
        warmup_lr = lr
        optimizer.param_groups[0]["lr"] = warmup_lr
        optimizer.step()

        model.module.apply_weight_mask()

        pbar.set_description(
            f"Loss: {loss.item():.5f}; logP: {log_p.item():.5f}; logdet: {log_det.item():.5f}; lr: {warmup_lr:.7f}"
        )

        if i % 10 == 0:
            with torch.no_grad():
                utils.save_image(
                    model_single.reverse(z_sample).cpu().data,
                    f"sample/{str(i + 1).zfill(6)}.png",
                    normalize=True,
                    nrow=10,
                    range=(-0.5, 0.5),
                )

# Evaluation

In [None]:
model.eval()

In [None]:
# FASHION
path = "/nfs/students/ayle/guided-research/FASHION-jpg/testing"
dataset = iter(sample_data(path, batch, img_size))
len_dataset = len(datasets.ImageFolder(path))

In [None]:
# CIFAR10
path  ="/nfs/students/ayle/guided-research/CIFAR-10-images/test"
dataset = iter(sample_data(path, batch, img_size))
len_dataset = len(datasets.ImageFolder(path))

In [None]:
# CELEBA test set
test_transform = transforms.Compose(
        [
            transforms.Resize(img_size),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
        ]
    )
test_set = datasets.CelebA(
    '/nfs/students/ayle/guided-research/',
    split='test',
    download=True,
    transform=test_transform
)
dataset = iter(DataLoader(test_set, shuffle=False, batch_size=batch, num_workers=4))
len_dataset = len(datasets.CelebA(
            '/nfs/students/ayle/guided-research/',
            split='test',
            download=True
        ))

In [None]:
###### EVALUATION

n_bins = 2.0 ** n_bits

z_sample = []
z_shapes = calc_z_shapes(channels, img_size, n_flow, n_block)
for z in z_shapes:
    z_new = torch.randn(n_sample, *z) * temp
    z_sample.append(z_new.to(device))
    
cum_log_p = []

with tqdm(range(int(len_dataset / batch))) as pbar:
    for i in pbar:
        image, _ = next(dataset)
        image = image.to(device)

        image = image * 255

        if n_bits < 8:
            image = torch.floor(image / 2 ** (8 - n_bits))

        image = image / n_bins - 0.5

        if i == 0:
            with torch.no_grad():
                log_p, logdet, _ = model.module(
                    image
                )
                cum_log_p.append(log_p.cpu().detach().numpy())

                continue

        else:
            log_p, logdet, _ = model(image)
            cum_log_p.append(log_p.cpu().detach().numpy())

        logdet = logdet.mean()

        loss, log_p, log_det = calc_loss(log_p, logdet, img_size, n_bins, channels=channels)

        pbar.set_description(
            f"Loss: {loss.item():.5f}; logP: {log_p.item():.5f}; Avg logP: {np.mean(cum_log_p).item():.5f}; logdet: {log_det.item():.5f}"
        )

## OOD data

In [None]:
# SVHN data loader
transformers = transforms.Compose([transforms.ToTensor()
                                  ])
test_set = datasets.SVHN(
        '/nfs/students/ayle/guided-research/gitignored/data',
        split='test',
        download=True,
        transform=transformers
    )
test_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size=batch,
        shuffle=False,
        pin_memory=True,
        num_workers=4
    )

dataset = iter(test_loader)
len_dataset = len(test_set)

In [None]:
# MNIST dataloader
path = "/nfs/students/ayle/guided-research/MNIST-jpg/testing"
dataset = iter(sample_data(path, batch, img_size))
len_dataset = len(datasets.ImageFolder(path))

In [None]:
###### EVALUATION
n_bins = 2.0 ** n_bits

z_sample = []
z_shapes = calc_z_shapes(channels, img_size, n_flow, n_block)
for z in z_shapes:
    z_new = torch.randn(n_sample, *z) * temp
    z_sample.append(z_new.to(device))
    
ood_cum_log_p = []

with tqdm(range(int(len_dataset / batch))) as pbar:
    for i in pbar:
        image, _ = next(dataset)
        image = image.to(device)
                
        image = image * 255

        if n_bits < 8:
            image = torch.floor(image / 2 ** (8 - n_bits))

        image = image / n_bins - 0.5

        if i == 0:
            with torch.no_grad():
                log_p, logdet, _ = model.module(
                    image
                )
                ood_cum_log_p.append(log_p.cpu().detach().numpy())

                continue

        else:
            log_p, logdet, _ = model(image)
            ood_cum_log_p.append(log_p.cpu().detach().numpy())

        logdet = logdet.mean()

        loss, log_p, log_det = calc_loss(log_p, logdet, img_size, n_bins, channels=channels)

        pbar.set_description(
            f"Loss: {loss.item():.5f}; logP: {log_p.item():.5f}; Avg logP: {np.mean(ood_cum_log_p).item():.5f}; logdet: {log_det.item():.5f}"
        )

# Plot

In [None]:
cum_log_p = np.concatenate([logp for logp in cum_log_p])

In [None]:
ood_cum_log_p = np.concatenate([logp for logp in ood_cum_log_p])

In [None]:
import matplotlib.pyplot as plt

plt.hist(cum_log_p, density=True, bins=100, histtype='stepfilled', label='In-dist', alpha=0.7)
plt.hist(ood_cum_log_p, density=True, bins=100, histtype='stepfilled', label='OOD', alpha=0.7)
plt.legend()
plt.show()

In [None]:
for name, layer in model.module.mask.items():
    print(name)
    print((layer == 0).float().sum() / torch.numel(layer))