In [1]:
import os
current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_dir,os.pardir))
os.chdir(parent_dir)

In [2]:
import numpy as np
import pandas as pd
import torch
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
from vae_earlystopping import EarlyStopping
from model.m2_bce import BCEcVAE
from model.m2_mse import MSEcVAE
from loss.l2_bce import l2_bce
from loss.l2_mse import l2_mse
import joblib
import torch
import numpy as np
import pandas as pd
from sklearn.metrics import r2_score, mean_absolute_error


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
train_loader = torch.load('torch/pre_retrain_loader_min.pt',weights_only=False)
val_loader = torch.load('torch/pre_reval_loader_min.pt',weights_only=False)
test_loader = torch.load('torch/pre_retest_loader_min.pt',weights_only=False)


In [5]:
x_sample,c_sample = next(iter(train_loader))
x_dim = x_sample.shape[1]
c_dim = c_sample.shape[1]
x_dim,c_dim

(23, 15)

### BCE모델에 대해서 정함

In [6]:
model = BCEcVAE(x_dim,c_dim,z_dim=8).to(device)
early_stopping = EarlyStopping(patience=40,min_delta = 1e-9)
optimizer = optim.Adam(model.parameters(),lr = 1e-3, weight_decay=1e-5)
epochs = 600

### train_val loader에서의 학습

In [7]:
for epoch in range(1,epochs+1):
    model.train()
    t_loss= 0
    for x,c in train_loader:
        x,c = x.to(device),c.to(device)
        optimizer.zero_grad()
        bce_logit, mu, logvar = model(x,c)
        loss_dict = l2_bce(bce_logit, x,mu,logvar)
        loss_dict['loss'].backward()
        optimizer.step()
        t_loss +=loss_dict['loss'].item()
    model.eval()
    v_loss = 0
    with torch.no_grad():
        for v_x, v_c in val_loader:
            v_x,v_c = v_x.to(device),v_c.to(device)
            v_bce_logit, v_mu, v_logvar = model(v_x,v_c)
            loss_dict = l2_bce(v_bce_logit, v_x, v_mu,v_logvar)
            v_loss += loss_dict['loss'].item()
    avg_train_loss = t_loss/len(train_loader)
    avg_val_loss = v_loss/len(val_loader)

    if epoch % 20 ==0  or epoch ==2:
        print(f'Epoch [{epoch}/{epochs}]|Train:{avg_train_loss:.4f} |Val:{avg_val_loss:.4f}')
    if early_stopping(avg_val_loss,model):
        break


Epoch [2/600]|Train:3.5288 |Val:3.3865
EarlyStopping counter: 1 out of 40
Epoch [20/600]|Train:2.5546 |Val:2.6761
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
EarlyStopping counter: 3 out of 40
Epoch [40/600]|Train:2.1915 |Val:2.3479
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
Epoch [60/600]|Train:1.9229 |Val:2.1250
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
E

In [8]:
x_scaler = joblib.load('torch/min_x_scaler.pkl')

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [38]:
from bce_metrics.bce_solve import eval_bce_metrics
import numpy as np
import torch
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

In [44]:
early_stopping.load_best_model(model)
model.eval()
all_x_hat = []
all_x_true = []
with torch.no_grad():
    for x_t, c_t in test_loader:
        x_t,c_t = x_t.to(device),c_t.to(device)
        bce_logit_t,mu_t,logvar_t = model(x_t,c_t)
        x_true = (x_t>0).float()
        all_x_hat.append(bce_logit_t.detach().cpu())
        all_x_true.append(x_true.detach().cpu())
all_x_hat = torch.cat([a.flatten() for a in all_x_hat])
all_x_true = torch.cat([a.flatten() for a in all_x_true])

metrics = eval_bce_metrics(all_x_hat,all_x_true,threshold=0.5)
metrics

Restored best model with loss: 0.798294


{'threshold': 0.5,
 'bce': 0.028044994920492172,
 'tp': 1650,
 'fp': 92,
 'tn': 25286,
 'fn': 158,
 'precision': 0.9471871412115547,
 'recall': 0.9126106194639789,
 'f1': 0.9295774597852237,
 'accuracy': 0.990804090340252}