In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
%matplotlib inline

from os import listdir
from os.path import isfile, isdir, join

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import distributions

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Utility functions

In [None]:
def plot_points(cloud, xlim=None, ylim=None, zlim=None, save_name=None, s=10, alpha=1):
    '''
    ploting point cloud
    '''
    fig = plt.figure(figsize=(10,10))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(cloud[:,0], cloud[:,1], cloud[:,2], s=s, alpha=alpha)
    if xlim:
        ax.set_xlim([-xlim,xlim])
    if ylim:
        ax.set_ylim([-ylim,ylim])
    if zlim:
        ax.set_zlim([-zlim,zlim])
    if save_name:
        plt.savefig(r'Plots/' + save_name + '.png', bbox_inches='tight')
    plt.show()

In [None]:
class ShapeNet(Dataset):
    def __init__(self, path, classes=['chair'], objects_ids=None, mixed=True):
        shapenet = np.load(path)
        self.objects_ids = objects_ids
        self.cloud = []
        self.targets = []
        for _class in classes:
            if objects_ids:
                X = [shapenet[_class][idx] for idx in objects_ids]
            else:
                X = shapenet[_class]
                
            if mixed:        
                for target, cloud in enumerate(X):
                    self.cloud.extend(cloud)
                    self.targets.extend([target] * len(cloud))

                self.cloud = np.array(self.cloud).reshape(-1,3)
            else:
                self.cloud = X
                self.targets = [i for i in range(len(X))]
            
            rotation_matrix = np.array([[0, -1, 0],
                                       [1, 0, 0],
                                       [0, 0, 1]])
            
            self.cloud = np.dot(self.cloud, rotation_matrix)
            
    def __len__(self):
        return len(self.cloud)
    
    def __getitem__(self, idx):
        point, target = self.cloud[idx], self.targets[idx]
        point, target = torch.from_numpy(point), torch.tensor(target)
        return point, target

# Model implementation

In [None]:
class F_MulNet(nn.Module):        
    def __init__(self, emb_dim):
        super(F_MulNet, self).__init__()
        
        self.layer0 = nn.Sequential(
            nn.Linear(emb_dim, 256),
            nn.LeakyReLU(0.1)
        )   
        
        self.layer1 = nn.Sequential(
            nn.Linear(2, 256),
            nn.LeakyReLU(0.1)
        )
        
        self.layer2 = nn.Sequential(
            nn.Linear(512, 512),
            nn.LeakyReLU(0.1)
        )
        
        self.layer3 = nn.Sequential(
            nn.Linear(512, 512),
            nn.LeakyReLU(0.1)
        )   
            
        self.layer4 = nn.Linear(512, 512)
            
        self.layer5 = nn.Sequential(
            nn.Linear(512, 512),
            nn.LeakyReLU(0.1)
        )   
            
        self.layer6 = nn.Linear(512, 512)
        
        self.layer7 = nn.Sequential(
            nn.Linear(512, 512),
            nn.LeakyReLU(0.1)
        )   
            
        self.layer8 = nn.Linear(512, 512)
        
        self.layer9 = nn.Sequential(
            nn.Linear(512, 512),
            nn.LeakyReLU(0.1)
        )   
            
        self.layer10 = nn.Linear(512, 512)
        
        self.layer11 = nn.Sequential(
            nn.Linear(512,1),
            nn.Tanh()
        )
    
    def forward(self, x, emb):
        emb = self.layer0(emb)
        x = self.layer1(x)
        x = self.layer2(torch.cat([x,emb], dim=1))
        
        _x = x
        x = self.layer3(x)
        x = self.layer4(x)
        x = F.leaky_relu(x + _x, negative_slope=0.1)
        
        _x = x
        x = self.layer5(x)
        x = self.layer6(x)
        x = F.leaky_relu(x + _x, negative_slope=0.1)
        
        _x = x
        x = self.layer7(x)
        x = self.layer8(x)
        x = F.leaky_relu(x + _x, negative_slope=0.1)
        
        _x = x
        x = self.layer9(x)
        x = self.layer10(x)
        x = F.leaky_relu(x + _x, negative_slope=0.1)
        
        x = self.layer11(x)
        
        return x

In [None]:
class F_AddNet(nn.Module):
    def __init__(self, emb_dim):
        super(F_AddNet, self).__init__()
        
        self.layer0 = self.layer1 = nn.Sequential(
            nn.Linear(emb_dim, 256),
            nn.LeakyReLU(0.1)
        )   
        
        self.layer1 = nn.Sequential(
            nn.Linear(2, 256),
            nn.LeakyReLU(0.1)
        )
        
        self.layer2 = nn.Sequential(
            nn.Linear(512, 512),
            nn.LeakyReLU(0.1)
        )
        
        self.layer3 = nn.Sequential(
            nn.Linear(512, 512),
            nn.LeakyReLU(0.1)
        )   
            
        self.layer4 = nn.Linear(512, 512)
            
        self.layer5 = nn.Sequential(
            nn.Linear(512, 512),
            nn.LeakyReLU(0.1)
        )   
            
        self.layer6 = nn.Linear(512, 512)
        
        self.layer7 = nn.Sequential(
            nn.Linear(512, 512),
            nn.LeakyReLU(0.1)
        )   
            
        self.layer8 = nn.Linear(512, 512)
        
        self.layer9 = nn.Sequential(
            nn.Linear(512, 512),
            nn.LeakyReLU(0.1)
        )   
            
        self.layer10 = nn.Linear(512, 512)
        
        self.layer11 = nn.Sequential(
            nn.Linear(512,1)
        )
    
    def forward(self, x, emb):
        emb = self.layer0(emb)
        x = self.layer1(x)
        x = self.layer2(torch.cat([x,emb], dim=1))
        
        _x = x
        x = self.layer3(x)
        x = self.layer4(x)
        x = F.leaky_relu(x + _x, negative_slope=0.1)
        
        _x = x
        x = self.layer5(x)
        x = self.layer6(x)
        x = F.leaky_relu(x + _x, negative_slope=0.1)
        
        _x = x
        x = self.layer7(x)
        x = self.layer8(x)
        x = F.leaky_relu(x + _x, negative_slope=0.1)
        
        _x = x
        x = self.layer9(x)
        x = self.layer10(x)
        x = F.leaky_relu(x + _x, negative_slope=0.1)
        
        x = self.layer11(x)
        
        return x

In [None]:
def init_weights(Layer):
    name = Layer.__class__.__name__
    if name == 'Linear':
        torch.nn.init.normal_(Layer.weight, mean=0, std=0.02)
        if Layer.bias is not None:
            torch.nn.init.constant_(Layer.bias, 0)

In [None]:
def optim(F_flows, l_rate):
    params = []
    for key in F_flows:
        params += list(F_flows[key].parameters())
        
    optimizer = torch.optim.Adam(params, lr=l_rate)
    return optimizer

In [None]:
def loss_fun(z, logdetJ, prior_z):
    ll_z = prior_z.log_prob(z.cpu()).to(device) + logdetJ
    return -torch.mean(ll_z)

Implement transformations of the CIF model here. You will need to fill the body of 2 functions: 
* `F_flow` - forward pass of the CIF
$$\begin{cases}
y_1 =& x_1\\ 
y_2 =& x_2 \odot \exp (s(x_1, e)) + t(x_1, e)
\end{cases}$$
* `F_inv_flow` - inversion of the forward pass
$$\begin{cases}
x_1 =& y_1\\ 
x_2 =& (y_2 - t(y_1, e)) \odot \exp (-s(y_1, e))
\end{cases}$$
Don't use masks. Instead, split dimensions, perform operations and concatenate results in proper way. \\
Hint: use `torch.cat()`.

In [None]:
def F_flow(x, emb, F_flows, n_flows_F):
    ldetJ = 0
    for n in range(n_flows_F):
        for k in range(3):
            h1, h2 = x[:,:2], x[:,2,None]
            M = F_flows['MNet' + str(n) + str(k)](h1, emb)
            A = F_flows['ANet' + str(n) + str(k)](h1, emb)
            h2 = h2 * torch.exp(M) + A
            ldetJ += torch.sum(M, dim=1).view(-1,1)
            x = torch.cat([h2, h1], dim=1)
    return x, ldetJ

In [None]:
def F_inv_flow(z, emb, F_flows, n_flows_F):
    for n in range(n_flows_F-1, -1, -1):
        for k in range(2, -1, -1):
            h1, h2 = z[:,1:], z[:,0,None]
            M_inv = torch.exp(-F_flows['MNet' + str(n) + str(k)](h1, emb))
            A = F_flows['ANet' + str(n) + str(k)](h1, emb)
            h2 = (h2 - A) * M_inv
            z = torch.cat([h1, h2], dim=1)
    return z

# Experiments

In [None]:
cloud = ShapeNet(path=r'/content/drive/My Drive/MLinPL/Notebooks/ShapeNet/3chairs.npz')
dataloader = DataLoader(cloud, batch_size=5000, shuffle=True)

for i in range(cloud.cloud.shape[0] // 2048):
    plot_points(cloud.cloud[2048*i:2048*(i+1)])

In [None]:
cloud.cloud.shape

In [None]:
n_epochs = 1000
n_flows_F = 5
l_rate = 1e-4
emb_dim = cloud.cloud.shape[0] // 2048
n_sample_points = 1000

prior_z = distributions.MultivariateNormal(torch.zeros(3), torch.eye(3))

In [None]:
F_flows = {}
for n in range(n_flows_F):
    for i in range(3):
        F_flows['MNet' + str(n) + str(i)] = F_MulNet(emb_dim).to(device)
        F_flows['MNet' + str(n) + str(i)].apply(init_weights)
        F_flows['ANet' + str(n) + str(i)] = F_AddNet(emb_dim).to(device)
        F_flows['ANet' + str(n) + str(i)].apply(init_weights)

optimizer = optim(F_flows, l_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.8)

In [None]:
for key in F_flows:
    F_flows[key].train()

for i in range(n_epochs):
    loss_acc = 0
    for j, (x, targets) in enumerate(dataloader):
        x = (x.float() + 1e-4*torch.rand(x.shape)).to(device)
        targets = targets.view(-1, 1).to(device)
        emb = torch.zeros(x.shape[0], emb_dim).to(device)
        emb.scatter_(1, targets, 1)

        z, logdetJ = F_flow(x, emb, F_flows, n_flows_F)
        
        loss = loss_fun(z, logdetJ, prior_z)
        loss_acc += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    scheduler.step()
            
    if i%50 == 0:
        print('Epoch: {}/{} Loss: {:.4f}'.format(i+1, n_epochs, loss_acc / (j+1)))
        with torch.no_grad():
            for l in range(emb_dim):
                z = torch.randn(2048, 3).to(device).float()
                targets = torch.LongTensor(2048, 1).fill_(l).to(device)
                emb = torch.zeros(2048, emb_dim).to(device)
                emb.scatter_(1, targets, 1)
                emb = emb.to(device)    
                
                z = F_inv_flow(z, emb, F_flows, n_flows_F)

                plot_points(z.cpu().numpy())

In [None]:
n_sample_points = 100000

for key in F_flows:
    F_flows[key].eval()

with torch.no_grad():
    for l in range(emb_dim):
        z = torch.randn(n_sample_points,3).to(device).float()
        targets = torch.LongTensor(n_sample_points, 1).fill_(l).to(device)
        emb = torch.zeros(n_sample_points, emb_dim).to(device)
        emb.scatter_(1, targets, 1)
        emb = emb.to(device)    

        z = F_inv_flow(z, emb, F_flows, n_flows_F)

        plot_points(z.cpu().numpy(), alpha=0.05)