In [1]:
import os
os.chdir("/nfs/homedirs/ayle/guided-research/glow")

In [2]:
from tqdm import tqdm
import numpy as np
from PIL import Image
from math import log, sqrt, pi

import argparse

import torch
from torch import nn, optim
from torch.autograd import Variable, grad
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils

from model import Glow

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch = 16
n_flow = 32
n_block = 4
no_lu = False
affine = False
n_bits = 5
lr = 1e-5
img_size = 32
channels = 3
temp = 0.7
n_sample = 20
iterations = 5

In [4]:
def sample_data(path, batch_size, image_size):
    transform = transforms.Compose(
        [
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    )

    dataset = datasets.ImageFolder(path, transform=transform)
    loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=4)
    loader = iter(loader)

    while True:
        try:
            yield next(loader)

        except StopIteration:
            loader = DataLoader(
                dataset, shuffle=True, batch_size=batch_size, num_workers=4
            )
            loader = iter(loader)
            yield next(loader)


def calc_z_shapes(n_channel, input_size, n_flow, n_block):
    z_shapes = []

    for i in range(n_block - 1):
        input_size //= 2
        n_channel *= 2

        z_shapes.append((n_channel, input_size, input_size))

    input_size //= 2
    z_shapes.append((n_channel * 4, input_size, input_size))

    return z_shapes


def calc_loss(log_p, logdet, image_size, n_bins, channels):
    # log_p = calc_log_p([z_list])
    n_pixel = image_size * image_size * channels

    loss = -log(n_bins) * n_pixel
    loss = loss + logdet + log_p

    return (
        (-loss / (log(2) * n_pixel)).mean(),
        (log_p / (log(2) * n_pixel)).mean(),
        (logdet / (log(2) * n_pixel)).mean(),
    )

In [5]:
model_single = Glow(
    channels, n_flow, n_block, affine=affine, conv_lu=not no_lu
)
model = nn.DataParallel(model_single)
model = model.to(device)

model.load_state_dict(torch.load("checkpoint/model_042001.pt"))

<All keys matched successfully>

In [6]:
mask = {name + ".weight": torch.ones_like(module.weight.data).to(device) for name, module in
            model.module.named_modules() if isinstance(module, nn.Conv2d)
            }
names = list(mask.keys())
# first_layer_name = names[0]
# last_layer_name = names[-1]
# del mask[first_layer_name]
# del mask[last_layer_name]
percentage = 0.2

# get threshold
all_weights = torch.cat(
    [torch.flatten(x) for name, x in model.module.named_parameters() if name in mask])
count = len(all_weights)
amount = int(count * percentage)
limit = torch.topk(all_weights.abs(), amount, largest=False).values[-1]

for (name, weights) in model.module.named_parameters():

    if name in mask:
        
#         all_weights = torch.flatten(weights)
#         count = len(all_weights)
#         amount = int(count * percentage)
#         limit = torch.topk(all_weights.abs(), amount, largest=False).values[-1]

        # prune on l1
        curr_mask = weights.abs() > limit

        mask[name] = curr_mask

with torch.no_grad():

    for name, tensor in model.module.named_parameters():
        if name in mask:
            tensor.data *= mask[name]

In [7]:
for name, tensor in model.module.named_parameters():
    if name in mask:
        print(name)
        print((torch.sum(tensor.data == 0).to(torch.float32) / torch.numel(tensor.data)).item())

blocks.0.flows.0.coupling.net.0.weight
0.02484809048473835
blocks.0.flows.0.coupling.net.2.weight
0.026142120361328125
blocks.0.flows.0.coupling.net.4.conv.weight
0.9449508190155029
blocks.0.flows.1.coupling.net.0.weight
0.026005497202277184
blocks.0.flows.1.coupling.net.2.weight
0.026203155517578125
blocks.0.flows.1.coupling.net.4.conv.weight
0.9575737714767456
blocks.0.flows.2.coupling.net.0.weight
0.0272352434694767
blocks.0.flows.2.coupling.net.2.weight
0.026401519775390625
blocks.0.flows.2.coupling.net.4.conv.weight
0.9479166865348816
blocks.0.flows.3.coupling.net.0.weight
0.0271267369389534
blocks.0.flows.3.coupling.net.2.weight
0.026180267333984375
blocks.0.flows.3.coupling.net.4.conv.weight
0.9428168535232544
blocks.0.flows.4.coupling.net.0.weight
0.028139468282461166
blocks.0.flows.4.coupling.net.2.weight
0.025722503662109375
blocks.0.flows.4.coupling.net.4.conv.weight
0.9519675970077515
blocks.0.flows.5.coupling.net.0.weight
0.025209780782461166
blocks.0.flows.5.coupling.net.

In [8]:
###### TRAINING

path = "/nfs/students/ayle/guided-research/FASHION-jpg/training"
dataset = iter(sample_data(path, batch, img_size))
n_bins = 2.0 ** n_bits

z_sample = []
z_shapes = calc_z_shapes(channels, img_size, n_flow, n_block)
for z in z_shapes:
    z_new = torch.randn(n_sample, *z) * temp
    z_sample.append(z_new.to(device))


optimizer = optim.Adam(model.parameters(), lr=lr)

with tqdm(range(iterations)) as pbar:
    for i in pbar:
        image, _ = next(dataset)
        image = image.to(device)

        image = image * 255

        if n_bits < 8:
            image = torch.floor(image / 2 ** (8 - n_bits))

        image = image / n_bins - 0.5

        if i == 0:
            with torch.no_grad():
                log_p, logdet, _ = model.module(
                    image + torch.rand_like(image) / n_bins
                )

                continue

        else:
            log_p, logdet, _ = model(image + torch.rand_like(image) / n_bins)

        logdet = logdet.mean()

        loss, log_p, log_det = calc_loss(log_p, logdet, img_size, n_bins, channels=channels)
        model.zero_grad()
        loss.backward()
        # warmup_lr = args.lr * min(1, i * batch_size / (50000 * 10))
        warmup_lr = lr
        optimizer.param_groups[0]["lr"] = warmup_lr
        optimizer.step()

        pbar.set_description(
            f"Loss: {loss.item():.5f}; logP: {log_p.item():.5f}; logdet: {log_det.item():.5f}; lr: {warmup_lr:.7f}"
        )

        #############
        with torch.no_grad():

            for name, tensor in model.module.named_parameters():
                if name in mask:
                    tensor.data *= mask[name]
        ############

        if i % 1 == 0:
            with torch.no_grad():
                utils.save_image(
                    model_single.reverse(z_sample).cpu().data,
                    f"sample/{str(i + 1).zfill(6)}.png",
                    normalize=True,
                    nrow=10,
                    range=(-0.5, 0.5),
                )

Loss: 4.30087; logP: -5.37056; logdet: 6.06969; lr: 0.0000100: 100%|██████████| 5/5 [00:17<00:00,  3.46s/it]  


In [10]:
###### EVALUATION


path = "/nfs/students/ayle/guided-research/FASHION-jpg/testing"
dataset = iter(sample_data(path, batch, img_size))
n_bins = 2.0 ** n_bits

z_sample = []
z_shapes = calc_z_shapes(channels, img_size, n_flow, n_block)
for z in z_shapes:
    z_new = torch.randn(n_sample, *z) * temp
    z_sample.append(z_new.to(device))
    
cum_log_p = []

with tqdm(range(int(10000 / batch))) as pbar:
    for i in pbar:
        image, _ = next(dataset)
        image = image.to(device)

        image = image * 255

        if n_bits < 8:
            image = torch.floor(image / 2 ** (8 - n_bits))

        image = image / n_bins - 0.5

        if i == 0:
            with torch.no_grad():
                log_p, logdet, _ = model.module(
                    image
                )
                cum_log_p.append(log_p.cpu().detach().numpy())

                continue

        else:
            log_p, logdet, _ = model(image)
            cum_log_p.append(log_p.cpu().detach().numpy())

        logdet = logdet.mean()

        loss, log_p, log_det = calc_loss(log_p, logdet, img_size, n_bins, channels=channels)

        pbar.set_description(
            f"Loss: {loss.item():.5f}; logP: {log_p.item():.5f}; Avg logP: {np.mean(cum_log_p).item():.5f}; logdet: {log_det.item():.5f}"
        )

Loss: 3.54135; logP: -4.61029; Avg logP: -9514.09375; logdet: 6.06894:   4%|▎         | 23/625 [00:04<02:02,  4.89it/s]


KeyboardInterrupt: 