In [2]:
%matplotlib inline
import os
import vision
from docopt import docopt
from torchvision import transforms
from glow.builder_new import build
from glow.trainer import Trainer
from glow.config import JsonConfig
import cv2
import random
import torch
import numpy as np
from matplotlib import pyplot as plt
from glow.models_new import Glow

In [3]:
hparams = JsonConfig('./hparams/celeba.json')
dataset_name = 'celeba'
dataset_root = 'dataset/CelebA'
dataset = vision.Datasets[dataset_name]
# set transform of dataset
transform = transforms.Compose([
        transforms.CenterCrop(hparams.Data.center_crop),
        transforms.Resize(hparams.Data.resize),
        transforms.ToTensor()])
# build graph and dataset
built = Glow(hparams)
#built = build(hparams, True)

In [None]:
schedule_name = "default"
schedule_args = {}
if "Schedule" in hparams.Optim:
    schedule_name = hparams.Optim.Schedule.name
    schedule_args = hparams.Optim.Schedule.args.to_dict()
if not ("init_lr" in schedule_args):
        schedule_args["init_lr"] = hparams.Optim.args.lr
assert schedule_args["init_lr"] == hparams.Optim.args.lr,\
                "Optim lr {} != Schedule init_lr {}".format(hparams.Optim.args.lr, schedule_args["init_lr"])
lrschedule = {
                "func": getattr(learning_rate_schedule, schedule_name),
                "args": schedule_args
            }

In [None]:
optim_name = hparams.Optim.name
if optim_name == 'Adam':
    optimizer = torch.optim.Adam(graph.parameters(), hparams.Optim.args.to_dict())

In [None]:
# begin to train
for epoch in range(1000):
    print("epoch", epoch)
    for i_batch, batch in enumerate(DataLoader):
        # update learning rate
        lr = lrschedule["func"](global_step=0,**lrschedule["args"])
        for param_group in self.optim.param_groups:
            param_group['lr'] = lr
        optimizer.zero_grad()
        if self.global_step % self.scalar_log_gaps == 0:
            self.writer.add_scalar("lr/lr", lr, self.global_step)
        # get batch data
        for k in batch:
            batch[k] = batch[k].to(self.data_device)
        x = batch["x"]
        y = None
        y_onehot = None
        if self.y_condition:
            if self.y_criterion == "multi-classes":
                assert "y_onehot" in batch, "multi-classes ask for `y_onehot` (torch.FloatTensor onehot)"
                y_onehot = batch["y_onehot"]
            elif self.y_criterion == "single-class":
                assert "y" in batch, "single-class ask for `y` (torch.LongTensor indexes)"
                y = batch["y"]
                y_onehot = thops.onehot(y, num_classes=self.y_classes)

        # at first time, initialize ActNorm
        if self.global_step == 0:
            self.graph(x[:self.batch_size // len(self.devices), ...],
                               y_onehot[:self.batch_size // len(self.devices), ...] if y_onehot is not None else None)
        # parallel
        if len(self.devices) > 1 and not hasattr(self.graph, "module"):
            print("[Parallel] move to {}".format(self.devices))
            self.graph = torch.nn.parallel.DataParallel(self.graph, self.devices, self.devices[0])
        # forward phase
        z, nll, y_logits = self.graph(x=x, y_onehot=y_onehot)

        # loss
        loss_generative = Glow.loss_generative(nll)
        loss_classes = 0
        if self.y_condition:
            loss_classes = (Glow.loss_multi_classes(y_logits, y_onehot)
                            if self.y_criterion == "multi-classes" else
                                    Glow.loss_class(y_logits, y))
        if self.global_step % self.scalar_log_gaps == 0:
            self.writer.add_scalar("loss/loss_generative", loss_generative, self.global_step)
            if self.y_condition:
                self.writer.add_scalar("loss/loss_classes", loss_classes, self.global_step)
        loss = loss_generative + loss_classes * self.weight_y

        # backward
        self.graph.zero_grad()
        self.optim.zero_grad()
        loss.backward()
        # operate grad
        if self.max_grad_clip is not None and self.max_grad_clip > 0:
            torch.nn.utils.clip_grad_value_(self.graph.parameters(), self.max_grad_clip)
        if self.max_grad_norm is not None and self.max_grad_norm > 0:
            grad_norm = torch.nn.utils.clip_grad_norm_(self.graph.parameters(), self.max_grad_norm)
            if self.global_step % self.scalar_log_gaps == 0:
                self.writer.add_scalar("grad_norm/grad_norm", grad_norm, self.global_step)
        # step
        self.optim.step()

        # checkpoints
        if self.global_step % self.checkpoints_gap == 0 and self.global_step > 0:
            save(global_step=self.global_step,
                         graph=self.graph,
                         optim=self.optim,
                         pkg_dir=self.checkpoints_dir,
                         is_best=True,
                         max_checkpoints=self.max_checkpoints)
        if self.global_step % self.plot_gaps == 0:
            img = self.graph(z=z, y_onehot=y_onehot, reverse=True)
            # img = torch.clamp(img, min=0, max=1.0)
            if self.y_condition:
                if self.y_criterion == "multi-classes":
                    y_pred = torch.sigmoid(y_logits)
                elif self.y_criterion == "single-class":
                    y_pred = thops.onehot(torch.argmax(F.softmax(y_logits, dim=1), dim=1, keepdim=True),
                                                  self.y_classes)
                y_true = y_onehot
            for bi in range(min([len(img), 4])):
                self.writer.add_image("0_reverse/{}".format(bi), torch.cat((img[bi], batch["x"][bi]), dim=1), self.global_step)
                if self.y_condition:
                    self.writer.add_image("1_prob/{}".format(bi), plot_prob([y_pred[bi], y_true[bi]], ["pred", "true"]), self.global_step)

        # inference
        if hasattr(self, "inference_gap"):
            if self.global_step % self.inference_gap == 0:
                img = self.graph(z=None, y_onehot=y_onehot, eps_std=0.5, reverse=True)
                # img = torch.clamp(img, min=0, max=1.0)
                for bi in range(min([len(img), 4])):
                    self.writer.add_image("2_sample/{}".format(bi), img[bi], self.global_step)

        # global step
        self.global_step += 1

self.writer.export_scalars_to_json(os.path.join(self.log_dir, "all_scalars.json"))
self.writer.close()

In [None]:
dataset = dataset(dataset_root, transform=transform)
# begin to train
trainer = Trainer(**built, dataset=dataset, hparams=hparams)
trainer.train()