<a href="https://colab.research.google.com/github/VadimFarutin/deep-unsupervised-learning/blob/hw02/hw02_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# HW02

## 2 High-dimensional data

### Imports and Data

In [0]:
%%capture
!pip install graphql-core==2.0
!pip install wandb -q

In [0]:
import wandb
!wandb login

In [0]:
from collections import defaultdict
from tqdm import tnrange, tqdm_notebook
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
import numpy as np
from matplotlib import pyplot as plt
import random
import pickle

import torch
import torch.nn as nn
from torch.nn.modules import loss
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data.dataset import Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.distributions import Normal, Uniform, MultivariateNormal

In [0]:
SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
random.seed(SEED)
np.random.seed(SEED)

In [5]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# DEVICE = torch.device('cpu')
print(DEVICE)

cuda


In [6]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [0]:
def read_data():
    path = 'drive/My Drive/hw2_q2.pkl'

    with open(path, 'rb') as file:
        dataset = pickle.load(file)
    
    return dataset['train'].transpose(0, 3, 1, 2), dataset['test'].transpose(0, 3, 1, 2)


In [0]:
train, val = read_data()

In [17]:
print(train.shape, val.shape)

(20000, 3, 32, 32) (6838, 3, 32, 32)


In [0]:
def MyNLLLoss(y):
    return -torch.mean(torch.log(y)) / 2

In [0]:
def fit(model, train, val, optimizer, loss_function, epoch_cnt, batch_size):
    train_loader = torch.utils.data.DataLoader(torch.from_numpy(train), batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(torch.from_numpy(val), batch_size=batch_size)
    train_loss_values = []
    val_loss_values = []
            
    for epoch in tnrange(epoch_cnt, desc='Epoch'):
        model.train()
        for batch_data in train_loader:
            x = batch_data.float().to(DEVICE)
            optimizer.zero_grad()
            output = model(x)
            loss = loss_function(output)
            train_loss_values.append(loss)
            loss.backward()
            optimizer.step()

            wandb.log({"Train Loss": loss})

        loss_values = []
        model.eval()
        for batch_data in val_loader:
            x = batch_data.float().to(DEVICE)
            output = model(x)
            loss = loss_function(output)
            loss_values.append(loss.item())
        val_loss_values.append(np.mean(np.array(loss_values)))

        wandb.log({"Validation Loss": val_loss_values[-1]})
    
    return train_loss_values, val_loss_values 


In [0]:
def plot_loss_values(train_loss_values, val_loss_values):
    plt.plot(np.arange(len(train_loss_values)), train_loss_values, color='blue', label='train')
    plt.plot(np.arange(0, len(train_loss_values), len(train_loss_values) / config.epochs), val_loss_values, color='red', label='validation')
    plt.legend()
    plt.title("Loss values")
    plt.xlabel("iteration")
    plt.ylabel("loss")
    plt.show()

### RealNVP

In [0]:
class ResBlock(nn.Module):
    def __init__(self, channels, type):
        super(ResBlock, self).__init__()

        if type == 'A':
            self.layers = nn.Sequential(nn.Conv2d(channels, channels, kernel_size=(1, 1), stride=(1, 1), padding=0),
                                        nn.ReLU(),
                                        nn.Conv2d(channels, channels, kernel_size=(3, 3), stride=(1, 1), padding=1))        
        else:
            self.layers = nn.Sequential(nn.ReLU(),
                                        nn.Conv2d(channels, channels, kernel_size=(1,1), stride=(1, 1), padding=0))
        
    def forward(self, x):
        out = self.layers(x)
        return out

    def __call__(self, x):
        return self.forward(x)

In [0]:
class Resnet(nn.Module):
    def __init__(self, in_channels, hidden_size, n_blocks):
        super(Resnet, self).__init__()
        out_channels = in_channels * 2

        self.conv1 = nn.Conv2d(in_channels, hidden_size, kernel_size=(3, 3), stride=(1, 1), padding=2)

        self.res_blocks_a = torch.nn.ModuleList([ResBlock(hidden_size, 'A')
                                                 for _ in range(n_blocks)])
        self.res_blocks_b = torch.nn.ModuleList([ResBlock(hidden_size, 'B')
                                                 for _ in range(n_blocks)])

        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(hidden_size, out_channels, kernel_size=(3, 3), stride=(1, 1))
    
    def forward(self, x):
        h = self.conv1(x)

        for res_block_a, res_block_b in zip(self.res_blocks_a, self.res_blocks_b):
            _h = self.res_block_a(h)
            h = self.res_block_b(_h)
            h = h + _h

        h = self.relu(h)
        x = self.conv2(h)

        return x

    def __call__(self, x):
        return self.forward(x)

In [0]:
class AffineCoupling(nn.Module):
    def __init__(self, mask, in_channels):
        super(AffineCoupling, self).__init__()
        self.mask = mask
        self.resnet = Resnet(in_channels, hidden_size=256, n_blocks=8)
    
    def forward(self, x):
        y1 = self.mask * x
        log_s, t = torch.chunk(self.resnet(y1), 2, dim=1)
        log_det = log_s.view(y1.shape[0], -1).sum(dim=1)

        return log_det

    def latent(self, x):
        y1 = self.mask * x
        log_s, t = torch.chunk(self.resnet(y1), 2, dim=1)
        y2 = (1 - self.mask) * torch.exp(log_s) * (x2 + t)

        return y1 + y2

    def inverse(self, y):
        x1 = self.mask * y
        log_s, t = torch.chunk(self.resnet(x1), 2, dim=1)
        x2 = (1 - self.mask) * y * torch.exp(-log_s) - t

        return x1 + x2
        
    def __call__(self, x):
        return self.forward(x)

In [0]:
class MyRealNVP(nn.Module):
    def __init__(self, in_channels, size):
        super(MyRealNVP, self).__init__()

        mask = self.checkerboard_mask(in_channels, size)
        self.couplings1 = nn.Sequential(AffineCoupling(mask, in_channels),
                                        AffineCoupling(1 - mask, in_channels),
                                        AffineCoupling(mask, in_channels),
                                        AffineCoupling(1 - mask, in_channels))

        size = size[0] // 2, size[1] // 2
        mask = self.channel_split_mask(in_channels, size)
        self.couplings1 = nn.Sequential(AffineCoupling(mask, in_channels * 4),
                                        AffineCoupling(1 - mask, in_channels),
                                        AffineCoupling(mask, in_channels * 4))

        mask = self.checkerboard_mask(in_channels, size)
        self.couplings1 = nn.Sequential(AffineCoupling(mask, in_channels * 4),
                                        AffineCoupling(1 - mask, in_channels),
                                        AffineCoupling(mask, in_channels * 4))

        size = size[0] // 2, size[1] // 2
        mask = self.channel_split_mask(in_channels, size)
        self.couplings1 = nn.Sequential(AffineCoupling(mask, in_channels * 16),
                                        AffineCoupling(1 - mask, in_channels),
                                        AffineCoupling(mask, in_channels * 16))

        mask = self.checkerboard_mask(in_channels, size)
        self.couplings1 = nn.Sequential(AffineCoupling(mask, in_channels * 16),
                                        AffineCoupling(1 - mask, in_channels),
                                        AffineCoupling(mask, in_channels * 16))

    def checkerboard_mask(self, in_channels, size):
        black = np.ones([1, in_channels, size[0], size[1]], dtype=np.bool)
        white = np.ones([1, in_channels, size[0], size[1]], dtype=np.bool)
        black[:, :, np.arange(1, size[0], 2), :] = False 
        white[:, :, :, np.arange(0, size[1], 2)] = False
        mask = torch.tensor(black ^ white, dtype=torch.float32).to(DEVICE)

        return mask

    def channel_split_mask(self, in_channels, size):
        mask = torch.zeros([1, in_channels, size[0], size[1]], dtype=torch.float32).to(DEVICE)
        channels_i = np.arange(0, in_channels // 4) * 4
        channels_i = np.stack((channels_i, channels_i + 1), axis=1).reshape(-1)
        mask[:, channels_i, :, :] = 1.0

        return mask
    
    def squeeze(self, x):
        b, c, h, w = x.shape
        return F.unfold(x, (2, 2), stride=2).reshape(b, 4 * c, h // 2, w // 2)

    def unsqueeze(self, x):
        b, c, h, w = x.shape
        return F.fold(x.reshape(b, c, -1), (h * 2, w * 2), (2, 2), stride=2)
            
    def forward(self, x):
        logdet = torch.zeros((x.shape[0], 1)).to(DEVICE)
        y = x
       
        for layer in self.couplings1:
            logdet += layer(y)
            y = layer.latent(y)

        y = self.squeeze(y)
        
        for layer in self.couplings2:
            logdet += layer(y)
            y = layer.latent(y)
        for layer in self.couplings3:
            logdet += layer(y)
            y = layer.latent(y)

        y = self.squeeze(y)
        
        for layer in self.couplings4:
            logdet += layer(y)
            y = layer.latent(y)
        for layer in self.couplings5:
            logdet += layer(y)
            y = layer.latent(y)
        
        logdet = torch.exp(logdet)

        return logdet
        
    def latent(self, x):
        y = x
        
        for layer in self.couplings1:
            y = layer.latent(y)

        y = self.squeeze(y)
        
        for layer in self.couplings2:
            y = layer.latent(y)
        for layer in self.couplings3:
            y = layer.latent(y)

        y = self.squeeze(y)
        
        for layer in self.couplings4:
            y = layer.latent(y)
        for layer in self.couplings5:
            y = layer.latent(y)

        y = self.unsqueeze(y)
        y = self.unsqueeze(y)
        
        return y

    def inverse(self, y):
        x = y
        x = self.squeeze(x)
        x = self.squeeze(x)
        
        for layer in reversed(self.couplings5):
            x = layer.inverse(x)
        for layer in reversed(self.couplings4):
            x = layer.inverse(x)
            
        x = self.unsqueeze(x)
        
        for layer in reversed(self.couplings3):
            x = layer.inverse(x)
        for layer in reversed(self.couplings2):
            x = layer.inverse(x)
            
        x = self.unsqueeze(x)
        
        for layer in reversed(self.couplings1):
            x = layer.inverse(x)
            
        return x

    def __call__(self, x):
        return self.forward(x)

In [41]:
wandb.init(entity="vadim-farutin", project="HSE-DUL-HW02-2")
wandb.watch_called = False # Re-run the model without restarting the runtime, unnecessary after our next release

config = wandb.config
config.lr = 1e-4
config.batch_size = 32
config.epochs = 2

In [0]:
channels = 3
size = (32, 32)
model = MyRealNVP(channels, size)
model = model.float()
model = model.to(DEVICE)

In [0]:
loss_function = MyNLLLoss
optimizer = optim.Adam(model.parameters(), lr=config.lr, weight_decay=0)

In [0]:
train_loss_values, val_loss_values =\
    fit(model, train, val, optimizer, loss_function, config.epochs, config.batch_size)

In [0]:
plot_loss_values(train_loss_values, val_loss_values)

In [0]:
# z = MultivariateNormal(torch.zeros(2), torch.eye(2)).sample((10000, 2))
# low = torch.tensor([0.0, 0.0]).to(DEVICE)
# high = torch.tensor([1.0, 1.0]).to(DEVICE)
# z = Uniform(low, high).rsample((10000, 2))

# x1, x2 = model.inverse(z)
# x1 = x1.cpu().detach().numpy()
# x2 = x2.cpu().detach().numpy()
# plt.figure()
# plt.scatter(x1, x2)