# Glow

In [None]:
# Import required packages
import torch
import torchvision as tv
from torchvision import datasets, transforms
import numpy as np
import normflow as nf

from matplotlib import pyplot as plt
from tqdm import tqdm

In [None]:
# Set up model

# Define flows
K = 32
torch.manual_seed(0)

latent_shape = (3, 16, 16)
channels = 3
hidden_channels = 128
split_mode = 'channel'
scale = True

flows = []
for i in range(K):
    flows += [nf.flows.GlowBlock(channels, hidden_channels,
                                 split_mode=split_mode, scale=scale)]

# Set prior and q0
q0 = nf.distributions.DiagGaussian(latent_shape, trainable=True)

# Construct flow model
model = nf.NormalizingFlow(q0=q0, flows=flows)

# Move model on GPU if available
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')
model = model.to(device)
model = model.double()

In [None]:
# Prepare training data
batch_size = 96

logit = nf.utils.Logit(alpha=0.05)
transform = transforms.Compose([transforms.Resize(16), transforms.ToTensor(),
                                nf.utils.Jitter(), logit, nf.utils.ToDevice(device)])
#train_data = datasets.MNIST('/scratch2/vs488/flow/lars/datasets/', train=True, download=True,
#                            transform=transform)
train_data = datasets.CIFAR10('/scratch2/vs488/flow/lars/datasets/', train=True, download=True, transform=transform)
idx = np.where(np.array(train_data.targets) == 6)[0]
train_data.targets = [train_data.targets[i] for i in idx]
train_data.data = train_data.data[idx]
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

In [None]:
# Train model
num_epoch = 20

#loss_hist = np.array([])

#optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-5)

for ep in tqdm(range(num_epoch)):
    for i, (x, y) in enumerate(train_loader):
        optimizer.zero_grad()
        loss = model.forward_kld(x)
        
        if ~(torch.isnan(loss) | torch.isinf(loss)):
            loss.backward()
            optimizer.step()

        loss_hist = np.append(loss_hist, loss.to('cpu').data.numpy())
    torch.cuda.empty_cache()

plt.figure(figsize=(10, 10))
plt.plot(loss_hist, label='loss')
plt.legend()
plt.show()

In [None]:
# Model samples
x, _ = model.sample(64)
x_ = torch.clamp(logit.inverse(x).detach().cpu(), 0, 1)
plt.figure(figsize=(10, 10))
plt.imshow(np.transpose(tv.utils.make_grid(x_).numpy(), (1, 2, 0)))
plt.show()