In [1]:
import os
current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_dir,os.pardir))
os.chdir(parent_dir)

In [2]:
import numpy as np
import pandas as pd
import torch
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
from vae_earlystopping import EarlyStopping
from model.m2_bce import BCEcVAE
from model.m2_mse import MSEcVAE
from loss.l2_bce import l2_bce
from loss.l2_mse import l2_mse
import joblib
import torch
import numpy as np
import pandas as pd
from sklearn.metrics import r2_score, mean_absolute_error


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
train_loader = torch.load('torch/pre_retrain_loader_min.pt',weights_only=False)
val_loader = torch.load('torch/pre_reval_loader_min.pt',weights_only=False)
test_loader = torch.load('torch/pre_retest_loader_min.pt',weights_only=False)


In [5]:
x_sample,c_sample = next(iter(train_loader))
x_dim = x_sample.shape[1]
c_dim = c_sample.shape[1]
x_dim,c_dim

(23, 15)

### BCE모델에 대해서 정함

In [6]:
model = BCEcVAE(x_dim,c_dim,z_dim=8).to(device)
early_stopping = EarlyStopping(patience=40,min_delta = 1e-9)
optimizer = optim.Adam(model.parameters(),lr = 1e-3, weight_decay=1e-5)
epochs = 600

### train_val loader에서의 학습

In [7]:
for epoch in range(1,epochs+1):
    model.train()
    t_loss= 0
    for x,c in train_loader:
        x,c = x.to(device),c.to(device)
        optimizer.zero_grad()
        bce_logit, mu, logvar = model(x,c)
        loss_dict = l2_bce(bce_logit, x,mu,logvar)
        loss_dict['loss'].backward()
        optimizer.step()
        t_loss +=loss_dict['loss'].item()
    model.eval()
    v_loss = 0
    with torch.no_grad():
        for v_x, v_c in val_loader:
            v_x,v_c = v_x.to(device),v_c.to(device)
            v_bce_logit, v_mu, v_logvar = model(v_x,v_c)
            loss_dict = l2_bce(v_bce_logit, v_x, v_mu,v_logvar)
            v_loss += loss_dict['loss'].item()
    avg_train_loss = t_loss/len(train_loader)
    avg_val_loss = v_loss/len(val_loader)

    if epoch % 20 ==0  or epoch ==2:
        print(f'Epoch [{epoch}/{epochs}]|Train:{avg_train_loss:.4f} |Val:{avg_val_loss:.4f}')
    if early_stopping(avg_val_loss,model):
        break


Epoch [2/600]|Train:3.5493 |Val:3.3952
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
Epoch [20/600]|Train:2.5585 |Val:2.6825
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
Epoch [40/600]|Train:2.1536 |Val:2.3833
EarlyStopping counter: 3 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
Epoch [60/600]|Train:1.8257 |Val:2.0252
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
E

In [8]:
x_scaler = joblib.load('torch/min_x_scaler.pkl')

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [9]:
from bce_metrics.bce_solve import eval_bce_metrics

In [10]:
early_stopping.load_best_model(model)
model.eval()
all_x_hat = []
all_x_true = []
with torch.no_grad():
    for x_t, c_t in test_loader:
        x_t,c_t = x_t.to(device),c_t.to(device)
        bce_logit_t,mu_t,logvar_t = model(x_t,c_t)
        x_true = (x_t>0).float()
        all_x_hat.append(bce_logit_t.detach().cpu())
        all_x_true.append(x_true.detach().cpu())
all_x_hat_1 = torch.cat([a.flatten() for a in all_x_hat])
all_x_true = torch.cat([a.flatten() for a in all_x_true])

metrics = eval_bce_metrics(all_x_hat_1,all_x_true,threshold=0.5)
metrics

Restored best model with loss: 0.887356


{'threshold': 0.5,
 'bce': 0.030722837895154953,
 'tp': 1652,
 'fp': 106,
 'tn': 25272,
 'fn': 156,
 'precision': 0.9397042093234375,
 'recall': 0.9137168141542383,
 'f1': 0.9265283180468241,
 'accuracy': 0.9903626866766018}

### MSE 구하는 방법

In [11]:
train_loader = torch.load('torch/pre_retrain_loader_min.pt',weights_only=False)
val_loader = torch.load('torch/pre_reval_loader_min.pt',weights_only=False)
test_loader = torch.load('torch/pre_retest_loader_min.pt',weights_only=False)

In [12]:
x_sample,c_sample = next(iter(train_loader))
x_dim = x_sample.shape[1]
c_dim = c_sample.shape[1]
x_dim,c_dim

(23, 15)

In [13]:
model = MSEcVAE(x_dim,c_dim,z_dim=8).to(device)
early_stopping = EarlyStopping(patience=40,min_delta = 1e-9)
optimizer = optim.Adam(model.parameters(),lr = 1e-3, weight_decay=1e-5)
epochs = 800

In [14]:
for epoch in range(1,epochs+1):
    model.train()
    t_loss = 0
    for x,c in train_loader:
        x,c = x.to(device),c.to(device)
        x_hat,mu,logvar = model(x,c)
        loss_dict = l2_mse(x_hat,x,mu,logvar)
        optimizer.zero_grad()
        loss_dict['loss'].backward()
        optimizer.step()
        t_loss += loss_dict['loss'].item()
    v_loss = 0
    model.eval()
    with torch.no_grad():
        for v_x,v_c in val_loader:
            v_x,v_c = v_x.to(device),v_c.to(device)
            x_hat,v_mu,v_logvar = model(v_x,v_c)
            loss_dict = l2_mse(x_hat,v_x,v_mu,v_logvar)
            v_loss += loss_dict['loss'].item()
    avg_train_loss = t_loss/len(train_loader)
    avg_val_loss = v_loss/len(val_loader)

    if epoch % 20 ==0  or epoch ==2:
        print(f'Epoch [{epoch}/{epochs}]|Train:{avg_train_loss:.4f} |Val:{avg_val_loss:.4f}')
    if early_stopping(avg_val_loss,model):
        break

Epoch [2/800]|Train:0.1051 |Val:0.1286
Epoch [20/800]|Train:0.0741 |Val:0.0991
Epoch [40/800]|Train:0.0554 |Val:0.0782
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
Epoch [60/800]|Train:0.0481 |Val:0.0678
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
EarlyStopping counter: 3 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
Epoch [80/800]|Train:0.0445 |Val:0.0628
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
Epoch [100/800]|Train:0.0423 |Val:0.0599
EarlyStopping counter: 3 out of 40
EarlyStopping counter: 4 

In [15]:
early_stopping.load_best_model(model)
model.eval()
mse_logit_list = []
x_true_list = []
with torch.no_grad():
    for x_t, c_t in test_loader:
        x_t,c_t = x_t.to(device),c_t.to(device)
        x_hat,mu_t,logvar_t = model(x_t,c_t)
        mse_logit_list.append(x_hat.cpu().numpy())
        x_true_list.append(x_t.cpu().numpy())
mse_logits = np.vstack(mse_logit_list)
x_true = np.vstack(x_true_list)

Restored best model with loss: 0.047878


### BCE-> 0,1 표현하여 MSE에 곱하기

In [16]:
all_x_hat_tensor = torch.cat(all_x_hat, dim=0) # (Total_Samples, x_dim) 형태로 결합
bce_logits_np = all_x_hat_tensor.numpy()
def sigmoid(x):
    return 1 / (1 + np.exp(-x))
bce_prob = sigmoid(bce_logits_np)
bce_binary = (bce_prob >= 0.5).astype(np.float32)
# final_x_hat = bce_binary * mse_logits

  return 1 / (1 + np.exp(-x))


In [17]:
x_hat_fin = x_scaler.inverse_transform(mse_logits)
x_true = x_scaler.inverse_transform(x_true)
final_x_hat = x_hat_fin*bce_binary

In [19]:
final_x_sig = x_hat_fin*bce_prob

In [20]:
from sklearn.metrics import r2_score,mean_squared_error
r2_mse = r2_score(x_true.flatten(),x_hat_fin.flatten())
r2_bce_mse = r2_score(x_true.flatten(), final_x_hat.flatten())
r2_bce_mse_sig = r2_score(x_true.flatten(),final_x_sig.flatten())
print(f"r2_bce_binary_mse: {r2_bce_mse:.4f},r2_mse: {r2_mse:.4f},r2_bce_sigmoid_mse:{r2_bce_mse_sig:.4f}")

r2_bce_binary_mse: 0.8607,r2_mse: 0.8630,r2_bce_sigmoid_mse:0.8473
