In [1]:
%matplotlib inline
import os
import vision
from docopt import docopt
from torchvision import transforms
from glow.builder_new import build
from glow.trainer import Trainer
from glow.config import JsonConfig
import cv2
import random
import torch
import numpy as np
from matplotlib import pyplot as plt
from glow.models_new import Glow
from glow import learning_rate_schedule
from glow import thops
from glow.utils import get_proper_device
from glow.utils import save
import datetime
from platform import python_version
from torch.utils.data import Dataset, DataLoader

In [2]:
print(python_version())

3.6.6


# fix random seeds

In [3]:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
np.random.seed(30)

# load data and transform - CelebA

In [4]:
hparams = JsonConfig('./hparams/celeba.json')
dataset_name = 'celeba'
if dataset_name == 'celeba':
    dataset_root = 'dataset/CelebA'
    dataset = vision.Datasets[dataset_name]
    # set transform of dataset
    transform = transforms.Compose([
            transforms.CenterCrop(hparams.Data.center_crop),
            transforms.Resize(hparams.Data.resize),
            transforms.ToTensor()])
    dataset = dataset(dataset_root, transform=transform)
    data_loader = DataLoader(dataset,batch_size=hparams.Train.batch_size,shuffle=False,drop_last=True)
    num_batches = len(data_loader)
    print('number of batches:', num_batches)

Begin to parse all image attrs
Find 202599 images, with 40 attrs
number of batches: 22511


# load data - sinusoidal

In [5]:
dataset_name = 'pde'
if dataset_name == 'pde':
    class NumpyDataset(Dataset):
        def __init__(self, path):
            super().__init__()
            self.data = np.load(path)
            self.transform = transforms.Compose([transforms.ToTensor()])

        def __getitem__(self, idx):
            return self.transform(np.float32(self.data[idx]))

        def __len__(self):
            return self.data.shape[0]

    def get_dataloader(path, batchsize):
        ds = NumpyDataset(path)
        dl = DataLoader(ds, batch_size=batchsize, drop_last=True, shuffle=True)
        return dl
    data_loader = get_dataloader('pdedata/data_lowranknoise.npy',8)

In [6]:
num_batches = len(data_loader)
print('number of batches:', num_batches)
hparams = JsonConfig('./hparams/pde.json')

number of batches: 1250


# initialize Glow network

In [7]:
Glownet = Glow(hparams)
Glownet.device = hparams.Device.glow
devices = get_proper_device(hparams.Device.glow)
if len(devices) > 0:
    device = Glownet.device[0]
    Glownet = Glownet.to(device)
else:
    device = 'cpu'

[Builder]: Found 1 gpu
[Builder]: cuda:1 is not found, ignore.
[Builder]: cuda:2 is not found, ignore.
[Builder]: cuda:3 is not found, ignore.


# initialize checkpoints

In [8]:
date = str(datetime.datetime.now())
date = date[:date.rfind(":")].replace("-", "")\
                                     .replace(":", "")\
                                     .replace(" ", "_")
log_dir = os.path.join(hparams.Dir.log_root, "log_" + date)
checkpoints_dir = os.path.join(log_dir, "checkpoints")
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
# write hparams
hparams.dump(log_dir)
if not os.path.exists(checkpoints_dir):
    os.makedirs(checkpoints_dir)

 {
  Dir {
    log_root: results/pde
  }
  Glow {
    image_shape: [64, 64, 1]
    hidden_channels: 512
    K: 32
    L: 3
    actnorm_scale: 1.0
    flow_permutation: invconv
    flow_coupling: affine
    LU_decomposed: False
    learn_top: False
    y_condition: False
    y_classes: 40
  }
  Criterion {
    y_condition: multi-classes
  }
  Data {
    center_crop: 160
    resize: 64
  }
  Optim {
    name: adam
      args {
      lr: 0.0001
      betas: [0.9, 0.9999]
      eps: 1e-08
    }
      Schedule {
      name: noam_learning_rate_decay
          args {
        warmup_steps: 1000
        minimum: 0.0001
      }
    }
  }
  Device {
    glow: ['cuda:0', 'cuda:1', 'cuda:2', 'cuda:3']
    data: cuda:0
  }
  Train {
    batch_size: 8
    num_batches: 10000
    max_grad_clip: 5
    max_grad_norm: 100
    max_checkpoints: 20
    checkpoints_gap: 5000
    num_plot_samples: 1
    scalar_log_gap: 20
    plot_gap: 20
    inference_gap: 20
    warm_start: 
    weight_y: 0.5
  }
  Infer {
 

# learning rate schedule

In [9]:
schedule_name = "default"
schedule_args = {}
if "Schedule" in hparams.Optim:
    schedule_name = hparams.Optim.Schedule.name
    schedule_args = hparams.Optim.Schedule.args.to_dict()
if not ("init_lr" in schedule_args):
        schedule_args["init_lr"] = hparams.Optim.args.lr
assert schedule_args["init_lr"] == hparams.Optim.args.lr,\
                "Optim lr {} != Schedule init_lr {}".format(hparams.Optim.args.lr, schedule_args["init_lr"])
lrschedule = {
                "func": getattr(learning_rate_schedule, schedule_name),
                "args": schedule_args
              }
opt_params = hparams.Optim.args

# initialize optimizer

In [10]:
optim_name = hparams.Optim.name
if optim_name == 'adam':
    optimizer = torch.optim.Adam(Glownet.parameters(), opt_params['lr'], opt_params['betas'], opt_params['eps'])

# train Glow network

In [None]:
# initialize global_step : cumulative no. of optimizer steps = no. epochs * no. batches
global_step = 0 
generative_loss_perNepoch = []
classification_loss_perNepoch = []
check_images = False
for epoch in range(10):
    print("epoch no.:", epoch)
    for i_batch, batch in enumerate(data_loader):
        print('epoch no.',epoch,', batch no.',i_batch,' of ', num_batches, end='\r')
        
        # update learning rate
        lr = lrschedule["func"](global_step=0,**lrschedule["args"])
        #print(lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
            
        # clear gradients for current mini-batch
        optimizer.zero_grad()
        
        # send data to device and extract
        if dataset_name=='celeba':
            for k in batch:
                print(k.shape)
                batch[k] = batch[k].to(device)
        elif dataset_name=='pde':
            x = batch.to(device)
            
        # extract images x
        if dataset_name =='celeba':
            x = batch["x"]
            
        # extract labels y, y_onehot
        y = None
        y_onehot = None
        if hparams.Glow.y_condition:
            if hparams.Criterion.y_condition == "multi-classes":
                assert "y_onehot" in batch, "multi-classes ask for `y_onehot` (torch.FloatTensor onehot)"
                y_onehot = batch["y_onehot"]
            elif hparams.Criterion.y_condition == "single-class":
                assert "y" in batch, "single-class ask for `y` (torch.LongTensor indexes)"
                y = batch["y"]
                y_onehot = thops.onehot(y, num_classes=hparams.Glow.y_classes)

        # initialize ActNorm (first iteration only)
        if global_step == 0:
            Glownet(x[:hparams.Train.batch_size // len(devices), ...],y_onehot[:hparams.Train.batch_size // len(devices), ...] if y_onehot is not None else None)
        
        # parallel 
        if len(devices) > 1 and not hasattr(Glownet, "module"):
            print("[Parallel] move to {}".format(self.devices))
            self.graph = torch.nn.parallel.DataParallel(self.graph, self.devices, self.devices[0])
            
        # forward phase
        z, nll, y_logits = Glownet(x=x, y_onehot=y_onehot)
        
        # construct genetative loss
        loss_generative = Glownet.loss_generative(nll)
        
        # construct classification loss
        loss_classes = 0
        if hparams.Glow.y_condition:
            loss_classes = (Glownet.loss_multi_classes(y_logits, y_onehot)
                            if self.y_criterion == "multi-classes" else
                                    Glownet.loss_class(y_logits, y))
        
        #construct overall loss function
        if global_step % hparams.Train.scalar_log_gap == 0:
            generative_loss_perNepoch.append(loss_generative)
            print("\ngenerative loss:", loss_generative.detach().cpu().numpy())
            if hparams.Glow.y_condition:
                generative_loss_perNepoch.append(loss_generative)
                print("classification loss:", loss_classes)
        loss = loss_generative + loss_classes * hparams.Train.weight_y

        # backpropagate gradients
        loss.backward()
        
        # clip gradients
        if hparams.Train.max_grad_clip is not None and hparams.Train.max_grad_clip > 0:
            torch.nn.utils.clip_grad_value_(Glownet.parameters(), hparams.Train.max_grad_clip)
        if hparams.Train.max_grad_norm is not None and hparams.Train.max_grad_norm > 0:
            grad_norm = torch.nn.utils.clip_grad_norm_(Glownet.parameters(), hparams.Train.max_grad_norm)
            if global_step % hparams.Train.scalar_log_gap == 0:
                print("norm of gradients:", grad_norm)
        
        # gradient step
        optimizer.step()

        
        # checkpoints
        if global_step % hparams.Train.checkpoints_gap == 0 and global_step > 0:
            save(global_step=global_step,
                         graph=Glownet,
                         optim=optimizer,
                         pkg_dir=checkpoints_dir,
                         is_best=True,
                         max_checkpoints=hparams.Train.max_checkpoints)
        
        # check generated images and plot
        if check_images:
            if global_step % hparams.Train.plot_gap == 0:
                img = Glownet(z=z, y_onehot=y_onehot, reverse=True)
                # img = torch.clamp(img, min=0, max=1.0)
                if hparams.Glow.y_condition:
                    if hparams.Criterion.y_condition == "multi-classes":
                        y_pred = torch.sigmoid(y_logits)
                    elif hparams.Criterion.y_condition == "single-class":
                        y_pred = thops.onehot(torch.argmax(F.softmax(y_logits, dim=1), dim=1, keepdim=True),
                                                      self.y_classes)
                    y_true = y_onehot

                for bi in range(min([len(img), 4])):
                    self.writer.add_image("0_reverse/{}".format(bi), torch.cat((img[bi], batch["x"][bi]), dim=1), self.global_step)
                    if hparams.Glow.y_condition:
                        self.writer.add_image("1_prob/{}".format(bi), plot_prob([y_pred[bi], y_true[bi]], ["pred", "true"]), self.global_step)


            # inference
            if hparams.Train.inference_gap is not None:
                if global_step % hparams.Train.inference_gap == 0:
                    img = Glownet(z=None, y_onehot=y_onehot, eps_std=0.5, reverse=True)
                    for bi in range(min([len(img), 4])):
                        self.writer.add_image("2_sample/{}".format(bi), img[bi], self.global_step)

        # global step
        global_step += 1

epoch no.: 0
epoch no. 0 , batch no. 0  of  1250
generative loss: 4.115608
norm of gradients: 156.08403423899406
epoch no. 0 , batch no. 20  of  1250
generative loss: 4.0624237
norm of gradients: 72.67700821575286
epoch no. 0 , batch no. 40  of  1250
generative loss: 4.0241942
norm of gradients: 57.897531735818205
epoch no. 0 , batch no. 60  of  1250
generative loss: 3.9914937
norm of gradients: 54.73581612716947
epoch no. 0 , batch no. 80  of  1250
generative loss: 3.9562435
norm of gradients: 56.825706512769266
epoch no. 0 , batch no. 100  of  1250
generative loss: 3.9247472
norm of gradients: 50.240040211846804
epoch no. 0 , batch no. 120  of  1250
generative loss: 3.8969753
norm of gradients: 44.43958457123058
epoch no. 0 , batch no. 140  of  1250
generative loss: 3.8680391
norm of gradients: 41.67861898419261
epoch no. 0 , batch no. 160  of  1250
generative loss: 3.843376
norm of gradients: 40.864850461710084
epoch no. 0 , batch no. 180  of  1250
generative loss: 3.8248544
norm of

In [None]:
if check_images:
    self.writer.export_scalars_to_json(os.path.join(self.log_dir, "all_scalars.json"))
    self.writer.close()

In [None]:
hparams.Train.checkpoints_gap