In [8]:
import os
import sys
import pandas as pd
import numpy as np
import datetime  
import time
import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader
from torchsummary import summary
from utils.dataset import *
from typing import List
from torch.nn import functional as F
from scipy.stats import entropy
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [9]:
pairs = 1000
pairdata, freqpairs, n_size, n_interval = get_univ_data(pairs)
sizedata = get_data(pairdata, freqpairs, 'size_index', n_size)
size_cdf = pd.read_csv('data/univ/size_cdf.csv')
size_cdf = np.concatenate(([0], (size_cdf['size'].values[1:] + size_cdf['size'].values[:-1]) / 2))
mean_sizes = (sizedata * size_cdf).sum(axis=1)

In [10]:
class SizeEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dims, latent_dim):
        super(SizeEncoder, self).__init__()
        self.encoder = nn.ModuleList()
        in_dim = input_dim 
        for h_dim in hidden_dims:
            self.encoder.append(
                nn.Sequential(
                    nn.Linear(in_dim, out_features=h_dim),
                    nn.ReLU())
            )
            in_dim = h_dim
        self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1], latent_dim)
    
    def forward(self, x: Tensor) -> List[Tensor]:
        for module in self.encoder:
            x = module(x)
        mu = self.fc_mu(x)
        log_var = self.fc_var(x)
        return [mu, log_var]

class SizeDecoder(torch.nn.Module):
    def __init__(self, output_dim, hidden_dims, latent_dim):
        super(SizeDecoder, self).__init__()
        self.decoder = torch.nn.ModuleList()
        in_dim = latent_dim
        for h_dim in hidden_dims:
            self.decoder.append(
                nn.Sequential(
                    nn.Linear(in_dim, out_features=h_dim,),
                    nn.ReLU())
            )
            in_dim = h_dim
        self.output = nn.Linear(hidden_dims[-1], output_dim)
    
    def forward(self, x: Tensor) -> List[Tensor]:
        for module in self.decoder:
            x = module(x)
        result = self.output(x)
        result = F.softmax(result, dim=1)
        return result

In [11]:
def reparameterize(mu: Tensor, logvar: Tensor) -> Tensor:
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return eps * std + mu

def train(encoder, decoder, dataloader, optimizer):
    epoch_loss, epoch_kld, epoch_recon, epoch_size, sample_num = 0, 0, 0, 0, 0
    for data in dataloader:
        data = data.to(device)
        optimizer.zero_grad()
        mu, var = encoder(data)
        z = reparameterize(mu, var)
        y = decoder(z)
        recon_loss = F.l1_loss(y, data)
        kld_loss = torch.mean(-0.5 * torch.sum(1 + var - mu ** 2 - var.exp(), dim = 1), dim = 0)
        loss = recon_loss + kld_weight * kld_loss
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * len(data)
        epoch_kld += kld_loss.item() * len(data)
        epoch_recon += recon_loss.item() * len(data)
        sample_num += len(data)

    epoch_loss /= sample_num
    epoch_recon /= sample_num
    epoch_kld /= sample_num
    epoch_size /= sample_num
    return epoch_loss, epoch_recon, epoch_kld, epoch_size

kld_weight = 1e-5

In [12]:
def cramer_dis(x, y):
    cdf_x = np.cumsum(x)
    cdf_y = np.cumsum(y)
    return np.sum(np.abs(cdf_x - cdf_y)) / x.shape[0]

def js_dis(p, q):
    p = list(p)
    q = list(q)
    pq_max_len = max(len(p), len(q))
    p += [0.0] * (pq_max_len - len(p))
    q += [0.0] * (pq_max_len - len(q))
    assert (len(p) == len(q))
    m = np.sum([p, q], axis=0) / 2
    return 0.5 * entropy(p, m) + 0.5 * entropy(q, m)

def model_test(encoder, decoder):
    size_dis = []
    mean_size_dis = []
    for i in range(1000):
        size_data = sizedata[i]
        size_data = torch.tensor(size_data, dtype=torch.float).to(device).unsqueeze(0)
        mu, var = encoder(size_data)
        z = reparameterize(mu, var)
        new_size = decoder(z)
        new_size = new_size.cpu().detach().numpy().squeeze()
        new_size[new_size < 1e-3] = 0
        new_size /= new_size.sum()
        new_mean_size = (new_size * size_cdf).sum()
        size_dis.append(js_dis(new_size, sizedata[i]))
        mean_size_dis.append(np.abs(new_mean_size - mean_sizes[i]) / mean_sizes[i])
    return np.mean(size_dis), mean_size_dis

In [20]:
latent_dim = 32
hidden_dims = [64, 128, 256, 128, 64]
encoder = SizeEncoder(n_size, hidden_dims, latent_dim).to(device)
hidden_dims.reverse()
decoder = SizeDecoder(n_size, hidden_dims, latent_dim).to(device)
print('encoder:', summary(encoder, [[n_size]]))
print('decoder:', summary(decoder, [[latent_dim]]))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                   [-1, 64]           1,984
              ReLU-2                   [-1, 64]               0
            Linear-3                  [-1, 128]           8,320
              ReLU-4                  [-1, 128]               0
            Linear-5                  [-1, 256]          33,024
              ReLU-6                  [-1, 256]               0
            Linear-7                  [-1, 128]          32,896
              ReLU-8                  [-1, 128]               0
            Linear-9                   [-1, 64]           8,256
             ReLU-10                   [-1, 64]               0
           Linear-11                   [-1, 32]           2,080
           Linear-12                   [-1, 32]           2,080
Total params: 88,640
Trainable params: 88,640
Non-trainable params: 0
---------------------------------

In [24]:
dataset = torch.tensor(sizedata, dtype=torch.float)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
lr = 1e-3
optimizer = torch.optim.Adam([{'params': encoder.parameters()}, {'params': decoder.parameters()}], lr=lr)

In [25]:
from tqdm import tqdm
def get_dis(decoder, latent_dim, sizedata, pairs):
    size_dis = np.zeros((pairs, pairs))
    decoder.eval()
    mean_sizes = []
    for i in range(pairs):
        z = torch.randn((1, latent_dim)).to(device)
        size = decoder(z)
        size = size.squeeze().detach().to('cpu').numpy()
        size[size < 1e-3] = 0
        size /= size.sum()
        mean_size = (size * size_cdf).sum()
        mean_sizes.append(mean_size)
        for j in range(pairs):
            loss = cramer_dis(size, sizedata[j:j+1])
            size_dis[i][j] = loss
    return size_dis, mean_sizes

In [27]:
stop_loss = 1e-3
encoder.train()
decoder.train()
start_time = time.time()
min_epoch_loss = 100
for epoch in range(100001):
    epoch_loss, epoch_recon, epoch_kld, epoch_size = train(encoder, decoder, dataloader, optimizer)
    min_epoch_loss = min(epoch_loss, min_epoch_loss)
    if epoch and epoch % 100 == 0:
        size_dis, mean_size_dis = model_test(encoder, decoder)
        cur_time = time.time()
        print("epoch=%d, loss=%.2e, min_loss=%.2e, kld=%.2f, recon=%.2e, size=%.2e, size_dis=%.3f, mean_size=%.2f, max_size=%.2f(%d), time=%.2f" % (epoch, epoch_loss, min_epoch_loss, epoch_kld, epoch_recon, epoch_size, size_dis, np.mean(mean_size_dis), np.max(mean_size_dis), np.argmax(mean_size_dis), cur_time - start_time))
        min_epoch_loss = 100

    if epoch and epoch % 1000 == 0:
        size_dis, pred_mean_sizes = get_dis(decoder, latent_dim, sizedata[:1000], 1000)
        print("%d, %d, %.2e, %.2e, %.2e, %.2e, %.2e, %.2e, %.2e, %.2e" % (np.mean(pred_mean_sizes), np.mean(mean_sizes), np.sort(np.min(size_dis, axis=0))[-100], np.sort(np.min(size_dis, axis=0))[-50], np.sort(np.min(size_dis, axis=0))[-10], np.sort(np.min(size_dis, axis=0))[-1], np.sort(np.min(size_dis, axis=1))[-100], np.sort(np.min(size_dis, axis=1))[-50], np.sort(np.min(size_dis, axis=1))[-10], np.sort(np.min(size_dis, axis=1))[-1]))
    if epoch_loss < stop_loss:
        size_dis, mean_size_dis = model_test(encoder, decoder)
        print("epoch=%d, loss=%.2e, min_loss=%.2e, kld=%.2f, recon=%.2e, size=%.2e, size_dis=%.3f, mean_size=%.2f, max_size=%.2f(%d), time=%.2f" % (epoch, epoch_loss, min_epoch_loss, epoch_kld, epoch_recon, epoch_size, size_dis, np.mean(mean_size_dis), np.max(mean_size_dis), np.argmax(mean_size_dis), cur_time - start_time))
        size_dis, pred_mean_sizes = get_dis(decoder, latent_dim, sizedata[:1000], 1000)
        print("%d, %d, %.2e, %.2e, %.2e, %.2e, %.2e, %.2e, %.2e, %.2e" % (np.mean(pred_mean_sizes), np.mean(mean_sizes), np.sort(np.min(size_dis, axis=0))[-100], np.sort(np.min(size_dis, axis=0))[-50], np.sort(np.min(size_dis, axis=0))[-10], np.sort(np.min(size_dis, axis=0))[-1], np.sort(np.min(size_dis, axis=1))[-100], np.sort(np.min(size_dis, axis=1))[-50], np.sort(np.min(size_dis, axis=1))[-10], np.sort(np.min(size_dis, axis=1))[-1]))
        break

epoch=100, loss=2.82e-03, min_loss=2.81e-03, kld=18.54, recon=2.64e-03, size=0.00e+00, size_dis=0.015, mean_size=0.24, max_size=70.26(998), time=4.06
epoch=200, loss=2.86e-03, min_loss=2.76e-03, kld=18.05, recon=2.68e-03, size=0.00e+00, size_dis=0.015, mean_size=0.31, max_size=134.68(998), time=8.07
epoch=300, loss=2.84e-03, min_loss=2.78e-03, kld=18.66, recon=2.65e-03, size=0.00e+00, size_dis=0.015, mean_size=0.28, max_size=108.04(998), time=12.02
epoch=400, loss=2.97e-03, min_loss=2.76e-03, kld=18.88, recon=2.78e-03, size=0.00e+00, size_dis=0.016, mean_size=0.26, max_size=44.64(496), time=15.96
epoch=500, loss=3.13e-03, min_loss=2.83e-03, kld=18.85, recon=2.94e-03, size=0.00e+00, size_dis=0.016, mean_size=0.23, max_size=36.16(497), time=19.89
epoch=600, loss=2.95e-03, min_loss=2.85e-03, kld=18.32, recon=2.77e-03, size=0.00e+00, size_dis=0.016, mean_size=0.21, max_size=23.55(496), time=26.10
epoch=700, loss=3.00e-03, min_loss=2.88e-03, kld=19.36, recon=2.80e-03, size=0.00e+00, size_di