In [1]:
import torch
import time
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as T
import torch.optim as optim
from torchvision.utils import save_image
import pickle
import os
from PIL import Image
%matplotlib inline

# Remove the normalization 


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Parser(object):
    
    ## HyperParams ##
    
    def __init__(self, name ):
        class Type(object):
            pass
        
        self.Glow = Type()
        self.Glow.hidden_channels = 96
        self.Glow.flow_depth = 2
        self.Glow.flow_levels = 3
        self.Glow.actnorm_scale = 1.
        self.Glow.logscale_factor = 3.
        self.Glow.flow_permutation = "invconv"
        self.Glow.flow_coupling = "affine"
        self.Glow.LU_decomposed = False
        self.Glow.learn_top = False
        self.Glow.y_condition = True
        self.Glow.K = 32
        self.Glow.L = 3
        self.Glow.model_name = 'Glow'
                
        self.Data = Type()
        self.Data.name = name
        
        
        self.Optim = Type()
        self.Optim.lr = 1e-3
        self.Optim.beta1 = 0.9
        self.Optim.beta2 = 0.999
        
        self.Train = Type()  
        self.Train.batch_size = 64
        self.Train.n_epoch = 30
        self.Train.with_attention_guide = True
        self.Train.max_grad_clip = 5
        self.Train.max_grad_norm = 100
        self.Train.time_steps_mask = True
        self.Train.scalar_log_gap = 100
        self.Train.weight_y = 0.5
        self.Train.img_show_freq = 50
        self.Train.img_save_freq = 100
        self.Train.model_save_freq = 1000
        
        if self.Data.name == 'mnist':
            self.Glow.in_channels = 1 
            self.Glow.image_shape = (1, 32, 32) 
            self.Glow.y_classes = 10 
            self.Data.img_size = 32 
            
        elif self.Data.name == 'celeba':
            self.Glow.in_channels = 3 
            self.Glow.image_shape = (3, 64, 64) 
            self.Glow.y_classes = 2 
            self.Data.img_size = 64 
            self.Data.train_img_path = "./data/celeba/"
        else:
            raise NotImplementedError
        
hparams = Parser('celeba')

img_path = './'+ hparams.Glow.model_name +'/Img/'
model_path = './'+ hparams.Glow.model_name +'/Model/'
if not os.path.exists(model_path):
    os.makedirs(model_path)
if not os.path.exists(img_path):
    os.makedirs(img_path)


class CelebADataset(Dataset):
    def __init__(self, mode='train'):
        
        self.image_transform = T.Compose([
            T.Resize((hparams.Data.img_size,hparams.Data.img_size)),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        
        with open("Path_Male.pickle", 'rb') as f:
            self.path ,self.male= pickle.load(f)
            print('Loaded')
        self.path = self.path[:-2000] if mode == 'train' else self.path[-2000:]
        self.male = self.male[:-2000] if mode == 'train' else self.male[-2000:]
        
    def __getitem__(self, index):
        
        idx = index % len(self.path)
        img = self.image_transform(Image.open(os.path.join(hparams.Data.train_img_path,
                                                             self.path[idx])))
        
        is_male = self.male[idx]
        return img, is_male

    def __len__(self):
        return len(self.path)

def thops_onehot(y, num_classes):
    
    y_onehot = torch.zeros(y.size(0), num_classes).to(y.device)
    if len(y.size()) == 1:
        y_onehot = y_onehot.scatter_(1, y.unsqueeze(-1), 1)
    elif len(y.size()) == 2:
        y_onehot = y_onehot.scatter_(1, y, 1)
    else:
        raise ValueError("[onehot]: y should be in shape [B], or [B, C]")
    return y_onehot

def thops_sum(tensor, dim=None, keepdim=False):
    if dim is None:
        # sum up all dim
        return torch.sum(tensor)
    else:
        if isinstance(dim, int):
            dim = [dim]
        dim = sorted(dim)
        for d in dim:
            tensor = tensor.sum(dim=d, keepdim=True)
        if not keepdim:
            for i, d in enumerate(dim):
                tensor.squeeze_(d-i)
        return tensor


def thops_mean(tensor, dim=None, keepdim=False):
    if dim is None:
        # mean all dim
        return torch.mean(tensor)
    else:
        if isinstance(dim, int):
            dim = [dim]
        dim = sorted(dim)
        for d in dim:
            tensor = tensor.mean(dim=d, keepdim=True)
        if not keepdim:
            for i, d in enumerate(dim):
                tensor.squeeze_(d-i)
        return tensor


def thops_split_feature(tensor, type="split"):
    """
    type = ["split", "cross"]
    """
    C = tensor.size(1)
    if type == "split":
        return tensor[:, :C // 2, ...], tensor[:, C // 2:, ...]
    elif type == "cross":
        return tensor[:, 0::2, ...], tensor[:, 1::2, ...]


def thops_cat_feature(tensor_a, tensor_b):
    return torch.cat((tensor_a, tensor_b), dim=1)


def thops_pixels(tensor):
    return int(tensor.size(2) * tensor.size(3))


class _ActNorm(nn.Module):
    """
    Activation Normalization
    Initialize the bias and scale with a given minibatch,
    so that the output per-channel have zero mean and unit variance for that.
    After initialization, `bias` and `logs` will be trained as parameters.
    "A trainable normalisation, similar to BN but use less params"
    """
    
    def __init__(self, num_features, scale=1.):
        super().__init__()
        # register mean and scale
        size = [1, num_features, 1, 1]
        self.register_parameter("bias", nn.Parameter(torch.zeros(*size))) # turn it to self.bias = nn.Parameter
        self.register_parameter("logs", nn.Parameter(torch.zeros(*size)))
        self.num_features = num_features
        self.scale = float(scale)
        self.inited = False

    def _check_input_dim(self, input):
        return NotImplemented

    def initialize_parameters(self, input):
        self._check_input_dim(input)
        if not self.training:
            return
        assert input.device == self.bias.device
        with torch.no_grad():
            # Similar to BN
            bias = thops_mean(input, dim=[0, 2, 3], keepdim=True) * -1.0
            vars = thops_mean((input + bias) ** 2, dim=[0, 2, 3], keepdim=True)
            logs = torch.log(self.scale/(torch.sqrt(vars)+1e-6))
            self.bias.data.copy_(bias.data)
            self.logs.data.copy_(logs.data)
            self.inited = True

    def _center(self, input, reverse=False):
        if not reverse:
            return input + self.bias
        else:
            return input - self.bias

    def _scale(self, input, logdet=None, reverse=False):
        logs = self.logs
        if not reverse:
            input = input * torch.exp(logs)
        else:
            input = input * torch.exp(-logs)
        if logdet is not None:
            """
            logs is log_std of `mean of channels`
            so we need to multiply pixels
            """
            dlogdet = thops_sum(logs) * thops_pixels(input)
            if reverse:
                dlogdet *= -1
            logdet = logdet + dlogdet
        return input, logdet

    def forward(self, input, logdet=None, reverse=False):
        if not self.inited:
            self.initialize_parameters(input) # use the first input to initialise the params
        self._check_input_dim(input)
        if not reverse:
            # center and scale
            input = self._center(input, reverse)
            input, logdet = self._scale(input, logdet, reverse)
        else:
            # scale and center
            input, logdet = self._scale(input, logdet, reverse)
            input = self._center(input, reverse)
        return input, logdet


class ActNorm2d(_ActNorm):
    def __init__(self, num_features, scale=1.):
        super().__init__(num_features, scale) # Send the argument 

    def _check_input_dim(self, input):
        assert len(input.size()) == 4
        assert input.size(1) == self.num_features, (
            "[ActNorm]: input should be in shape as `BCHW`,"
            " channels should be {} rather than {}".format(
                self.num_features, input.size()))


class LinearZeros(nn.Linear):
    def __init__(self, in_channels, out_channels, logscale_factor=3):
        super().__init__(in_channels, out_channels)
        self.logscale_factor = logscale_factor
        # set logs parameter
        self.register_parameter("logs", nn.Parameter(torch.zeros(out_channels)))
        # init
        self.weight.data.zero_()
        self.bias.data.zero_()

    def forward(self, input):
        output = super().forward(input) # use the original forward (inherited forward)
        return output * torch.exp(self.logs * self.logscale_factor)


class Conv2d(nn.Conv2d):
    pad_dict = {
        "same": lambda kernel, stride: [((k - 1) * s + 1) // 2 for k, s in zip(kernel, stride)],
        "valid": lambda kernel, stride: [0 for _ in kernel]
    }

    @staticmethod
    def get_padding(padding, kernel_size, stride):
        # make paddding
        if isinstance(padding, str):
            if isinstance(kernel_size, int):
                kernel_size = [kernel_size, kernel_size]
            if isinstance(stride, int):
                stride = [stride, stride]
            padding = padding.lower()
            try:
                padding = Conv2d.pad_dict[padding](kernel_size, stride)
            except KeyError:
                raise ValueError("{} is not supported".format(padding))
        return padding

    def __init__(self, in_channels, out_channels,
                 kernel_size=[3, 3], stride=[1, 1],
                 padding="same", do_actnorm=True, weight_std=0.05):
        padding = Conv2d.get_padding(padding, kernel_size, stride)
        super().__init__(in_channels, out_channels, kernel_size, stride,
                         padding, bias=(not do_actnorm))
        # init weight with std
        self.weight.data.normal_(mean=0.0, std=weight_std) # it can get the params from the inherited class
        if not do_actnorm:
            self.bias.data.zero_()
        else:
            self.actnorm = ActNorm2d(out_channels)
        self.do_actnorm = do_actnorm

    def forward(self, input):
        x = super().forward(input)
        if self.do_actnorm:
            x, _ = self.actnorm(x)
        return x


class Conv2dZeros(nn.Conv2d):
    def __init__(self, in_channels, out_channels,
                 kernel_size=[3, 3], stride=[1, 1],
                 padding="same", logscale_factor=3):
        padding = Conv2d.get_padding(padding, kernel_size, stride)
        super().__init__(in_channels, out_channels, kernel_size, stride, padding)
        # logscale_factor
        self.logscale_factor = logscale_factor
        self.register_parameter("logs", nn.Parameter(torch.zeros(out_channels, 1, 1))) # learnable
        # if the class has the nn.Module base, then it can use self.register_parameter
        # init
        self.weight.data.zero_()
        self.bias.data.zero_()

    def forward(self, input):
        output = super().forward(input)
        return output * torch.exp(self.logs * self.logscale_factor)


class Permute2d(nn.Module):
    '''
    Will be used when the FlowPermutation is shuffle rather than inconv
    Can we just use z = z[:,::-1,:,:]? forward
    '''
    def __init__(self, num_channels, shuffle):
        
    
        super().__init__()
        self.num_channels = num_channels
        self.indices = np.arange(self.num_channels - 1, -1, -1).astype(np.long)
        self.indices_inverse = np.zeros((self.num_channels), dtype=np.long)
        for i in range(self.num_channels):
            self.indices_inverse[self.indices[i]] = i
        if shuffle:
            self.reset_indices()

    def reset_indices(self):
        np.random.shuffle(self.indices)
        for i in range(self.num_channels):
            self.indices_inverse[self.indices[i]] = i

    def forward(self, input, reverse=False):
        assert len(input.size()) == 4
        if not reverse:
            return input[:, self.indices, :, :]
        else:
            return input[:, self.indices_inverse, :, :]
        

class InvertibleConv1x1(nn.Module):
    def __init__(self, num_channels, LU_decomposed=False):
        super().__init__()
        if not LU_decomposed:
            w_shape = [num_channels, num_channels]
            # Sample a random orthogonal matrix:
            w_init = np.linalg.qr(
                np.random.randn(*w_shape))[0].astype(np.float32)
            self.register_parameter("weight",
                                    nn.Parameter(torch.Tensor(w_init)))
        else:
            raise NotImplementedError()

    def forward(self, input, logdet=None, reverse=False):
        """
        log-det = log|abs(|W|)| * pixels
        """
        w_shape = self.weight.size()
        pixels = thops_pixels(input)
        dlogdet = torch.log(torch.abs(torch.det(self.weight))) * pixels
        if not reverse:
            weight = self.weight.view(w_shape[0], w_shape[1], 1, 1)
            z = F.conv2d(input, weight)
            if logdet is not None:
                logdet = logdet + dlogdet
            return z, logdet
        else:
            weight = torch.inverse(self.weight).view(w_shape[0], w_shape[1], 1, 1)
            z = F.conv2d(input, weight)
            if logdet is not None:
                logdet = logdet - dlogdet
            return z, logdet


class GaussianDiag: 
    Log2PI = float(np.log(2 * np.pi)) # a float, not np.array
 
    @staticmethod
    def likelihood(mean, logs, x):
        """
        lnL = -1/2 * { ln|Var| + ((X - Mu)^T)(Var^-1)(X - Mu) + kln(2*PI) }
              k = 1 (Independent)
              Var = logs ** 2
        """
        part0 = logs * 2.
        part1 = ((x - mean) ** 2) / torch.exp(logs * 2.)
        part2 = GaussianDiag.Log2PI
        return -0.5 * (part0 + part1 + part2)

    @staticmethod
    def logp(mean, logs, x):
        '''
        log probability: p(x)
        '''
        likelihood = GaussianDiag.likelihood(mean, logs, x)
        return thops_sum(likelihood, dim=[1, 2, 3])

    @staticmethod
    def sample(mean, logs, eps_std=None):

        mean_size = [int(d) for d in mean.size()]
        if eps_std is not None:
            eps = torch.Tensor(np.random.normal(0, eps_std, mean_size)).to(mean.device)
        else:
            eps = torch.Tensor(np.random.normal(0, 1, mean_size)).to(mean.device)
        return mean + torch.exp(logs) * eps


class Split2d(nn.Module):
    def __init__(self, num_channels):
        '''
        C -> C//2
        '''
        super().__init__()
        self.conv = Conv2dZeros(num_channels // 2, num_channels)

    def split2d_prior(self, z):
        h = self.conv(z)
        return thops_split_feature(h, "cross")

    def forward(self, input, logdet=0., reverse=False, eps_std=None):
        if not reverse:
            z1, z2 = thops_split_feature(input, "split")
            mean, logs = self.split2d_prior(z1)
            logdet = GaussianDiag.logp(mean, logs, z2) + logdet
            return z1, logdet
        else:
            z1 = input
            mean, logs = self.split2d_prior(z1)
            z2 = GaussianDiag.sample(mean, logs, eps_std)
            z = thops_cat_feature(z1, z2)
            return z, logdet


def squeeze2d(input, factor=2):
    assert factor >= 1 and isinstance(factor, int)
    if factor == 1:
        return input
    size = input.size()
    B = size[0]
    C = size[1]
    H = size[2]
    W = size[3]
    assert H % factor == 0 and W % factor == 0, "{}".format((H, W))
    x = input.view(B, C, H // factor, factor, W // factor, factor)
    x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
    x = x.view(B, C * factor * factor, H // factor, W // factor)
    return x


def unsqueeze2d(input, factor=2):
    '''
    F.pixelshuffle(input ,2)
    '''
    assert factor >= 1 and isinstance(factor, int)
    factor2 = factor ** 2
    if factor == 1:
        return input
    size = input.size()
    B = size[0]
    C = size[1]
    H = size[2]
    W = size[3]
    assert C % (factor2) == 0, "{}".format(C)
    x = input.view(B, C // factor2, factor, factor, H, W)
    x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
    x = x.view(B, C // (factor2), H * factor, W * factor)
    return x


class SqueezeLayer(nn.Module):
    def __init__(self, factor):
        super().__init__()
        self.factor = factor
        # no parameters

    def forward(self, input, logdet=None, reverse=False):
        if not reverse:
            output = squeeze2d(input, self.factor)
            return output, logdet
        else:
            output = unsqueeze2d(input, self.factor)
            return output, logdet

def f(in_channels, out_channels, hidden_channels):
    '''
    2 conv layers with 1 conv0
    '''
    return nn.Sequential(
        Conv2d(in_channels, hidden_channels), nn.ReLU(inplace=True), 
        Conv2d(hidden_channels, hidden_channels, kernel_size=[1, 1]), nn.ReLU(inplace=True),
        Conv2dZeros(hidden_channels, out_channels))

 
class FlowStep(nn.Module):
    FlowCoupling = ["additive", "affine"]
    FlowPermutation = {
        "reverse": lambda obj, z, logdet, rev: (obj.reverse(z, rev), logdet),
        "shuffle": lambda obj, z, logdet, rev: (obj.shuffle(z, rev), logdet),
        "invconv": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev)
    }

    def __init__(self, in_channels, hidden_channels,
                 actnorm_scale=1.0,
                 flow_permutation="invconv",
                 flow_coupling="additive",
                 LU_decomposed=False):
        
        '''
        1. actnorm
        2. flow permutation (inconv) InvertibleConv1x1
        3. split z
        5. coupling (addictive) f, z2 += self.f(z1)
    
        '''
        # check configures
        assert flow_coupling in FlowStep.FlowCoupling,\
            "flow_coupling should be in `{}`".format(FlowStep.FlowCoupling)
        assert flow_permutation in FlowStep.FlowPermutation,\
            "float_permutation should be in `{}`".format(
                FlowStep.FlowPermutation.keys())
        super().__init__()
        self.flow_permutation = flow_permutation
        self.flow_coupling = flow_coupling
        # 1. actnorm
        self.actnorm = ActNorm2d(in_channels, actnorm_scale)
        # 2. permute
        if flow_permutation == "invconv":
            self.invconv = InvertibleConv1x1(
                in_channels, LU_decomposed=LU_decomposed)
        elif flow_permutation == "shuffle":
            self.shuffle = Permute2d(in_channels, shuffle=True)
        else:
            self.reverse = Permute2d(in_channels, shuffle=False)
        # 3. coupling
        if flow_coupling == "additive":
            self.f = f(in_channels // 2, in_channels // 2, hidden_channels)
        elif flow_coupling == "affine":
            self.f = f(in_channels // 2, in_channels, hidden_channels)

    def forward(self, input, logdet=None, reverse=False):
        if not reverse:
            return self.normal_flow(input, logdet)
        else:
            return self.reverse_flow(input, logdet)

    def normal_flow(self, input, logdet):
        assert input.size(1) % 2 == 0
        # 1. actnorm
        z, logdet = self.actnorm(input, logdet=logdet, reverse=False)
        # 2. permute
        z, logdet = FlowStep.FlowPermutation[self.flow_permutation](
            self, z, logdet, False)
        # 3. coupling
        z1, z2 = thops_split_feature(z, "split")
        if self.flow_coupling == "additive":
            z2 += self.f(z1)
        elif self.flow_coupling == "affine":
            h = self.f(z1)
            shift, scale = thops_split_feature(h, "cross")
            scale = F.sigmoid(scale + 2.)
            z2 += shift
            z2 *= scale
            logdet = thops_sum(torch.log(scale), dim=[1, 2, 3]) + logdet
        z = thops_cat_feature(z1, z2)
        return z, logdet

    def reverse_flow(self, input, logdet):
        assert input.size(1) % 2 == 0
        # 1.coupling
        z1, z2 = thops_split_feature(input, "split")
        if self.flow_coupling == "additive":
            z2 -= self.f(z1)
        elif self.flow_coupling == "affine":
            h = self.f(z1)
            shift, scale = thops_split_feature(h, "cross")
            scale = F.sigmoid(scale + 2.)
            z2 /= scale
            z2 -= shift
            logdet = -thops_sum(torch.log(scale), dim=[1, 2, 3]) + logdet
        z = thops_cat_feature(z1, z2)
        # 2. permute
        z, logdet = FlowStep.FlowPermutation[self.flow_permutation](
            self, z, logdet, True)
        # 3. actnorm
        z, logdet = self.actnorm(z, logdet=logdet, reverse=True)
        return z, logdet


class FlowNet(nn.Module):
    def __init__(self, image_shape, hidden_channels, K, L,
                 actnorm_scale=1.0,
                 flow_permutation="invconv",
                 flow_coupling="additive",
                 LU_decomposed=False):
        """
                             K                                      K
        --> [Squeeze] -> [FlowStep] -> [Split] -> [Squeeze] -> [FlowStep]
               ^                           v
               |          (L - 1)          |
               + --------------------------+
        """
        super().__init__()
        self.layers = nn.ModuleList()
        self.output_shapes = []
        self.K = K
        self.L = L
        C, H, W = image_shape # change this
        assert C == 1 or C == 3, ("image_shape should be CHW, like (3, 64, 64)"
                                  "C == 1 or C == 3")
        for i in range(L):
            # 1. Squeeze
            C, H, W = C * 4, H // 2, W // 2
            self.layers.append(SqueezeLayer(factor=2))
            self.output_shapes.append([-1, C, H, W])
            # 2. K FlowStep
            for _ in range(K):
                self.layers.append(
                    FlowStep(in_channels=C,
                             hidden_channels=hidden_channels,
                             actnorm_scale=actnorm_scale,
                             flow_permutation=flow_permutation,
                             flow_coupling=flow_coupling,
                             LU_decomposed=LU_decomposed))
                self.output_shapes.append(
                    [-1, C, H, W]) # the flowstep won't change the shape
            # 3. Split2d
            if i < L - 1:
                self.layers.append(Split2d(num_channels=C))
                self.output_shapes.append([-1, C // 2, H, W])
                C = C // 2

    def forward(self, input, logdet=0., reverse=False, eps_std=None):
        if not reverse:
            return self.encode(input, logdet)
        else:
            return self.decode(input, eps_std)

    def encode(self, z, logdet=0.0):
        for layer, shape in zip(self.layers, self.output_shapes): # Why do we need to use the output_shape
            z, logdet = layer(z, logdet, reverse=False)
        return z, logdet

    def decode(self, z, eps_std=None):
        for layer in reversed(self.layers):
            if isinstance(layer, Split2d):
                z, logdet = layer(z, logdet=0, reverse=True, eps_std=eps_std)
            else:
                z, logdet = layer(z, logdet=0, reverse=True)
        return z


class Glow(nn.Module):
    BCE = nn.BCEWithLogitsLoss()
    CE = nn.CrossEntropyLoss()

    def __init__(self, hparams):
        super().__init__()
        self.flow = FlowNet(image_shape=hparams.Glow.image_shape,
                            hidden_channels=hparams.Glow.hidden_channels,
                            K=hparams.Glow.K,
                            L=hparams.Glow.L,
                            actnorm_scale=hparams.Glow.actnorm_scale,
                            flow_permutation=hparams.Glow.flow_permutation,
                            flow_coupling=hparams.Glow.flow_coupling,
                            LU_decomposed=hparams.Glow.LU_decomposed)
        self.hparams = hparams
        self.y_classes = hparams.Glow.y_classes
        # for prior
        if hparams.Glow.learn_top:
            C = self.flow.output_shapes[-1][1]
            self.learn_top = Conv2dZeros(C * 2, C * 2) 
        if hparams.Glow.y_condition:
            C = self.flow.output_shapes[-1][1]
            self.project_ycond = LinearZeros(
                hparams.Glow.y_classes, 2 * C)
            self.project_class = LinearZeros(
                C, hparams.Glow.y_classes)
        # register prior hidden

        self.register_parameter(
            "prior_h",
            nn.Parameter(torch.zeros([hparams.Train.batch_size,
                                      self.flow.output_shapes[-1][1] * 2,  # the last output shape
                                      self.flow.output_shapes[-1][2],
                                      self.flow.output_shapes[-1][3]])))
        
        # Q: The shape of prior_h is the same as the last output? self.flow.output_shapes[-1]

    def prior(self, y_onehot=None):
        # prior distribution
        B, C = self.prior_h.size(0), self.prior_h.size(1)
        h = self.prior_h.detach().clone() # prior_h won't be changed
        assert torch.sum(h) == 0.0 # the h will
        if self.hparams.Glow.learn_top:
            h = self.learn_top(h) # Conv2dZeros(C * 2, C * 2) add one more layer, learn the presentation of h
        if self.hparams.Glow.y_condition:
            assert y_onehot is not None
            yp = self.project_ycond(y_onehot).view(B, C, 1, 1) # LinearZeros(hparams.Glow.y_classes, 2 * C)
            h += yp
        return thops_split_feature(h, "split")

    def forward(self, x=None, y_onehot=None, z=None,
                eps_std=None, reverse=False):
        if not reverse:
            return self.normal_flow(x, y_onehot)
        else:
            return self.reverse_flow(z, y_onehot, eps_std)

    def normal_flow(self, x, y_onehot):
        '''
        Forward:
        
        1. add noise
        2. initialise the logdet
        3. forward flow
        
        
        '''
        pixels = thops_pixels(x)
        z = x + torch.normal(mean=torch.zeros_like(x),  # add noise to the input.
                             std=torch.ones_like(x) * (1. / 256.))  
        logdet = torch.zeros_like(x[:, 0, 0, 0]) # Create an initialised logdet
        logdet += float(-np.log(256.) * pixels)
        # encode
        z, objective = self.flow(z, logdet=logdet, reverse=False)
        # prior <The meaning of prior distribution>, A: create mean and std
        mean, logs = self.prior(y_onehot) # do prior after self.flow 
        objective += GaussianDiag.logp(mean, logs, z) # log det + log p(x)


        if self.hparams.Glow.y_condition:
            y_logits = self.project_class(z.mean(2).mean(2))
        else:
            y_logits = None

        # return
        nll = (-objective) / float(np.log(2.) * pixels) # Q: nll is -logdet?
        return z, nll, y_logits

    def reverse_flow(self, z, y_onehot, eps_std):
        with torch.no_grad():
            mean, logs = self.prior(y_onehot)
            if z is None:
                z = GaussianDiag.sample(mean, logs, eps_std)
            x = self.flow(z, eps_std=eps_std, reverse=True)
        return x

    def set_actnorm_init(self, inited=True):
        for name, m in self.named_modules():
            if (m.__class__.__name__.find("ActNorm") >= 0):
                m.inited = inited
                
    def sample(self, temperatures = [0.25]):
        #[0., .25, .5, .6, .7, .8, .9, 1.]
        x_sample = []
        y = ((list(np.arange(0,hparams.Glow.y_classes))*\
              ((hparams.Train.batch_size)//hparams.Glow.y_classes + 1)))[:hparams.Train.batch_size]
        y = torch.LongTensor(y).to(device)
        for i in temperatures:
            x_sample.append(self.reverse_flow(None,
                                              thops_onehot(y, hparams.Glow.y_classes)[:64].to(device),
                                              i))
        return torch.stack(x_sample)                
    

    @staticmethod
    def loss_generative(nll):
        # Generative loss
        return torch.mean(nll)

    @staticmethod
    def loss_multi_classes(y_logits, y_onehot):
        if y_logits is None:
            return 0
        else:
            return Glow.BCE(y_logits, y_onehot.float())

    @staticmethod
    def loss_class(y_logits, y):
        if y_logits is None:
            return 0
        else:
            return Glow.CE(y_logits, y.long())
  
    def image_save(self, img, step):
        
        path = './' + hparams.Glow.model_name+ '/Img/' + hparams.Glow.model_name +'_Step_' + str(step) + '.png'
        save_image( img, path , nrow=8, normalize=True, range=(-1,1))
        print('Image saved')  
 
    def model_save(self, step):
        path = './' + hparams.Glow.model_name+ '/Model/' + hparams.Glow.model_name +'_Step_' + str(step) + '.pth'
        torch.save({hparams.Glow.model_name :self.state_dict()}, path)
        print('Model saved')
   
    def load_step_dict(self,step):
        
        path = './' + hparams.Glow.model_name+ '/Model/' + hparams.Glow.model_name +'_Step_' + str(step) + '.pth'
        self.load_state_dict(torch.load(path, map_location=lambda storage, loc: storage)[hparams.Glow.model_name])
        
    def num_all_params(self,):
        return sum([param.nelement() for param in self.parameters()])
    

    def plot_all_loss(self, train_hist, step):

        fig, ax = plt.subplots(figsize= (20,8))
        for k in train_hist.keys():
            plt.plot(train_hist[k], label= k)
        plt.ylabel('Loss',fontsize=15)
        plt.xlabel('Number of Steps',fontsize=15)
        plt.title('Loss',fontsize=30,fontweight ="bold")
        plt.legend(loc = 'upper left')
        fig.savefig( hparams.Glow.model_name +"_Loss_"+str(step)+".png")

In [2]:
to_img= T.Compose([T.ToPILImage()])
to_tensor = T.Compose([T.ToTensor()])
load_norm = T.Compose([#T.CenterCrop(hparams.Data.img_size,hparams.Data.img_size),
                       T.Resize((hparams.Data.img_size,hparams.Data.img_size)),
                       T.ToTensor(),
                       T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                      ])

if hparams.Data.name == 'celeba':
    dataset = CelebADataset(mode='train')
    training_loader = DataLoader(dataset,batch_size=hparams.Train.batch_size,shuffle=True,drop_last=True,pin_memory=True)
elif hparams.Data.name == 'mnist':
    training_loader = DataLoader(datasets.MNIST('./data/mnist',train=True,
                    download=True,transform=load_norm),
                    batch_size=hparams.Train.batch_size, shuffle=True,drop_last=True)

Loaded


In [3]:
epoch = 0
all_steps = 1
train_hist = {}
train_hist['G_Loss'] = []
train_hist['C_Loss'] = []
glow = Glow(hparams).to(device)

In [4]:
optimizer = optim.Adam(glow.parameters(), lr = hparams.Optim.lr, betas=(hparams.Optim.beta1, hparams.Optim.beta2))
scheduler = optim.lr_scheduler.StepLR(optimizer,10000,0.5)

In [5]:
# optimizer.param_groups[0]['lr']

In [6]:
while epoch < hparams.Train.n_epoch:
    for i,(img, y) in enumerate(training_loader):
        start_t = time.time()
        optimizer.zero_grad()
        img = img.to(device)
        y = y.to(device)
        #scheduler.step()
        z, nll, y_logits = glow(img.to(device), thops_onehot(y, hparams.Glow.y_classes).to(device))
        loss_generative = glow.loss_generative(nll)
        loss_classes = glow.loss_class(y_logits, y)
        loss = loss_generative + loss_classes * hparams.Train.weight_y
        loss.backward()
        optimizer.step()
        end_t = time.time()
        train_hist['G_Loss'].append(loss_generative.item())
        train_hist['C_Loss'].append(loss_classes.item())
        
        print('| Step [%d] | lr [%.4f] | G_Loss: [%.3f] | C_Loss: [%.3f] | Time: %.1fs' %\
              ( all_steps, optimizer.param_groups[0]['lr'], loss_generative.item(), loss_classes.item(),
               end_t - start_t))
        
        if all_steps % hparams.Train.img_show_freq == 0: #hparams.Train.img_show_freq
            fig = plt.figure(figsize=(8, 8))
            fig.add_subplot(1,3,1)
            sample_img = glow.sample().squeeze(0)
            plt.imshow(to_img(sample_img[0].cpu()*0.5+0.5))
            plt.show()
            if all_steps % hparams.Train.img_save_freq == 0: # hparams.Train.img_save_freq
                glow.plot_all_loss(train_hist, 'Training')
                glow.image_save(sample_img,all_steps)
                if all_steps % hparams.Train.model_save_freq == 0: # hparams.Train.model_save_freq
                    glow.model_save(all_steps)
                    
        all_steps += 1
        if all_steps > 1000:
            raise StopIteration
    epoch +=1


| Step [1] | lr [0.0010] | G_Loss: [3.393] | C_Loss: [0.693] | Time: 102.6s
| Step [2] | lr [0.0010] | G_Loss: [33.711] | C_Loss: [0.686] | Time: 93.6s


KeyboardInterrupt: 