In [None]:
import numpy as np 
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split
import torch.optim as optimizer
from collections import OrderedDict
from torch.optim.lr_scheduler import ReduceLROnPlateau

import random
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import time
from scipy.optimize import root
import math
import os
import copy

In [None]:
%run "C:\SBU-3\Jupyter-Research\RSCR\\2 Machine Learning\\2 VAE Anar\\VAE Model.ipynb"
%run "C:\SBU-3\Jupyter-Research\RSCR\Functions.ipynb"

In [None]:
class VAEDataset(Dataset):
    def __init__(self):
        
        paths = np.load('C:\SBU-3\Jupyter-Research\RSCR\Saved Data\\555\\5_Path_555.npy')
        mechs = np.load('C:\SBU-3\Jupyter-Research\RSCR\Saved Data\\555\\5_Mec_555.npy')
        paths = paths.reshape((paths.shape[0], paths.shape[1]*paths.shape[2]))
        
        if mechs.shape[1] == 16:
            mechs = mechs.reshape((mechs.shape[0], mechs.shape[1]))
        else:
            mechs = mechs.reshape((mechs.shape[0], mechs.shape[1]*mechs.shape[2]))
            
        self.data = np.hstack((paths, mechs))
        random.shuffle(self.data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        #path_mech = {'data':self.data[idx]}
        #return path_mech
        path_mech = self.data[idx]
        return path_mech[:300].astype(np.float32), path_mech[300:].astype(np.float32)

class CustomDataset(Dataset):
    def __init__(self, latent_dim):
        
        self.latent_dim = latent_dim
        self.data = []
        for file_name in os.listdir('latents'):
            if file_name.endswith('.npy'):
                file_path = os.path.join('latents', file_name)
                data = np.load(file_path)
                self.data.extend(data)
                
        random.shuffle(self.data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        path_mech = self.data[idx]
        return path_mech[:self.latent_dim].astype(np.float32), path_mech[self.latent_dim:].astype(np.float32)


In [None]:
def loss_eval_vae(data_loader, model, loss_fn):
    model.eval()
    overall_loss = 0
    kld_loss = 0
    mse_loss = 0 
    
    data_inx = 0
    with torch.no_grad():
        for batch_inx, data in enumerate(data_loader):
            data_inx += 1
            #x = torch.flatten(x[0], start_dim=1)
            x = data[0].to(DEVICE)
            
            xhat, mu, sigma  = model(x)
            loss, kld, mse = loss_fn(xhat, x, mu, sigma)
            overall_loss += loss.item()
            kld_loss += kld.item()
            mse_loss += mse.item()
        
    overall_loss = overall_loss/data_inx
    kld_loss = kld_loss/data_inx
    mse_loss = mse_loss/data_inx
    
    return overall_loss, kld_loss, mse_loss

In [None]:
def VAE_Train(train_loader, val_loader, BATCH_SIZE, LR, EPOCH, model):
    start_time = time.time()
    'Define loss function'
    def loss_fn(x_recon, x, mean, logvar):
        kld_loss = -0.5*torch.mean(logvar - logvar.exp() - mean.pow(2) + 1) #kl loss = -0.5*(log(sigma^2) - sigma^2 - mu^2 +1)
        #mse_loss = nn.functional.binary_cross_entropy(x_recon, x, reduction='sum')
        mse_loss = nn.MSELoss(reduction='sum')(x_recon, x)
        #什么时候用mse什么时候用bce？？
        return kld_loss+mse_loss, kld_loss, mse_loss
    
    'Define optimizer'
    optim = optimizer.Adam(model.parameters(), lr=LR)
    scheduler = ReduceLROnPlateau(optim, mode='min', factor=0.1, patience=7, threshold=1e-4, cooldown=10)
    epochs = EPOCH
    
#     'Print starting loss'
#     train_loss, train_kld, train_mse = loss_eval_vae(train_loader, model, loss_fn)
#     print("\tStart Train Loss:", train_loss, "\tStart KLD:", train_kld,"\tStart MSE:", train_mse)
#     #wandb.log({"Train loss": train_loss, "Train KLD":train_kld, "Train MSE": train_mse})
    
#     val_loss, val_kld, val_mse = loss_eval_vae(val_loader, model, loss_fn)
#     print("\tStart Val Loss:", val_loss, "\tStart KLD:", val_kld,"\tStart MSE:", val_mse)
#     #wandb.log({"Val loss":val_loss, "Val KLD":val_kld, "Val MSE": val_mse})
#     print('-----------------------------------------------------------------------------------------------------')
    
    'Train'
    model.train()
    for epoch in range(epochs):
        train_loss = 0
        train_kld = 0
        train_mse = 0
        
        train_inx = 0
        for batch_inc, data in enumerate(train_loader):
            train_inx += 1
            #x = torch.flatten(x[0], start_dim=1)
            x = data[0].to(DEVICE)
            
            optim.zero_grad()
            xhat, mu, sigma = model(x)
            loss, kld, mse = loss_fn(xhat, x, mu, sigma)
            train_loss += loss.item()
            train_kld += kld.item()
            train_mse += mse.item()
            
            loss.backward()
            optim.step()  
        train_loss = train_loss/train_inx
        train_kld = train_kld/train_inx
        train_mse = train_mse/train_inx
        #wandb.log({"Train loss": train_loss, "Train KLD": train_kld, "Train MSE": train_mse})
        
        'Validation in Train'
        val_loss, val_kld, val_mse = loss_eval_vae(val_loader, model, loss_fn)
        #wandb.log({"Val loss":val_loss, "Val KLD":val_kld, "Val MSE": val_mse})   
        
        if epoch%50 == 0:
            print('-----------------------------------------------------------------------------------------------------')
            print("\tEpoch:", epoch, "complete!", "\tTime Taken:", round((time.time() - start_time)/60,2),"mins")
            print("\tTrain Loss:", train_loss, "\tTrain KLD:", train_kld,"\tTrain MSE:", train_mse)
            print("\tVal Loss:", val_loss, "\tVal KLD:", val_kld,"\tVal MSE:", val_mse)
        
        #scheduler.step(val_loss)
        #early_stopping(val_loss, model)
        #if early_stopping.early_stop:
            #print("Early stopping")
    print("Training Complete!")
    
#     'Test'
#     print("Running Test...")
#     test_loss, test_kld, test_mse = loss_eval_vae(test_loader, model, loss_fn)
#     print("\tTest Loss:", test_loss, "\tTest KLD:",test_kld,"\tTest MSE:",test_mse)
    #wandb.log({"Test loss": test_loss, "Test KLD": test_kld, "Test MSE": test_mse})
    #wandb.finish()
    return test_loss, test_kld, test_mse

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else "cpu")

In [None]:
LOSS = []

In [None]:
input_size = 300
joint_size = 21
batch_size = 512
latent_dim = 64

dataset = VAEDataset()
dataset_size = len(dataset)
train_size = int(0.8 * dataset_size)
val_size = dataset_size - train_size

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
val_loader = DataLoader(val_dataset, batch_size=batch_size)


In [None]:
'Train'
model_1 = VAE(input_size, latent_dim)
model_1 = model_1.to(DEVICE)
LR = 1e-3
EPOCH = 10

loss_temp = VAE_Train(train_loader, val_loader, batch_size, LR, EPOCH, model_1)
LOSS.append("Train 1")
LOSS.append(loss_temp)