<a href="https://colab.research.google.com/github/AntonYermilov/deep-unsupervised-learning/blob/task6/Task_6.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip3 install --upgrade tqdm

Collecting tqdm
[?25l  Downloading https://files.pythonhosted.org/packages/8c/c3/d049cf3fb31094ee045ec1ee29fffac218c91e82c8838c49ab4c3e52627b/tqdm-4.41.0-py2.py3-none-any.whl (56kB)
[K     |█████▊                          | 10kB 20.3MB/s eta 0:00:01[K     |███████████▌                    | 20kB 4.0MB/s eta 0:00:01[K     |█████████████████▎              | 30kB 5.7MB/s eta 0:00:01[K     |███████████████████████         | 40kB 7.1MB/s eta 0:00:01[K     |████████████████████████████▉   | 51kB 4.8MB/s eta 0:00:01[K     |████████████████████████████████| 61kB 3.7MB/s 
[?25hInstalling collected packages: tqdm
  Found existing installation: tqdm 4.28.1
    Uninstalling tqdm-4.28.1:
      Successfully uninstalled tqdm-4.28.1
Successfully installed tqdm-4.41.0


In [0]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
from torch.utils.data.dataset import random_split
from torch.optim import Adam
from torch.nn import NLLLoss, CrossEntropyLoss
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm

import plotly.offline as py
import plotly.graph_objs as go
import plotly.express as px
import plotly.figure_factory as ff

In [0]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [0]:
def sample_data_1():
    count = 100000
    rand = np.random.RandomState(0)
    return [[1.0, 2.0]] + rand.randn(count, 2) * [[5.0, 1.0]]


def sample_data_2():
    count = 100000
    rand = np.random.RandomState(0)
    return [[1.0, 2.0]] + (rand.randn(count, 2) * [[5.0, 1.0]]).dot([[np.sqrt(2) / 2, np.sqrt(2) / 2], [-np.sqrt(2) / 2, np.sqrt(2) / 2]])


def sample_data_3():
    count = 100000
    rand = np.random.RandomState(0)
    a = [[-1.5, 2.5]] + rand.randn(count // 3, 2) * 0.2
    b = [[1.5, 2.5]] + rand.randn(count // 3, 2) * 0.2
    c = np.c_[2 * np.cos(np.linspace(0, np.pi, count // 3)), -np.sin(np.linspace(0, np.pi, count // 3))]

    c += rand.randn(*c.shape) * 0.2
    data_x = np.concatenate([a, b, c], axis=0)
    data_y = np.array([0] * len(a) + [1] * len(b) + [2] * len(c))
    perm = rand.permutation(len(data_x))
    return data_x[perm], data_y[perm]

In [0]:
class Encoder(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, output_size: int):
        super().__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        self.encoder =  nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.Tanh()
        )
        self.mu = nn.Linear(hidden_size, output_size)
        self.logvar = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.encoder(x)
        return self.mu(x), self.logvar(x)


class MultivariateDecoder(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, output_size: int):
        super().__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        self.decoder =  nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.Tanh()
        )
        self.mu = nn.Linear(hidden_size, output_size)
        self.logvar = nn.Linear(hidden_size, output_size)

    def forward(self, z):
        z = self.decoder(z)
        return self.mu(z), self.logvar(z)


class ScalarDecoder(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, output_size: int):
        super().__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        self.decoder =  nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.Tanh()
        )
        self.mu = nn.Linear(hidden_size, output_size)
        self.sigma = nn.Linear(hidden_size, 1)

    def forward(self, z):
        z = self.decoder(z)
        return self.mu(z), self.sigma(z)


class SimpleVAE(nn.Module):
    def __init__(self, input_size, hidden_size, encoder_type, decoder_type, device):
        super().__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.device = device

        self.encoder = encoder_type(input_size, 20, hidden_size)
        self.decoder = decoder_type(hidden_size, 20, input_size)

    def forward(self, x):
        mu_z, logvar_z = self.encoder(x)
        z = mu_z + torch.exp(0.5 * logvar_z) * torch.randn_like(mu_z).to(device)
        mu_y, logvar_y = self.decoder(z)
        y = mu_y + torch.exp(0.5 * logvar_y) * torch.randn_like(mu_y).to(device)
        return y, mu_z, logvar_z, mu_y, logvar_y

    def generate_rand(self, n):
        z = torch.randn(n, self.hidden_size).to(device)
        mu_y, logvar_y = self.decoder(z)
        y = mu_y + torch.exp(0.5 * logvar_y) * torch.randn_like(mu_y).to(device)
        return y

    def generate_mean(self, n):
        z = torch.randn(n, self.hidden_size).to(device)
        mu_y, logvar_y = self.decoder(z)
        return mu_y

In [0]:
def vae_loss(x, mu_z, logvar_z, mu_y, logvar_y):
    kldiv_loss = 0.5 * (logvar_z.exp() + mu_z**2 - 1 - logvar_z).mean()

    log_prob = -0.5 * (logvar_y + torch.log(2 * torch.tensor(np.pi)) + (x - mu_y)**2 / logvar_y.exp())
    log_prob[log_prob > 0] = 0
    decoder_loss = -log_prob.mean()

    return kldiv_loss, decoder_loss

In [0]:
def plot_model(samples, model, dataset_name=''):
    N = 1000
    samples = samples[:N]
    generated_rand = model.generate_rand(N).to('cpu').detach().numpy()
    generated_mean = model.generate_mean(N).to('cpu').detach().numpy()

    x, y = samples[:,0], samples[:,1]
    x_rand, y_rand = generated_rand[:,0], generated_rand[:,1]
    x_mean, y_mean = generated_mean[:,0], generated_mean[:,1]

    traces = [
        go.Scatter(x=x, y=y, mode='markers', name='samples'),
        go.Scatter(x=x_rand, y=y_rand, mode='markers', name='generated random'),
        go.Scatter(x=x_mean, y=y_mean, mode='markers', name='generated mean')
    ]
    figure = go.Figure(data=traces)

    figure.update_layout(title=dataset_name, width=900, height=900)
    figure.update_xaxes(range=[-20, 20])
    figure.update_yaxes(range=[-20, 20])
    figure.update_traces(marker=dict(size=3))
    figure.show()


def plot_labeled_model(samples_x, samples_y, model, dataset_name=''):
    N = 1000
    samples_x, samples_y = samples_x[:N], samples_y[:N]
    generated_rand = model.generate_rand(N).to('cpu').detach().numpy()
    generated_mean = model.generate_mean(N).to('cpu').detach().numpy()

    x1, y1 = samples_x[:,0], samples_x[:,1]
    x_rand, y_rand = generated_rand[:,0], generated_rand[:,1]
    x_mean, y_mean = generated_mean[:,0], generated_mean[:,1]

    X, _, _, _, _ = model(torch.tensor(samples_x).float().to(model.device))
    X = X.to('cpu').detach().numpy()
    x2, y2 = X[:,0], X[:,1]

    traces = [
        go.Scatter(x=x1, y=y1, mode='markers', name='initial samples', marker=dict(color=samples_y)),
        go.Scatter(x=x1, y=y1, mode='markers', name='restored samples', marker=dict(color=samples_y)),
        go.Scatter(x=x_rand, y=y_rand, mode='markers', name='generated random'),
        go.Scatter(x=x_mean, y=y_mean, mode='markers', name='generated mean')
    ]
    figure = go.Figure(data=traces)

    figure.update_layout(title=dataset_name, width=900, height=900)
    figure.update_xaxes(range=[-5, 5])
    figure.update_yaxes(range=[-5, 5])
    figure.update_traces(marker=dict(size=3))
    figure.show()

In [0]:
def make_loss_figure():
    traces = [
        go.Scatter(x=[], y=[], mode='lines', name='train ELBO loss'),
        go.Scatter(x=[], y=[], mode='lines', name='train KL-div loss'),
        go.Scatter(x=[], y=[], mode='lines', name='train decoder loss'),
        go.Scatter(x=[], y=[], mode='lines', name='valid ELBO loss'),
        go.Scatter(x=[], y=[], mode='lines', name='valid KL-div loss'),
        go.Scatter(x=[], y=[], mode='lines', name='valid decoder loss')
    ]
    figure = go.FigureWidget(data=traces)
    figure.update_layout(title='Losses')
    return figure


def fit_model(model, optimizer, data_x, data_c, epochs, batch_size, device):
    data_x = torch.tensor(data_x).float()
    data_c = torch.zeros(data_x.shape[0]).float() if data_c is None else torch.tensor(data_c).float()
    
    
    dataset = TensorDataset(data_x, data_c)

    train_len = int(0.8 * len(dataset))
    valid_len = len(dataset) - train_len
    train_dataset, valid_dataset = random_split(dataset, [train_len, valid_len])

    train_sampler = RandomSampler(train_dataset)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

    figure = make_loss_figure()

    for epoch in tqdm(range(epochs)):
        print(f'epoch={epoch},\t', end='')

        model.train()
        kldiv_loss_avg, decoder_loss_avg, total_loss_avg = 0, 0, 0
        for batch in train_loader:
            x, c = [b.to(device) for b in batch]
            _, mu_z, logvar_z, mu_y, logvar_y = model(x)
            kldiv_loss, decoder_loss = vae_loss(x, mu_z, logvar_z, mu_y, logvar_y)
            total_loss = 1 * kldiv_loss + decoder_loss

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            kldiv_loss_avg += kldiv_loss.item()
            decoder_loss_avg += decoder_loss.item()
            total_loss_avg += total_loss.item()
            
        for i, loss in enumerate([total_loss_avg, kldiv_loss_avg, decoder_loss_avg]):
            figure.data[i].x += tuple([epoch])
            figure.data[i].y += tuple([loss / len(train_loader)])

        print(f'train loss={total_loss_avg / len(train_loader):.8f},\t', end='')

        model.eval()
        kldiv_loss_avg, decoder_loss_avg, total_loss_avg = 0, 0, 0
        for batch in valid_loader:
            x, c = [b.to(device) for b in batch]
            with torch.no_grad():
                _, mu_z, logvar_z, mu_y, logvar_y = model(x)
                kldiv_loss, decoder_loss = vae_loss(x, mu_z, logvar_z, mu_y, logvar_y)
                total_loss = 1 * kldiv_loss + decoder_loss

            kldiv_loss_avg += kldiv_loss.item()
            decoder_loss_avg += decoder_loss.item()
            total_loss_avg += total_loss.item()

        for i, loss in enumerate([total_loss_avg, kldiv_loss_avg, decoder_loss_avg]):
            figure.data[i + 3].x += tuple([epoch])
            figure.data[i + 3].y += tuple([loss / len(valid_loader)])

        print(f'valid loss={total_loss_avg / len(valid_loader):.8f}', end='\n')
    
    figure.show()

**Dataset 1, decoder with diagonal covariance matrix**

In [86]:
model = SimpleVAE(input_size=2, hidden_size=2, encoder_type=Encoder, decoder_type=MultivariateDecoder, device=device).to(device)
optimizer = Adam(model.parameters(), lr=0.001)
fit_model(model, optimizer, sample_data_1(), None, 15, 256, device)

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))

epoch=0,	train loss=3.97305792,	valid loss=2.55748921
epoch=1,	train loss=2.41298592,	valid loss=2.31990042
epoch=2,	train loss=2.28339027,	valid loss=2.24732587
epoch=3,	train loss=2.23940300,	valid loss=2.22935325
epoch=4,	train loss=2.22792080,	valid loss=2.22172744
epoch=5,	train loss=2.22500277,	valid loss=2.22229663
epoch=6,	train loss=2.22409276,	valid loss=2.22145236
epoch=7,	train loss=2.22326516,	valid loss=2.22157078
epoch=8,	train loss=2.22328481,	valid loss=2.22034789
epoch=9,	train loss=2.22330459,	valid loss=2.21992941
epoch=10,	train loss=2.22306013,	valid loss=2.22065513
epoch=11,	train loss=2.22235611,	valid loss=2.22039151
epoch=12,	train loss=2.22268748,	valid loss=2.22031891
epoch=13,	train loss=2.22278878,	valid loss=2.21989445
epoch=14,	train loss=2.22270736,	valid loss=2.22025747



In [87]:
plot_model(sample_data_1(), model, 'dataset 1')

**Dataset 1, decoder with scalar variance**

In [89]:
model = SimpleVAE(input_size=2, hidden_size=2, encoder_type=Encoder, decoder_type=ScalarDecoder, device=device).to(device)
optimizer = Adam(model.parameters(), lr=0.001)
fit_model(model, optimizer, sample_data_1(), None, 15, 256, device)

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))

epoch=0,	train loss=3.74396009,	valid loss=2.76644163
epoch=1,	train loss=2.67024063,	valid loss=2.62286560
epoch=2,	train loss=2.59824597,	valid loss=2.57492857
epoch=3,	train loss=2.54515529,	valid loss=2.49394435
epoch=4,	train loss=2.44503343,	valid loss=2.39131818
epoch=5,	train loss=2.35647051,	valid loss=2.32358406
epoch=6,	train loss=2.29781235,	valid loss=2.27411544
epoch=7,	train loss=2.26019982,	valid loss=2.25902234
epoch=8,	train loss=2.24806371,	valid loss=2.24474994
epoch=9,	train loss=2.24140509,	valid loss=2.24750786
epoch=10,	train loss=2.23740237,	valid loss=2.23572174
epoch=11,	train loss=2.23516031,	valid loss=2.23417625
epoch=12,	train loss=2.23144044,	valid loss=2.23398619
epoch=13,	train loss=2.23230693,	valid loss=2.23498652
epoch=14,	train loss=2.22973680,	valid loss=2.23197068



In [90]:
plot_model(sample_data_1(), model, 'dataset 1')

**Dataset 2, decoder with diagonal covariance matrix**

In [91]:
model = SimpleVAE(input_size=2, hidden_size=2, encoder_type=Encoder, decoder_type=MultivariateDecoder, device=device).to(device)
optimizer = Adam(model.parameters(), lr=0.001)
fit_model(model, optimizer, sample_data_2(), None, 15, 256, device)

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))

epoch=0,	train loss=5.06409428,	valid loss=3.19852514
epoch=1,	train loss=2.90337751,	valid loss=2.64277961
epoch=2,	train loss=2.55659272,	valid loss=2.46682310
epoch=3,	train loss=2.44374705,	valid loss=2.39987436
epoch=4,	train loss=2.38020342,	valid loss=2.34451194
epoch=5,	train loss=2.33585342,	valid loss=2.30548091
epoch=6,	train loss=2.29810619,	valid loss=2.27291144
epoch=7,	train loss=2.27110450,	valid loss=2.26087672
epoch=8,	train loss=2.25519741,	valid loss=2.24560566
epoch=9,	train loss=2.24624545,	valid loss=2.23951289
epoch=10,	train loss=2.24545497,	valid loss=2.23344197
epoch=11,	train loss=2.24084731,	valid loss=2.22821885
epoch=12,	train loss=2.23834909,	valid loss=2.22839344
epoch=13,	train loss=2.23314444,	valid loss=2.23184527
epoch=14,	train loss=2.23311105,	valid loss=2.22917719



In [92]:
plot_model(sample_data_2(), model, 'dataset 2')

**Dataset 2, decoder with scalar variance**

In [93]:
model = SimpleVAE(input_size=2, hidden_size=2, encoder_type=Encoder, decoder_type=ScalarDecoder, device=device).to(device)
optimizer = Adam(model.parameters(), lr=0.001)
fit_model(model, optimizer, sample_data_2(), None, 15, 256, device)

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))

epoch=0,	train loss=3.99841901,	valid loss=2.83237143
epoch=1,	train loss=2.70200061,	valid loss=2.63763939
epoch=2,	train loss=2.59350777,	valid loss=2.55727929
epoch=3,	train loss=2.49492181,	valid loss=2.44229436
epoch=4,	train loss=2.39161113,	valid loss=2.35539666
epoch=5,	train loss=2.31852104,	valid loss=2.29620071
epoch=6,	train loss=2.27489176,	valid loss=2.26461192
epoch=7,	train loss=2.25320236,	valid loss=2.24983038
epoch=8,	train loss=2.24358834,	valid loss=2.24422399
epoch=9,	train loss=2.23887596,	valid loss=2.24188453
epoch=10,	train loss=2.23782292,	valid loss=2.24389000
epoch=11,	train loss=2.23508962,	valid loss=2.23979218
epoch=12,	train loss=2.22999715,	valid loss=2.23610727
epoch=13,	train loss=2.23023388,	valid loss=2.23493235
epoch=14,	train loss=2.22738737,	valid loss=2.23492172



In [94]:
plot_model(sample_data_2(), model, 'dataset 2')

**Dataset 3, decoder with diagonal covariance matrix**

In [117]:
model = SimpleVAE(input_size=2, hidden_size=2, encoder_type=Encoder, decoder_type=MultivariateDecoder, device=device).to(device)
optimizer = Adam(model.parameters(), lr=0.001)
fit_model(model, optimizer, *sample_data_3(), 30, 256, device)

HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

epoch=0,	train loss=1.90229051,	valid loss=1.82470314
epoch=1,	train loss=1.82205700,	valid loss=1.82106242
epoch=2,	train loss=1.81669083,	valid loss=1.81487458
epoch=3,	train loss=1.78173854,	valid loss=1.72342337
epoch=4,	train loss=1.54735836,	valid loss=1.44805979
epoch=5,	train loss=1.37544150,	valid loss=1.33200757
epoch=6,	train loss=1.31565944,	valid loss=1.30128107
epoch=7,	train loss=1.27786661,	valid loss=1.27469171
epoch=8,	train loss=1.25910840,	valid loss=1.24972837
epoch=9,	train loss=1.24272541,	valid loss=1.24769099
epoch=10,	train loss=1.23735324,	valid loss=1.23438150
epoch=11,	train loss=1.22760285,	valid loss=1.22169984
epoch=12,	train loss=1.21755236,	valid loss=1.21916847
epoch=13,	train loss=1.21177353,	valid loss=1.21495055
epoch=14,	train loss=1.21358775,	valid loss=1.21929071
epoch=15,	train loss=1.20629108,	valid loss=1.20814534
epoch=16,	train loss=1.20631691,	valid loss=1.20785391
epoch=17,	train loss=1.20458078,	valid loss=1.20423227
epoch=18,	train loss

In [118]:
plot_labeled_model(*sample_data_3(), model, 'dataset 3')