In [1]:
import os
current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_dir,os.pardir))
os.chdir(parent_dir)

In [2]:
import numpy as np
import pandas as pd
import torch
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
from vae_earlystopping import EarlyStopping
from model.m2_bce import BCEcVAE
from model.m2_mse import MSEcVAE
from loss.l2_bce import l2_bce
from loss.l2_mse import l2_mse
import joblib
import torch
import numpy as np
import pandas as pd
from sklearn.metrics import r2_score, mean_absolute_error


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### MSE 구하는 방법

In [4]:
train_loader = torch.load('torch/pre_retrain_loader_min.pt',weights_only=False)
val_loader = torch.load('torch/pre_reval_loader_min.pt',weights_only=False)
test_loader = torch.load('torch/pre_retest_loader_min.pt',weights_only=False)

In [5]:
x_sample,c_sample = next(iter(train_loader))
x_dim = x_sample.shape[1]
c_dim = c_sample.shape[1]
x_dim,c_dim

(23, 15)

In [6]:
model = MSEcVAE(x_dim,c_dim,z_dim=8).to(device)
early_stopping = EarlyStopping(patience=40,min_delta = 1e-9)
optimizer = optim.Adam(model.parameters(),lr = 1e-3, weight_decay=1e-5)
epochs = 800

In [7]:
for epoch in range(1,epochs+1):
    model.train()
    t_loss = 0
    for x,c in train_loader:
        x,c = x.to(device),c.to(device)
        x_hat,mu,logvar = model(x,c)
        loss_dict = l2_mse(x_hat,x,mu,logvar)
        optimizer.zero_grad()
        loss_dict['loss'].backward()
        optimizer.step()
        t_loss += loss_dict['loss'].item()
    model.eval()
    v_loss = 0
    with torch.no_grad():
        for v_x,v_c in val_loader:
            v_x,v_c = v_x.to(device),v_c.to(device)
            x_hat,v_mu,v_logvar = model(v_x,v_c)
            loss_dict = l2_mse(x_hat,v_x,v_mu,v_logvar)
            v_loss += loss_dict['loss'].item()
        avg_train_loss = t_loss/len(train_loader)
        avg_val_loss = v_loss/len(val_loader)

    if epoch % 20 ==0  or epoch ==2:
        print(f'Epoch [{epoch}/{epochs}]|Train:{avg_train_loss:.4f} |Val:{avg_val_loss:.4f}')
    if early_stopping(avg_val_loss,model):
        break

Epoch [2/800]|Train:0.1051 |Val:0.1278
EarlyStopping counter: 1 out of 40
Epoch [20/800]|Train:0.0807 |Val:0.1068
Epoch [40/800]|Train:0.0610 |Val:0.0851
EarlyStopping counter: 1 out of 40
Epoch [60/800]|Train:0.0520 |Val:0.0730
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
Epoch [80/800]|Train:0.0473 |Val:0.0659
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
Epoch [100/800]|Train:0.0446 |Val:0.0616
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 

In [8]:
early_stopping.load_best_model(model)
model.eval()
mse_logit_list = []
x_true_list = []
with torch.no_grad():
    for x_t, c_t in test_loader:
        x_t,c_t = x_t.to(device),c_t.to(device)
        x_hat,mu_t,logvar_t = model(x_t,c_t)
        mse_logit_list.append(x_hat.cpu().numpy())
        x_true_list.append(x_t.cpu().numpy())
mse_logits = np.vstack(mse_logit_list)
x_true = np.vstack(x_true_list)

Restored best model with loss: 0.049206


In [9]:
x_scaler = joblib.load('torch/min_x_scaler.pkl')

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [10]:
x_hat_fin = x_scaler.inverse_transform(mse_logits)
x_true = x_scaler.inverse_transform(x_true)

In [11]:
from sklearn.metrics import r2_score,mean_squared_error
r2_mse = r2_score(x_true.flatten(),x_hat_fin.flatten())
print(f"{r2_mse:.4f}")

0.8469
