In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from sklearn import datasets,cluster,mixture
import numpy as np
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import matplotlib.pyplot as plt
from google.colab import files
%matplotlib inline

In [6]:
! nvidia-smi

Mon Sep 14 23:43:12 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.67       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   38C    P8     9W /  70W |      0MiB / 15079MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru

In [None]:
print(torch.cuda.is_available())
%mkdir new_samples

True


In [None]:
class coupling_layer(torch.nn.Module):
    def __init__(self,in_dim,hidden_dim,hidden_layers,mask):
        super(coupling_layer, self).__init__()

        self.mask=mask
        self.hidden_layers=hidden_layers
        self.layer1=torch.nn.Linear(int(in_dim/2),hidden_dim)
        self.layer2=torch.nn.Linear(hidden_dim,hidden_dim)
        self.layer3=torch.nn.Linear(hidden_dim,int(in_dim/2))
        
    def forward(self,x,backward=False):
        #divide the flattened image tensor into 2 parts:
        #1:The first half dimensions (1:d/2)
        #2:The second half dimensions (d/2:D)
        #esentially gives us 2 half-images
        x=x.reshape((x.size()[0],x.size()[1]//2,2))
        first=x[:,:,0]
        second=x[:,:,1]
        if self.mask=="odd":
            x1=first
            x2=second
        else:
            x1=second
            x2=first
        #return (x2)
        mx2=self.layer1(x2)
        mx2=F.relu(mx2)
        for i in range(self.hidden_layers):
          mx2=self.layer2(mx2)
          mx2=F.relu(mx2)
        mx2=self.layer3(mx2)
        #mx2=F.relu(mx2)
        
        y1=x2#y1=x1:d/2
        #y2=x1:d/2+m(xd/2+1:D)
        if backward==True:
            y2=x1-mx2
        else:
            y2=x1+mx2
        if self.mask=="odd":
            x=torch.stack((y2,y1),dim=2)
        else:
            x=torch.stack((y1,y2),dim=2)
        x=x.reshape(x.shape[0],x.shape[1]*x.shape[2])
        return x

In [None]:
#scaling layer is added after the affine transformations
class scaling_layer(torch.nn.Module):
    def __init__(self,in_dim):
        super(scaling_layer,self).__init__()
        self.scaling_layer=torch.nn.Parameter(torch.zeros(in_dim))
        self.scaling_layer.requires_grad=True
    
    def forward(self,x,backward=False):
        if backward==True:
            x=x*torch.exp(-self.scaling_layer)
        else:
            x=x*torch.exp(self.scaling_layer)
        log_det_jacobian=torch.sum(self.scaling_layer)
        return (x,log_det_jacobian)

In [None]:
class NICE(torch.nn.Module):
    def __init__(self,prior_dist,n_coupling,in_dim,hidden_dim,hidden_layers):
        super(NICE,self).__init__()
        self.prior_dist=prior_dist
        self.n_coupling=n_coupling
        self.coupling_layers=torch.nn.ModuleList([coupling_layer(in_dim,hidden_dim,hidden_layers,mask="even") if i%2==0 else coupling_layer(in_dim,hidden_dim,hidden_layers,mask="odd") for i in range(self.n_coupling)])
        self.scaling_layers=scaling_layer(in_dim)
        
    def inference(self,x):#image_space ==> prior distribution
        for i in range(self.n_coupling):
            x=self.coupling_layers[i](x)
        x=self.scaling_layers(x)
        return (x)
    
    def sampling(self,z):#prior distribution ==> image space
        x,det_jacobian=self.scaling_layers(z,backward=True)
        for i in reversed(range(self.n_coupling)):
            x=self.coupling_layers[i](x,backward=True)
        return (x)
    
    def sample_images(self,number):
        z=self.prior_dist.sample((number,in_dim)).to(device)
        gen_images=self.sampling(z)
        return(gen_images)
    
    def likelihood(self,x):
        #log(p(x))=log(ph(f(x))+log(sii)
        x_,det_jacobian=self.inference(x)
        log_likelihood=torch.sum(self.prior_dist.log_prob(x_),dim=1)
        return(log_likelihood + det_jacobian)
    
    def forward(self,x):
        return(self.likelihood(x))

In [None]:
class logistic_di(torch.distributions.Distribution):
    def __init__(self):
        super(logistic_di, self).__init__()

    def log_prob(self, x):
        
        return -(F.softplus(x) + F.softplus(-x))

    def sample(self, size):
        z = torch.distributions.Uniform(0., 1.).sample(size).to(device)
        return torch.log(z) - torch.log(1. - z)

# MNIST generative model

In [None]:
mean=torch.load("./mnist_mean.pt")
transforms=torchvision.transforms.ToTensor()
batch_size=200
training_data=torchvision.datasets.MNIST(root='torch/data/MNIST',train=True, download=True, transform=transforms)
train_loader=torch.utils.data.DataLoader(training_data,batch_size=batch_size, shuffle=True, num_workers=2)

validation_data=torchvision.datasets.MNIST(root='torch/data/MNIST',train=False,download=True,transform=transforms)
val_loader=torch.utils.data.DataLoader(validation_data,batch_size=batch_size,shuffle=True,num_workers=2)

In [None]:
class logistic_distribution(torch.distributions.Distribution):
    def __init__(self):
        super(logistic_distribution, self).__init__()

    def log_prob(self, x):
      likelihood=-F.softplus(x)-F.softplus(-x)
      return (likelihood)

    def sample(self, size):
        z = torch.distributions.Uniform(0., 1.).sample(size).to(device)
        return torch.log(z) - torch.log(1. - z)

In [None]:
prior_dist=logistic_distribution()
n_coupling=4
in_dim=1*28*28
hidden_dim=1000
hidden_layers=5
flow_model=NICE(prior_dist=prior_dist,n_coupling=n_coupling,in_dim=in_dim,hidden_dim=hidden_dim,hidden_layers=hidden_layers)
flow_model=flow_model.to(device)

In [None]:
mean=torch.load('./mnist_mean.pt')
mean=mean.to(device)
optimizer=torch.optim.Adam(flow_model.parameters(),lr=1e-3,betas=(0.9,0.999), eps=1e-8)
total_iter=0
train=True
running_loss=0.0
max_iter=50000
while train==True:
  for batch_idx,data in enumerate(train_loader):
    flow_model.train()
    if total_iter == max_iter:
      train=False
      break
    total_iter+=1
    optimizer.zero_grad()
    
    images,label=data[0],data[1]
    images=images.to(device)
    noise=torch.distributions.Uniform(0.,1.).sample(images.size())
    noise=noise.to(device)
    images=(images*255.+noise)
    images=images/256.
    B,C,H,W=images.size()
    images=images.reshape(B,C*H*W)
    images=images-mean

    loss=-flow_model(images).mean()
    running_loss+=float(loss)

    loss.backward()
    optimizer.step()

    if total_iter%1000 == 0:
      mean_loss=running_loss / 1000
      print("iter %s:" % total_iter, 'loss= %.3f'% mean_loss)
      running_loss=0.0


      flow_model.eval()
      with torch.no_grad():
        z,det=flow_model.inference(images)
        gen_img=flow_model.sampling(z)
        gen_img=gen_img.cpu()
        gen_samples=flow_model.sample_images(64).cpu()
        [B,H]=gen_samples.size()
        assert [H==1*28*28]
        mean_cpu=mean.cpu()
        gen_samples=gen_samples+mean_cpu
        gen_samples=gen_samples.reshape((B,1,28,28))
        torchvision.utils.save_image(torchvision.utils.make_grid(gen_samples),'./new_samples/'+'iter%d.png' % total_iter)
        print("Images saved")

iter 1000: loss= -989.920
Images saved
iter 2000: loss= -1340.552
Images saved
iter 3000: loss= -1603.609
Images saved
iter 4000: loss= -1795.998
Images saved
iter 5000: loss= -1898.828
Images saved
iter 6000: loss= -1938.368
Images saved
iter 7000: loss= -1961.485
Images saved
iter 8000: loss= -1976.175
Images saved
iter 9000: loss= -1987.314
Images saved
iter 10000: loss= -1996.313
Images saved
iter 11000: loss= -2002.843
Images saved
iter 12000: loss= -2009.081
Images saved
iter 13000: loss= -2014.430
Images saved
iter 14000: loss= -2019.396
Images saved
iter 15000: loss= -2023.134
Images saved
iter 16000: loss= -2028.518
Images saved
iter 17000: loss= -2032.474
Images saved
iter 18000: loss= -2035.562
Images saved
iter 19000: loss= -2040.874
Images saved
iter 20000: loss= -2041.354
Images saved
iter 21000: loss= -2045.219
Images saved
iter 22000: loss= -2047.551
Images saved
iter 23000: loss= -2050.529
Images saved
iter 24000: loss= -2052.647
Images saved
iter 25000: loss= -2055.18

In [None]:
!zip -r ./new_sample.zip ./new_samples
files.download("./new_sample.zip")

  adding: new_samples/ (stored 0%)
  adding: new_samples/iter24000.png (deflated 5%)
  adding: new_samples/iter29000.png (deflated 5%)
  adding: new_samples/iter34000.png (deflated 5%)
  adding: new_samples/iter36000.png (deflated 5%)
  adding: new_samples/iter32000.png (deflated 5%)
  adding: new_samples/iter4000.png (deflated 4%)
  adding: new_samples/iter37000.png (deflated 5%)
  adding: new_samples/iter19000.png (deflated 5%)
  adding: new_samples/iter2000.png (deflated 4%)
  adding: new_samples/iter45000.png (deflated 5%)
  adding: new_samples/iter46000.png (deflated 5%)
  adding: new_samples/iter21000.png (deflated 5%)
  adding: new_samples/iter12000.png (deflated 5%)
  adding: new_samples/iter30000.png (deflated 5%)
  adding: new_samples/iter1000.png (deflated 4%)
  adding: new_samples/iter3000.png (deflated 4%)
  adding: new_samples/iter44000.png (deflated 5%)
  adding: new_samples/iter47000.png (deflated 5%)
  adding: new_samples/iter22000.png (deflated 5%)
  adding: new_sampl