In [1]:
import os
current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_dir,os.pardir))
os.chdir(parent_dir)

In [2]:
import numpy as np
import pandas as pd
import torch
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
from vae_earlystopping import EarlyStopping
from model.m26_bin import MultiDecoderCondVAE
from loss.l26oss_all import integrated_loss_fn


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
train_loader = torch.load('torch/pre_retrain_loader.pt',weights_only=False)
val_loader = torch.load('torch/pre_reval_loader.pt',weights_only=False)
test_loader = torch.load('torch/pre_retest_loader.pt',weights_only=False)


In [None]:
x_sample,c_sample = next(iter(train_loader))
x_dim = x_sample.shape[1]
c_dim = c_sample.shape[1]
x_dim,c_dim

In [None]:
model = MultiDecoderCondVAE(x_dim,c_dim,z_dim=8).to(device)
early_stopping = EarlyStopping(patience=40,min_delta = 1e-9)
optimizer = optim.Adam(model.parameters(),lr = 1e-12, weight_decay=1e-5)

In [None]:
history = {'train_loss':[],'train_bce':[],'train_mse':[],'train_kl':[],'val_loss':[],'val_bce':[],'val_mse':[],'val_kl':[]}
epochs = 600

In [None]:
for epoch in range(1,epochs+1):
    model.train()
    t_loss,t_mse,t_bce,t_kl = 0,0,0,0
    for x,c in train_loader:
        x,c = x.to(device),x.to(device)
        optimizer.zero_grad()
        bce_logit,binary_out,x_hat,z_mu,z_logvar = model(x,c)
        loss_dict = integrated_loss_fn(bce_logit,x_hat, x,mu,logvar)
        loss_dict['loss'].backward()
        optimizer.step()
        t_loss +=loss_dict['loss'].item()
        t_mse +=loss_dict['mse_loss'].item()
        t_bce +=loss_dict['bce_loss'].item()
        t_kl +=loss_dict['kl_loss'].item()

    model.eval()
    v_loss,v_mse,v_bce,v_kl = 0,0,0,0
    x_true_all,x_pred_all,x_hat_all = [],[],[]
    threshold = 0.5
    with torch.no_grad():
        for v_x, v_c in val_loader:
            v_x,v_c = v_x.to(device),v_c.to(device)
            v_bce_logit,v_binary_out,v_x_hat,v_mu,v_logvar = model(v_x,v_c)
            loss_dict = integrated_loss_fn(v_bce_logit, v_x_hat, v_x,v_mu,v_logvar)
            v_loss += loss_dict['loss'].item()
            v_mse += loss_dict['mse_loss'].item()
            v_bce += loss_dict['bce_loss'].item()
            v_kl += loss_dict['kl']
