# Glow

In [None]:
# Import required packages
import torch
import torchvision as tv
import numpy as np
import normflow as nf

from matplotlib import pyplot as plt
from tqdm import tqdm

In [None]:
# Set up model

# Define flows
K = 32
torch.manual_seed(0)

latent_shape = (3, 16, 16)
channels = 3
hidden_channels = 128
split_mode = 'channel'
scale = True
num_classes = 10

flows = []
for i in range(K):
    flows += [nf.flows.GlowBlock(channels, hidden_channels,
                                 split_mode=split_mode, scale=scale)]

# Set prior and q0
q0 = nf.distributions.ClassCondDiagGaussian(latent_shape, num_classes)

# Construct flow model
model = nf.ClassCondFlow(q0=q0, flows=flows)

# Move model on GPU if available
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')
model = model.to(device)
model = model.double()

In [None]:
# Prepare training data
batch_size = 192

logit = nf.utils.Logit(alpha=0.05)
transform = tv.transforms.Compose([tv.transforms.Resize(16), tv.transforms.ToTensor(),
                                   nf.utils.Jitter(), logit, nf.utils.ToDevice(device)])
train_data = tv.datasets.CIFAR10('/scratch2/vs488/flow/lars/datasets/', train=True, download=True,
                                 transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True,
                                           drop_last=True)
train_iter = iter(train_loader)

In [None]:
# Train model
max_iter = 1000

loss_hist = np.array([])

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-5)

for i in tqdm(range(max_iter)):
    try:
        x, y = next(train_iter)
    except StopIteration:
        train_iter = iter(train_loader)
        x, y = next(train_iter)
    optimizer.zero_grad()
    loss = model.forward_kld(x, y.to(device))
        
    if ~(torch.isnan(loss) | torch.isinf(loss)):
        loss.backward()
        optimizer.step()

    loss_hist = np.append(loss_hist, loss.detach().to('cpu').numpy())
    del(x, y, loss)

plt.figure(figsize=(10, 10))
plt.plot(loss_hist, label='loss')
plt.legend()
plt.show()

In [None]:
# Model samples
num_sample = 10

model.eval();
y = torch.arange(num_classes).repeat(num_sample).to(device)
x, _ = model.sample(y=y)
x_ = torch.clamp(logit.inverse(x).detach().cpu(), 0, 1)
plt.figure(figsize=(10, 10))
plt.imshow(np.transpose(tv.utils.make_grid(x_, nrow=num_classes).numpy(), (1, 2, 0)))
plt.show()

del(x, y, x_)
torch.cuda.empty_cache()
model.train();