In [1]:
import torch; torch.manual_seed(0)
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.distributions
import torchvision

from sklearn.neighbors import NearestNeighbors

import numpy as np
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200

from path import Path
import os
import json


In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [19]:
class VAE(nn.Module):
    def __init__(self, latent_dims):
        super(VAE, self).__init__()
        self.latent_dims = latent_dims
        self.fc1 = nn.Linear(800, 200)
        self.fc21 = nn.Linear(200, latent_dims)
        self.fc22 = nn.Linear(200, latent_dims)
        
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        
        
        self.fc3 = nn.Linear(latent_dims, 200)
        self.fc4 = nn.Linear(200, 800)
        
    def encode(self, x):
        x = torch.flatten(x, start_dim=1)
        x = self.relu(self.fc1(x))
        mu = self.fc21(x)
        sigma = self.fc22(x)
        return mu, sigma
    
    #sigma = logvar
    def reparameterize(self, mu, sigma):
        std = torch.exp(sigma * 0.5)
        eps = torch.randn_like(std)
        res = eps * std + mu
        return res
    
    
    def decode(self, z):
        z = self.relu(self.fc3(z))
        return self.sigmoid(self.fc4(z))
    
    def forward(self, x):
        mu, sigma = self.encode(x)
        z = self.reparameterize(mu, sigma)
        return self.decode(z), mu, sigma
        

In [30]:
def loss_function(z, x, mu, sigma):
    BCE = F.binary_cross_entropy(z, torch.flatten(x, start_dim=1), reduction='sum')
    #     BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD_element = (1 + sigma - mu.pow(2) - sigma.exp())
    KLD =  -0.5 * torch.sum(KLD_element)
    return BCE + KLD

In [23]:
def train(model, dataloader, epochs=30):
    optim = torch.optim.Adam(model.parameters(), lr=0.003)
    for epoch in range(epochs):
        print(epoch)
        model.train()
        for i, data, in enumerate(dataloader, 0):
            #inputs, kNN = data['points'].to(device).float(), data['KNN'].to(device).int()
            inputs = data['points'].to(device).float()
            optim.zero_grad()
            x_hat, mu, sigma = model(inputs.transpose(1, 2))
            loss = loss_function(x_hat, inputs, mu, sigma)
            loss.backward()
            optim.step()


In [20]:
latent_dims = 5
vae = VAE(5).to(device)

# Dataset

In [10]:
class DRPointData(Dataset):
    def __init__(self, root_dir, valid=False):
        self.root_dir = root_dir
        self.valid = valid
        self.files = [_ for _ in os.listdir(root_dir) if _.endswith('.json')]
        self.files = self.files[:4000]
        
    def __len__(self):
        return len(self.files)
    
    def __preproc__(self, file):
        with open(file, encoding="UTF-8") as f:
            
            np_points = np.array(json.load(f, strict=False))
            nbrs = NearestNeighbors(n_neighbors=11, algorithm='auto').fit(np_points)
            matrix = torch.from_numpy(nbrs.kneighbors_graph(np_points).toarray())
            points = torch.from_numpy(np_points)
            max_val = torch.max(points, -2).values.view(1, -1)
            min_val = torch.min(points, -2).values.view(1, -1)
            diff = max_val - min_val 
            points = (points - min_val) / diff
        return {'points': points, #[400, 2]
               'KNN':matrix}
        
    def __getitem__(self, idx):
        json_file = os.path.join(self.root_dir, self.files[idx])
        item = self.__preproc__(json_file)
        return item
    
    def __filename__(self, idx):
        return self.files[idx]
        

In [11]:
path = Path("data_0610")
train_dr = DRPointData(path)
len(train_dr)

4000

In [12]:
dataloader = DataLoader(dataset=train_dr, batch_size=32, drop_last=True)

# train

In [None]:
vae = train(vae, dataloader, epochs=10)