In [None]:
import os
import torch
import torch.nn as nn
import torch
import cv2
import numpy as np
import torch.optim as optim
import torch.distributions as distributions
import torchvision
import torchvision.utils as utils 
import torchvision.transforms as transforms
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
%mkdir samples

In [None]:
class res_block(torch.nn.Module):
    def __init__(self,dim):
        super(res_block,self).__init__()
        self.layer1=nn.BatchNorm2d(dim)
        self.residual_block_layer1=nn.utils.weight_norm(nn.Conv2d(dim,dim,(1,1),stride=1,padding=0,bias=False))
        #needs the scale weight to be fixed
        self.residual_block_layer2= nn.BatchNorm2d(dim)
        self.residual_block_layer3=nn.utils.weight_norm(nn.Conv2d(dim,dim,(3,3),stride=1,padding=1,bias=False))
        #needs the scale weight to be fixed
        self.residual_block_layer4=nn.BatchNorm2d(dim)
        self.residual_block_layer5=nn.utils.weight_norm(nn.Conv2d(dim,dim,(1,1),stride=1,padding=0,bias=True))
    
    def forward(self,x):
        self.residual_block_layer1.weight_g.data=torch.ones_like(self.residual_block_layer1.weight_g)
        self.residual_block_layer1.weight_g.requires_grad=False

        
        self.residual_block_layer3.weight_g.data=torch.ones_like(self.residual_block_layer3.weight_g)
        self.residual_block_layer3.weight_g.requires_grad=False

        op=self.layer1(x)
        op=F.relu(op)
        
        op=self.residual_block_layer1(op)
        op=self.residual_block_layer2(op)
        op=F.relu(op)
        
        op=self.residual_block_layer3(op)
        op=self.residual_block_layer4(op)
        op=F.relu(op)
        
        op=self.residual_block_layer5(op)
        return (x+op)

In [None]:
class residual_module(torch.nn.Module):
    def __init__(self,in_dim,res_dim,out_dim):
        super(residual_module,self).__init__()
        self.layer1_1=nn.utils.weight_norm(nn.Conv2d(in_dim,res_dim,(3,3),stride=1,padding=1,bias=True))
        #needs scale weight to be fixed
        self.layer1_2=res_block(res_dim)
        self.layer1_3=res_block(res_dim)
        self.layer1_4=res_block(res_dim)
        self.layer1_5=res_block(res_dim)
        self.layer1_6=nn.BatchNorm2d(res_dim)
        self.layer1_7=nn.utils.weight_norm(nn.Conv2d(res_dim,out_dim,(1,1),stride=1,padding=0,bias=True))
        
        self.layer2_1=nn.utils.weight_norm(nn.Conv2d(res_dim,res_dim,(1,1),stride=1,padding=0,bias=True))
        self.layer2_2=nn.utils.weight_norm(nn.Conv2d(res_dim,res_dim,(1,1),stride=1,padding=0,bias=True))
        self.layer2_3=nn.utils.weight_norm(nn.Conv2d(res_dim,res_dim,(1,1),stride=1,padding=0,bias=True))
        self.layer2_4=nn.utils.weight_norm(nn.Conv2d(res_dim,res_dim,(1,1),stride=1,padding=0,bias=True))
        self.layer2_5=nn.utils.weight_norm(nn.Conv2d(res_dim,res_dim,(1,1),stride=1,padding=0,bias=True))
    
    def forward(self,x):
        self.layer1_1.weight_g.data=torch.ones_like(self.layer1_1.weight_g)
        self.layer1_1.weight_g.requires_grad=False

        x=self.layer1_1(x)
        x1=self.layer2_1(x)
        
        x=self.layer1_2(x)
        x1=x1+self.layer2_2(x)
        
        x=self.layer1_3(x)
        x1=x1+self.layer2_3(x)
        
        x=self.layer1_4(x)
        x1=x1+self.layer2_4(x)
        
        x=self.layer1_5(x)
        x1=x1+self.layer2_5(x)
        
        x=x1
        x=self.layer1_6(x)
        x=F.relu(x)
        
        op=self.layer1_7(x)
        return(op)

NameError: ignored

In [None]:
def checkerboard_mask(config,size):
    if config == 1:
        mask=np.ones((size,size))
        mask[1::2,::2]=0
        mask[::2,1::2]=0
    else:
        mask=np.zeros((size,size))
        mask[1::2,::2]=1
        mask[::2,1::2]=1
    mask=mask.reshape(-1,1,size,size)
    return (torch.tensor(mask.astype(np.float32)))

def mean_var(x):
  mean=torch.mean(x, dim=(0, 2, 3), keepdim=True)
  var=torch.mean((x - mean)**2, dim=(0, 2, 3), keepdim=True)#[E[x-E[x]]^2]
  return mean,var

In [None]:
p=nn.BatchNorm2d(5)

NameError: ignored

In [None]:
p.running_mean

In [None]:
class checkerboard_coupling(torch.nn.Module):
    def __init__(self,in_dim,mid_dim,out_dim,size,config):
        super(checkerboard_coupling,self).__init__()
        self.size=size
        self.config=config
        self.mask=checkerboard_mask(self.config,self.size).to(device)
        #try again without scale shifting the s output
        self.scale=torch.nn.Parameter(torch.zeros(1),requires_grad=True)
        self.scale_shift=torch.nn.Parameter((torch.zeros(1)),requires_grad=True)
        self.layer1=nn.BatchNorm2d(in_dim)
        self.layer2=residual_module(2*in_dim+1,mid_dim,2*out_dim)
        self.layer3=nn.BatchNorm2d(out_dim,affine=False)
        
        
    def forward(self,x):
        mask=self.mask.repeat(x.size(0),1,1,1)
        x1=self.layer1(x*mask)
        #offcial implementation does this,not sure why?
        x1=torch.cat((x1,-x1),dim=1)
        x1=torch.cat((x1,mask),dim=1)
        #####################################
        x1=F.relu(x1)
        res_op=self.layer2(x1)
        t,s=res_op.split(x.size(1),dim=1)
        s=self.scale*torch.tanh(s) + self.scale_shift
        #s=self.scale*torch.tanh(s)
        
        s=s*(1.-mask)
        t=t*(1.-mask)
        
        log_det_jacobian=s
        x=x*torch.exp(s) + t
        if self.training:
          #print("training_ch")
          mean,var=mean_var(x)
        else:
          #print("val_ch")
          var=self.layer3.running_var
          var=var.reshape(-1, 1, 1, 1).transpose(0, 1)
        x=self.layer3(x) * (1. - mask) + x * mask
        log_det_jacobian=log_det_jacobian - 0.5 * torch.log(var + 1e-5) * (1. - mask)
        #try adding a batch-norm layer to output of the coupling_layer
        return (x,log_det_jacobian)
        
    def backward(self,x):
        mask=self.mask.repeat(x.size(0),1,1,1)
        x1=self.layer1(x*mask)
        #offcial implementation does this,not sure why?
        x1=torch.cat((x1,-x1),dim=1)
        x1=torch.cat((x1,mask),dim=1)
        #####################################
        x1=F.relu(x1)
        res_op=self.layer2(x1)
        t,s=res_op.split(x.size(1),dim=1)
        s=self.scale*torch.tanh(s) + self.scale_shift
        #s=self.scale*torch.tanh(s)
        
        s=s*(1.-mask)
        t=t*(1.-mask)
        
        log_det_jacobian=s
        mean,var=self.layer3.running_mean,self.layer3.running_var
        mean=mean.reshape(-1, 1, 1, 1).transpose(0, 1)
        var=var.reshape(-1, 1, 1, 1).transpose(0, 1)
        x = x * torch.exp(0.5 * torch.log(var + 1e-5) * (1. - mask)) + mean * (1. - mask)
        x=(x-t)*torch.exp(-s)
        #try adding a batch-norm layer to output of the coupling_layer 
        return (x,log_det_jacobian)

In [None]:
#The input to the channel-wise masking layer will be a sqeezed tensor with size=(b,c*4,h/2,w/2)    
class channel_coupling(torch.nn.Module):
    def __init__(self,in_dim,mid_dim,out_dim,config=1):
        super(channel_coupling,self).__init__()
        self.config=config
        self.scale=torch.nn.Parameter(torch.zeros(1),requires_grad=True)
        self.scale_shift=torch.nn.Parameter((torch.zeros(1)),requires_grad=True)
        self.layer1=nn.BatchNorm2d(in_dim//2)
        self.layer2=residual_module(in_dim,mid_dim,out_dim)
        self.layer3=nn.BatchNorm2d(int(in_dim/2),affine=False)

        
    def forward(self,x):
        if self.config == 1:
            x1,x2=x.split(int(x.size(1)/2),dim=1)
        else:
            x2,x1=x.split(int(x.size(1)/2),dim=1)
        x_2=self.layer1(x2)
        x_2=torch.cat((x_2,-x_2),dim=1)
        x_2=F.relu(x_2)
        res_op=self.layer2(x_2)
        t,s=res_op.split(int(x.size(1)/2),dim=1)
        
        s=self.scale*torch.tanh(s) + self.scale_shift
        log_det_jacobian_tr=s
        
        x1=x1*torch.exp(s) + t
        if self.training:
          #print("training_chan")
          mean,var=mean_var(x1)
        else:
          #print("val_chan")
          var=self.layer3.running_var
          var=var.reshape(-1, 1, 1, 1).transpose(0, 1)
        x1=self.layer3(x1)
        log_det_jacobian_tr=log_det_jacobian_tr - 0.5 * torch.log(var + 1e-5)
        if self.config == 1 :
            op=torch.cat((x1,x2),dim=1)
            #x1 is transformed x2 is left unchnaged
            log_det_jacobian=torch.cat((log_det_jacobian_tr,torch.zeros(log_det_jacobian_tr.shape).to(device)),dim=1)
        else:
            op=torch.cat((x2,x1),dim=1)
            #x1 is transformed x2 is left unchnaged
            log_det_jacobian=torch.cat((torch.zeros(log_det_jacobian_tr.shape).to(device),log_det_jacobian_tr),dim=1)
        return (op,log_det_jacobian)
    
    def backward(self,x):
        if self.config == 1:
            x1,x2=x.split(int(x.size(1)/2),dim=1)
        else:
            x2,x1=x.split(int(x.size(1)/2),dim=1)
        x_2=self.layer1(x2)
        x_2=torch.cat((x_2,-x_2),dim=1)
        x_2=F.relu(x_2)
        res_op=self.layer2(x_2)
        t,s=res_op.split(x.size(1)//2,dim=1)
        
        s=self.scale*torch.tanh(s) + self.scale_shift
        log_det_jacobian_tr=s
        
        mean, var = self.layer3.running_mean, self.layer3.running_var
        mean=mean.reshape(-1, 1, 1, 1).transpose(0, 1)
        var=var.reshape(-1, 1, 1, 1).transpose(0, 1)
        x1=x1 * torch.exp(0.5 * torch.log(var + 1e-5)) + mean
        x1=(x1-t) * torch.exp(-s)
        if self.config == 1 :
            op=torch.cat((x1,x2),dim=1)
            #print("op:",op.shape)
            #x1 is transformed x2 is left unchnaged
            log_det_jacobian=torch.cat((log_det_jacobian_tr,torch.zeros(log_det_jacobian_tr.shape).to(device)),dim=1)
        else:
            op=torch.cat((x2,x1),dim=1)
            #print("op:",op.shape)
            #x1 is transformed x2 is left unchnaged
            log_det_jacobian=torch.cat((torch.zeros(log_det_jacobian_tr.shape).to(device),log_det_jacobian_tr),dim=1)
            #print(log_det_jacobian.shape)
        return (op,log_det_jacobian)

In [None]:
class real_nvp(torch.nn.Module):
    def __init__(self,prior_dist,in_dim,mid_dim,out_dim,size):
        super(real_nvp,self).__init__()
        self.prior=prior_dist
        self.checkerboard_coupling_layer1=nn.ModuleList([checkerboard_coupling(in_dim,mid_dim,out_dim,size,config=1),
                                                      checkerboard_coupling(in_dim,mid_dim,out_dim,size,config=0),
                                                      checkerboard_coupling(in_dim,mid_dim,out_dim,size,config=1)])
        
        self.channel_coupling_layer1=nn.ModuleList([channel_coupling(in_dim*4,mid_dim*2,out_dim*4,config=0),
                                             channel_coupling(in_dim*4,mid_dim*2,out_dim*4,config=1),
                                             channel_coupling(in_dim*4,mid_dim*2,out_dim*4,config=0)])
        
        self.order_matrix_1 = self.order_matrix(in_dim).to(device)
        #print(type(self.order_matrix_1))

        in_dim=in_dim*2
        out_dim=out_dim*2
        mid_dim=mid_dim*2
        size=size//2
        
        self.checkerboard_coupling_layer2=nn.ModuleList([checkerboard_coupling(in_dim,mid_dim,out_dim,size,config=1),
                                                         checkerboard_coupling(in_dim,mid_dim,out_dim,size,config=0),
                                                         checkerboard_coupling(in_dim,mid_dim,out_dim,size,config=1),
                                                         checkerboard_coupling(in_dim,mid_dim,out_dim,size,config=0)])
        
        
    def squeeze(self,x):
        [B, C, H, W] = list(x.size())
        x = x.reshape(B, C, H//2, 2, W//2, 2)
        x = x.permute(0, 1, 3, 5, 2, 4)
        x = x.reshape(B, C*4, H//2, W//2)
        return x
    
    def unsqueeze(self,x):
        b,c,h,w=x.size()
        x=x.reshape(b,c//4,2,2,h,w)
        x=x.permute(0,1,4,2,5,3)
        x=x.reshape(b,c//4,h*2,w*2)
        return (x)
    
    #Reference : https://github.com/tensorflow/models/blob/master/research/real_nvp/real_nvp_utils.py
    def order_matrix(self,channel):
        weights = np.zeros((channel*4, channel, 2, 2))
        ordering = np.array([[[[1., 0.],
                               [0., 0.]]],
                             [[[0., 0.],
                               [0., 1.]]],
                             [[[0., 1.],
                               [0., 0.]]],
                             [[[0., 0.],
                               [1., 0.]]]])
        for i in range(channel):
            s1 = slice(i, i+1)
            s2 = slice(4*i, 4*(i+1))
            weights[s2, s1, :, :] = ordering
        shuffle = np.array([4*i for i in range(channel)]
                         + [4*i+1 for i in range(channel)]
                         + [4*i+2 for i in range(channel)]
                         + [4*i+3 for i in range(channel)])
        weights = weights[shuffle, :, :, :].astype('float32')
        return torch.tensor(weights)
    
    
    def inference(self,x):
        #for every scale 3 checkerboard coupling layer ==> squeeze tensor ==> channel_masking ==> unsqueeze ==>factor out half 
        #the dimensions 
        #scale 1
        z=x
        log_det_jacobian=torch.zeros(z.shape).to(device)
        for i in range(len(self.checkerboard_coupling_layer1)):
            z,jacobian=self.checkerboard_coupling_layer1[i].forward(z)
            log_det_jacobian+=jacobian

        z=self.squeeze(z)
        log_det_jacobian=self.squeeze(log_det_jacobian)

        for i in range(len(self.channel_coupling_layer1)):
            z,jacobian=self.channel_coupling_layer1[i].forward(z)
            log_det_jacobian=log_det_jacobian + jacobian

        #print(type(z))
        z=self.unsqueeze(z)
        log_det_jacobian=self.unsqueeze(log_det_jacobian)

        #print(type(z))
        y=F.conv2d(z,self.order_matrix_1,stride=2,padding=0)
        z,z1=y.split(int(y.size(1)/2),dim=1)

        y=F.conv2d(log_det_jacobian,self.order_matrix_1,stride=2,padding=0)
        log_det_jacobian,log_det_jacobian1=y.split(int(y.size(1)/2),dim=1)

        #scale 2
        for i in range(len(self.checkerboard_coupling_layer2)):
            z,jacobian=self.checkerboard_coupling_layer2[i].forward(z)
            log_det_jacobian+=jacobian

        final_op=torch.cat((z,z1),dim=1)
        final_op=F.conv_transpose2d(final_op,self.order_matrix_1,stride=2, padding=0)

        final_det_jacobian=torch.cat((log_det_jacobian,log_det_jacobian1),dim=1)
        final_det_jacobian=F.conv_transpose2d(final_det_jacobian,self.order_matrix_1,stride=2, padding=0)

        return (final_op,final_det_jacobian)
    
    def sampling(self,z):
        #
        #reverse order of infernece 
        y=F.conv2d(z,self.order_matrix_1,stride=2,padding=0)
        x,x1=y.split(int(y.size(1)/2),dim=1)
        
        for coup_layer in self.checkerboard_coupling_layer2[::-1]:
            x,_=coup_layer.backward(x)
        
        x=torch.cat((x,x1),dim=1)
        x=F.conv_transpose2d(x,self.order_matrix_1,stride=2,padding=0)
        
        x=self.squeeze(x)
        
        for coup_layer in self.channel_coupling_layer1[::-1]:
            x,_=coup_layer.backward(x)
        
        x=self.unsqueeze(x)
        for coup_layer in self.checkerboard_coupling_layer1[::-1]:
            x,_=coup_layer.backward(x)
        
        return (x)
    
    def likelihood(self,x):
        #log(p(x))=log(ph(f(x))+log(sii)
        x_,det_jacobian=self.inference(x)
        assert (x_.shape == det_jacobian.shape)
        det_jacobian=torch.sum(det_jacobian,dim=(1,2,3))
        log_likelihood=torch.sum(self.prior.log_prob(x_),dim=(1,2,3))
        return(log_likelihood+det_jacobian)
    
    def sample_images(self,number,channel,width,heigth):
        z=self.prior.sample((number,channel,width,heigth))
        x=self.sampling(z)
        return (x)
    
    def forward(self,x):
        weight_scale=0
        o=[]
        for name,parameter in self.named_parameters():
            p_name=name.split(".")[-1]
            if (p_name == "weight_g" or p_name == "scale") and parameter.requires_grad==True:
                weight_scale+=torch.pow(parameter,2).sum()
        x_ll=self.likelihood(x)
        return (x_ll,weight_scale)

In [None]:
mean=torch.tensor(0.).to(device)
std_dev=torch.tensor(1.).to(device)
prior_dist=distributions.Normal(mean,std_dev)
#prior_dist=prior_dist.to(device)
r=torch.randn(8,1,28,28).to(device)
x=real_nvp(prior_dist,1,64,1,28).to(device)
print("ll",x.forward(r)[0].shape)
print("generated_images",x.sample_images(5,1,28,28).shape)
plt.imshow(x.sample_images(5,1,28,28)[0,0,:,:].cpu().detach().numpy(),cmap='gray')
print(x.sample_images(5,1,28,28).dtype)

In [None]:
transforms_train = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.ToTensor()])
transforms_val=transforms.ToTensor()
batch_size=32
training_data=torchvision.datasets.CIFAR10(root='torch/data/cifar10',train=True, download=True, transform=transforms_train)
train_loader=torch.utils.data.DataLoader(training_data,batch_size=batch_size, shuffle=True, num_workers=8)

validation_data=torchvision.datasets.CIFAR10(root='torch/data/cifar10',train=False,download=True,transform=transforms_val)
val_loader=torch.utils.data.DataLoader(validation_data,batch_size=batch_size,shuffle=True,num_workers=8)

In [None]:
mean=torch.tensor(0.).to(device)
std_dev=torch.tensor(1.).to(device)
prior_dist=distributions.Normal(mean,std_dev)
flow_model=real_nvp(prior_dist,in_dim=3,mid_dim=64,out_dim=3,size=32).to(device)
momentum=0.9
decay=0.999
lr=1e-3
optimizer=optim.Adamax(flow_model.parameters(),lr=lr,betas=(momentum,decay), eps=1e-7)

In [None]:
print(r.shape)
noise=torch.distributions.Uniform(0,1).sample(r.shape)
print(noise.shape)

In [None]:
max_epoch=5000
scale_reg=5e-5
epoch=0
avg_loss=0.
avg_likelihood=0.
while epoch < max_epoch:
  epoch=epoch+1
  print("Epoch np:",epoch)
  flow_model.train()
  for  batch_idx, data in enumerate(train_loader):
    optimizer.zero_grad()
    img,label=data[0],data[1]
    noise = distributions.Uniform(0., 1.).sample(img.shape)
    img=(img *255.+noise) /256.
    img*= 2.             # [0, 2]
    img-= 1.             # [-1, 1]
    img*=0.9     # [-0.9, 0.9]
    img+= 1.             # [0.1, 1.9]
    img/=2.             # [0.05, 0.95]
    logit_img=torch.log(img) - torch.log(1. - img)
    pre_logit_scale=torch.tensor(np.log(0.9) - np.log(1. - 0.9))
    log_jacobian=F.softplus(logit_img) + F.softplus(-logit_img)-F.softplus(-pre_logit_scale)
    log_jacobian=torch.sum(log_jacobian,dim=(1,2,3))
    img=torch.log(img)-torch.log(1.-img)
    img=img.to(device)
    log_jacobian=log_jacobian.to(device)

    batch_likelihood,weight_scale=flow_model.forward(img)
    # print(batch_likelihood.shape)
    # print(log_jacobian.shape)
    log_likelihood = (batch_likelihood + log_jacobian).mean()

    loss=-log_likelihood + scale_reg*weight_scale
    avg_loss+=loss.item()
    avg_likelihood+=log_likelihood.item()
    
    loss.backward()
    optimizer.step()

    if batch_idx% 1000 == 0:
      # print('[%d/%d]\tloss: %.3f\tlog-ll: %.3f' % \
      #               (batch_idx*batch_size, len(train_loader.dataset), 
      #                   loss.item(), log_ll.item(), bit_per_dim))
      print("[%d/%d]\tloss: %.3f\tlog-ll: %.3f" % (batch_idx*batch_size, len(train_loader.dataset),loss.item(),log_likelihood.item()))

  mean_loss=avg_loss/batch_idx
  mean_likelihood=avg_likelihood/batch_idx

  print('===> Average train loss: %.3f' % mean_loss)
  print('===> Average train log-likelihood: %.3f' % mean_likelihood)
  avg_loss=0
  avg_likelihood=0

  flow_model.eval()
  with torch.no_grad():
    for  batch_idx, data in enumerate(val_loader):
      #optimizer.zero_grad()
      img,label=data[0],data[1]
      noise = distributions.Uniform(0., 1.).sample(img.shape)
      img=(img *255.+noise) /256.
      img*= 2.             # [0, 2]
      img-= 1.             # [-1, 1]
      img*=0.9     # [-0.9, 0.9]
      img+= 1.             # [0.1, 1.9]
      img/=2.             # [0.05, 0.95]
      logit_img=torch.log(img) - torch.log(1. - img)
      pre_logit_scale=torch.tensor(np.log(0.9) - np.log(1. - 0.9))
      log_jacobian=F.softplus(logit_img) + F.softplus(-logit_img)-F.softplus(-pre_logit_scale)
      log_jacobian=torch.sum(log_jacobian,dim=(1,2,3))
      img=torch.log(img)-torch.log(1.-img)
      img=img.to(device)
      log_jacobian=log_jacobian.to(device)

      batch_likelihood,weight_scale=flow_model.forward(img)
      # print(batch_likelihood.shape)
      # print(log_jacobian.shape)
      log_likelihood = (batch_likelihood + log_jacobian).mean()

      loss=-log_likelihood + scale_reg*weight_scale
      avg_loss+=loss.item()
      avg_likelihood+=log_likelihood.item()
    mean_loss=avg_loss/batch_idx
    mean_likelihood=avg_likelihood/batch_idx

    print('===> Average validation loss: %.3f' % mean_loss)
    print('===> Average validation log-likelihood: %.3f' % mean_likelihood)
    avg_loss=0
    avg_likelihood=0


    gen_images=flow_model.sample_images(64,3,32,32)
    gen_images=1./(torch.exp(-gen_images) + 1.)    # [0.05, 0.95]
    gen_images*=2.             # [0.1, 1.9]
    gen_images-=1.             # [-0.9, 0.9]
    gen_images/=0.9     # [-1, 1]
    gen_images+=1.             # [0, 2]
    gen_images/=2.             # [0, 1]
    
    print(gen_images.dtype)
    print(type(gen_images))
    utils.save_image(utils.make_grid(torch.tensor(gen_images.clone().detach())),'./samples/'+'_ep%d.png' % epoch)