In [63]:
import torch
from sklearn.datasets import make_moons
from torch import nn, optim
from torch.autograd import Variable
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, TensorDataset
from torchvision.utils import save_image
import os
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons, load_digits
from sklearn.model_selection import train_test_split
import numpy as np
from torchvision.transforms.functional import rotate
from scipy import ndimage

In [70]:
class RealNVP(nn.Module):
    def __init__(self, input_size, hidden_size, blocks):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.blocks = blocks
        self.skip_connection_dimension = int(np.floor(input_size / 2))
        self.transform_dimension = int(input_size - self.skip_connection_dimension)
        
        self.scalling_coefficient_network = nn.Sequential(
            *[nn.Linear(self.skip_connection_dimension, self.hidden_size),
              nn.Linear(self.hidden_size, self.hidden_size),
              nn.ReLU(),
              nn.Linear(self.hidden_size, self.hidden_size),
              nn.ReLU(),
              nn.Linear(self.hidden_size, self.transform_dimension),
              nn.Tanh()]
        )
        
        self.translation_coefficient_network = nn.Sequential(
            *[nn.Linear(self.skip_connection_dimension, self.hidden_size),
              nn.Linear(self.hidden_size, self.hidden_size),
              nn.ReLU(),
              nn.Linear(self.hidden_size, self.hidden_size),
              nn.ReLU(),
              nn.Linear(self.hidden_size, self.transform_dimension)]
        )
        
    def forward(self, X):
        for l in range(self.blocks):
            X_skip = X[:, :self.skip_connection_dimension]
            X_transform = X[:, self.skip_connection_dimension:]
            s_tilde = self.scalling_coefficient_network(X_skip)
            s = torch.exp(s_tilde)
            t = self.translation_coefficient_network(X_skip)
            X_transform = torch.mul(X_transform, s) + t
            X = torch.cat((X_skip, X_transform), axis=1)
            random_rotation_angle = np.random.uniform(-180, 180)
            # Since multiplying a vector by an orthogonal matrix represents a rotation or reflection, 
            # we apply a rotation with random angle to improve performance. (Skip QR factorization of random matrix) 
            X = torch.from_numpy(ndimage.rotate(X.detach().numpy(), random_rotation_angle, reshape=False))
        return X
    
    def train_inn(self, X_train,  ts_size, epochs, lr=0.001, batch_size=128):
            assert ts_size % batch_size == 0
            assert batch_size % 2 == 0
            indices = torch.randperm(len(X_train))[:ts_size]
            X_train = X_train[indices]
            error_type = nn.NLLLoss()
            optimizer = torch.optim.Adam(self.parameters(), lr=lr)
            losses = []
            for epoch in range(epochs):
                for batch in range(0, ts_size - batch_size, batch_size):
                    X_batch = X_train[batch : batch + batch_size]
                    output = self.forward(X_batch)
                    print(output.shape)
                    print(X_batch.shape)
                    loss = error_type(output, X_batch)
                    losses.append(loss)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()


            return self, losses

        
        

In [56]:
def import_data(noise=0, random_state=1, shuffle=True, n_test=0.5, name="moons", n_samples = 2000):
    if name == "moons":
        data, _ = make_moons(noise=noise, random_state=random_state, shuffle=shuffle, n_samples=n_samples)
        X_train, X_test= train_test_split(data, test_size=n_test, random_state=random_state)
        return torch.FloatTensor(X_train), torch.FloatTensor(X_test), None, None

    elif name == "digits":
        data, labels = load_digits(return_X_y = True)
        data = data.reshape((len(data), -1))
        noise =  np.random.normal(0, noise * data.max(), data.shape)
        X_train, X_test, Y_train, Y_test = train_test_split(data, labels, test_size=n_test, random_state=random_state)
        return torch.FloatTensor(X_train), torch.FloatTensor(X_test), torch.FloatTensor(Y_train), torch.FloatTensor(Y_test)

In [57]:
X_train, X_test, Y_train, Y_test = import_data(noise=0.1, n_test=0.5,  name="moons")

In [71]:
nvp = RealNVP(2, 5, 5).train_inn(X_train, 512, 100)

torch.Size([128, 2])
torch.Size([128, 2])


RuntimeError: 0D or 1D target tensor expected, multi-target not supported