In [1]:
%matplotlib inline
import os
import vision
from docopt import docopt
from torchvision import transforms
from glow.builder_new import build
from glow.trainer import Trainer
from glow.config import JsonConfig
import cv2
import random
import torch
import numpy as np
from matplotlib import pyplot as plt
from glow.models_new import Glow
from glow import learning_rate_schedule
from torch.utils.data import DataLoader
from glow import thops
from glow.utils import get_proper_device
import datetime
from platform import python_version

In [2]:
print(python_version())

3.6.6


# fix random seeds

In [3]:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
np.random.seed(30)

# load data and transform

In [4]:
hparams = JsonConfig('./hparams/celeba.json')
dataset_name = 'celeba'
dataset_root = 'dataset/CelebA'
dataset = vision.Datasets[dataset_name]
# set transform of dataset
transform = transforms.Compose([
        transforms.CenterCrop(hparams.Data.center_crop),
        transforms.Resize(hparams.Data.resize),
        transforms.ToTensor()])
dataset = dataset(dataset_root, transform=transform)
data_loader = DataLoader(dataset,batch_size=hparams.Train.batch_size,shuffle=False,drop_last=True)
num_batches = len(data_loader)
print('number of batches:', num_batches)

Begin to parse all image attrs
Find 202599 images, with 40 attrs
number of batches: 22511


# initialize Glow network

In [5]:
Glownet = Glow(hparams)
Glownet.device = hparams.Device.glow
devices = get_proper_device(hparams.Device.glow)
if len(devices) > 0:
    device = Glownet.device[0]
    Glownet = Glownet.to(device)
else:
    device = 'cpu'

[Builder]: Found 1 gpu
[Builder]: cuda:1 is not found, ignore.
[Builder]: cuda:2 is not found, ignore.
[Builder]: cuda:3 is not found, ignore.


# initialize checkpoints

In [6]:
date = str(datetime.datetime.now())
date = date[:date.rfind(":")].replace("-", "")\
                                     .replace(":", "")\
                                     .replace(" ", "_")
log_dir = os.path.join(hparams.Dir.log_root, "log_" + date)
checkpoints_dir = os.path.join(log_dir, "checkpoints")

# learning rate schedule

In [7]:
schedule_name = "default"
schedule_args = {}
if "Schedule" in hparams.Optim:
    schedule_name = hparams.Optim.Schedule.name
    schedule_args = hparams.Optim.Schedule.args.to_dict()
if not ("init_lr" in schedule_args):
        schedule_args["init_lr"] = hparams.Optim.args.lr
assert schedule_args["init_lr"] == hparams.Optim.args.lr,\
                "Optim lr {} != Schedule init_lr {}".format(hparams.Optim.args.lr, schedule_args["init_lr"])
lrschedule = {
                "func": getattr(learning_rate_schedule, schedule_name),
                "args": schedule_args
              }
opt_params = hparams.Optim.args

# initialize optimizer

In [8]:
optim_name = hparams.Optim.name
if optim_name == 'adam':
    optimizer = torch.optim.Adam(Glownet.parameters(), opt_params['lr'], opt_params['betas'], opt_params['eps'])

# train Glow network

In [None]:
# initialize global_step : cumulative no. of optimizer steps = no. epochs * no. batches
global_step = 0 
generative_loss_perNepoch = []
classification_loss_perNepoch = []
check_images = False
for epoch in range(1):
    print("epoch no.:", epoch)
    for i_batch, batch in enumerate(data_loader):
        print('epoch no.',epoch,', batch no.',i_batch,' of ', num_batches, end='\r')
        
        # update learning rate
        lr = lrschedule["func"](global_step=0,**lrschedule["args"])
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
            
        # initialize optimizer
        optimizer.zero_grad()
        
        # send data to device and extract
        for k in batch:
            batch[k] = batch[k].to(device)
        # extract images x
        x = batch["x"]
        # extract labels y, y_onehot
        y = None
        y_onehot = None
        if hparams.Glow.y_condition:
            if hparams.Criterion.y_condition == "multi-classes":
                assert "y_onehot" in batch, "multi-classes ask for `y_onehot` (torch.FloatTensor onehot)"
                y_onehot = batch["y_onehot"]
            elif hparams.Criterion.y_condition == "single-class":
                assert "y" in batch, "single-class ask for `y` (torch.LongTensor indexes)"
                y = batch["y"]
                y_onehot = thops.onehot(y, num_classes=hparams.Glow.y_classes)

        # initialize ActNorm (first iteration only)
        if global_step == 0:
            Glownet(x[:hparams.Train.batch_size // len(devices), ...],y_onehot[:hparams.Train.batch_size // len(devices), ...] if y_onehot is not None else None)
        
        # parallel 
        if len(devices) > 1 and not hasattr(Glownet, "module"):
            print("[Parallel] move to {}".format(self.devices))
            self.graph = torch.nn.parallel.DataParallel(self.graph, self.devices, self.devices[0])
            
        # forward phase
        z, nll, y_logits = Glownet(x=x, y_onehot=y_onehot)
        
        # construct genetative loss
        loss_generative = Glownet.loss_generative(nll)
        
        # construct classification loss
        loss_classes = 0
        if hparams.Glow.y_condition:
            loss_classes = (Glownet.loss_multi_classes(y_logits, y_onehot)
                            if self.y_criterion == "multi-classes" else
                                    Glownet.loss_class(y_logits, y))
        
        #construct overall loss function
        if global_step % hparams.Train.scalar_log_gap == 0:
            generative_loss_perNepoch.append(loss_generative)
            print("\ngenerative loss:", loss_generative.detach().cpu().numpy())
            if hparams.Glow.y_condition:
                generative_loss_perNepoch.append(loss_generative)
                print("classification loss:", loss_classes)
        loss = loss_generative + loss_classes * hparams.Train.weight_y

        # backpropagate gradients
        Glownet.zero_grad()
        optimizer.zero_grad()
        loss.backward()
        
        # clip gradients
        if hparams.Train.max_grad_clip is not None and hparams.Train.max_grad_clip > 0:
            torch.nn.utils.clip_grad_value_(Glownet.parameters(), hparams.Train.max_grad_clip)
        if hparams.Train.max_grad_norm is not None and hparams.Train.max_grad_norm > 0:
            grad_norm = torch.nn.utils.clip_grad_norm_(Glownet.parameters(), hparams.Train.max_grad_norm)
            if global_step % hparams.Train.scalar_log_gap == 0:
                print("norm of gradients:", grad_norm)
        
        # gradient step
        optimizer.step()

        
        # checkpoints
        if global_step % hparams.Train.checkpoints_gap == 0 and global_step > 0:
            save(global_step=global_step,
                         graph=Glownet,
                         optim=optimizer,
                         pkg_dir=checkpoints_dir,
                         is_best=True,
                         max_checkpoints=hparams.Train.max_checkpoints)
        
        # check generated images and plot
        if check_images:
            if global_step % hparams.Train.plot_gap == 0:
                img = Glownet(z=z, y_onehot=y_onehot, reverse=True)
                # img = torch.clamp(img, min=0, max=1.0)
                if hparams.Glow.y_condition:
                    if hparams.Criterion.y_condition == "multi-classes":
                        y_pred = torch.sigmoid(y_logits)
                    elif hparams.Criterion.y_condition == "single-class":
                        y_pred = thops.onehot(torch.argmax(F.softmax(y_logits, dim=1), dim=1, keepdim=True),
                                                      self.y_classes)
                    y_true = y_onehot

                for bi in range(min([len(img), 4])):
                    self.writer.add_image("0_reverse/{}".format(bi), torch.cat((img[bi], batch["x"][bi]), dim=1), self.global_step)
                    if hparams.Glow.y_condition:
                        self.writer.add_image("1_prob/{}".format(bi), plot_prob([y_pred[bi], y_true[bi]], ["pred", "true"]), self.global_step)


            # inference
            if hparams.Train.inference_gap is not None:
                if global_step % hparams.Train.inference_gap == 0:
                    img = Glownet(z=None, y_onehot=y_onehot, eps_std=0.5, reverse=True)
                    for bi in range(min([len(img), 4])):
                        self.writer.add_image("2_sample/{}".format(bi), img[bi], self.global_step)

        # global step
        global_step += 1

epoch no.: 0
epoch no. 0 , batch no. 0  of  22511
generative loss: 0.14936174
norm of gradients: 82.2569327122766
epoch no. 0 , batch no. 20  of  22511
generative loss: 0.4281676
norm of gradients: 102.09868519696306
epoch no. 0 , batch no. 37  of  22511

In [None]:
if check_images:
    self.writer.export_scalars_to_json(os.path.join(self.log_dir, "all_scalars.json"))
    self.writer.close()