In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.ensemble import RandomForestRegressor
from sklearn.multioutput import MultiOutputRegressor
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

In [2]:
import os
par_dir = os.path.abspath(os.path.join(os.getcwd(),os.pardir))
os.chdir(par_dir)
from multilayer_model.m2_bce import BCEcVAE
from multilayer_model.m2_mse import MSEcVAE
from multilayer_loss.l2_bce import l2_bce
from multilayer_loss.l2_mse import l2_mse
from vae_earlystopping import EarlyStopping

In [3]:
x_data = np.load('./data/metal.npy')      # Target (X)
x2_data = np.load('./data/support_norm.npy')
c_data = np.load('./data/pre_re_change_temp_logconst.npy') # Condition (C)

# 데이터 분할 (Train/Val/Test)
c_train, c_temp, x_train, x_temp, x2_train, x2_temp = train_test_split(c_data, x_data,x2_data, test_size=0.4, random_state=42)
c_val, c_test, x_val, x_test, x2_val, x2_test = train_test_split(c_temp, x_temp,x2_temp, test_size=0.5, random_state=42)

# 스케일링
c_scaler, x_scaler,x2_scaler = MinMaxScaler(), MinMaxScaler(), MinMaxScaler()
c_train = c_scaler.fit_transform(c_train)
x_train = x_scaler.fit_transform(x_train)
x2_train = x2_scaler.fit_transform(x2_train)
c_val, x_val, x2_val = c_scaler.transform(c_val), x_scaler.transform(x_val), x2_scaler.transform(x2_val)
c_test, x_test, x2_test = c_scaler.transform(c_test), x_scaler.transform(x_test), x2_scaler.transform(x2_test)

# 텐서 변환
def to_tensor(arr): return torch.tensor(arr, dtype=torch.float32)
c_train, x_train, x2_train = to_tensor(c_train), to_tensor(x_train), to_tensor(x2_train)
c_val, x_val, x2_val = to_tensor(c_val), to_tensor(x_val), to_tensor(x2_val)
c_test, x_test, x2_test = to_tensor(c_test), to_tensor(x_test), to_tensor(x2_test)

train_loader = DataLoader(TensorDataset(x_train, x2_train, c_train), batch_size=64, shuffle=True)
val_loader = DataLoader(TensorDataset(x_val, x2_val, c_val), batch_size=64, shuffle=False)
test_loader = DataLoader(TensorDataset(x_test, x2_test, c_test), batch_size=64, shuffle=False)

x,x2, c = next(iter(train_loader))
x_dim = x.shape[1]
x2_dim = x2.shape[1]
c_dim = c.shape[1]
z_dim = 8
z2_dim = 8
device= torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 3. BCE 및 MSE 모델 정의
model_bce = BCEcVAE(x_dim, x2_dim, c_dim, z_dim, z2_dim).to(device)
model_mse = MSEcVAE(x_dim, x2_dim, c_dim, z_dim, z2_dim).to(device)

In [4]:
optimizer_bce = optim.Adam(model_bce.parameters(), lr=1e-3, weight_decay=1e-5)
optimizer_mse = optim.Adam(model_mse.parameters(),lr=1e-3, weight_decay = 1e-5)
es_bce= EarlyStopping(patience=40, min_delta=1e-9)
es_mse = EarlyStopping(patience=40,min_delta=1e-9)

In [None]:
for epoch in range(1, 801):
    model_bce.train()
    t_loss = 0
    for x, x2, c in train_loader:
        x, x2, c = x.to(device), x2.to(device), c.to(device)
        optimizer_bce.zero_grad()
        logit2, mu, logvar = model_bce(x, x2, c)
        loss = l2_bce(logit2, x2, mu, logvar)['loss']
        loss.backward()
        optimizer_bce.step()
        t_loss += loss.item()
    
    model_bce.eval()
    v_loss = 0
    with torch.no_grad():
        for vx, vx2, vc in val_loader:
            vx, vx2, vc = vx.to(device), vx2.to(device), vc.to(device)
            v2_logit, v_mu, v_logvar = model_bce(vx, vx2, vc)
            v_loss += l2_bce(v2_logit, vx2, v_mu, v_logvar)['loss'].item()
    
    avg_v_loss = v_loss / len(val_loader)
    if epoch % 50 == 0: print(f"Epoch {epoch} | Val Loss: {avg_v_loss:.6f}")
    if es_bce(avg_v_loss, model_bce): break

es_bce.load_best_model(model_bce)

# --- 4. MSE 모델 학습 (수치 예측) ---
print("\n[Phase 2] Training Weighted MSEcVAE...")
model_mse = MSEcVAE(x_dim, x2_dim, c_dim, z_dim,z2_dim).to(device)
optimizer_mse = optim.Adam(model_mse.parameters(), lr=1e-3, weight_decay=1e-5)
es_mse = EarlyStopping(patience=40, min_delta=1e-9)

for epoch in range(1, 801):
    model_mse.train()
    t_loss = 0
    for x, x2, c in train_loader:
        x, x2, c = x.to(device), x2.to(device), c.to(device)
        optimizer_mse.zero_grad()
        x2_hat, mu, logvar = model_mse(x, x2, c)
        loss = l2_mse(x2_hat, x2, mu, logvar)['loss']
        loss.backward()
        optimizer_mse.step()
        t_loss += loss.item()

    model_mse.eval()
    v_loss = 0
    with torch.no_grad():
        for vx, vx2, vc in val_loader:
            vx, vx2, vc = vx.to(device), vx2.to(device), vc.to(device)
            vh2, vm, vl = model_mse(vx, vx2, vc)
            v_loss += l2_mse(vh2, vx2, vm, vl)['loss'].item()
    
    avg_v_loss = v_loss / len(val_loader)
    if epoch % 50 == 0: print(f"Epoch {epoch} | Val Loss: {avg_v_loss:.6f}")
    if es_mse(avg_v_loss, model_mse): break

es_mse.load_best_model(model_mse)



EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
EarlyStopping counter: 1 out of 40
Epoch 50 | Val Loss: 0.286608
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
EarlyStopping counter: 3 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
EarlyStopping counter: 3 out of 40
EarlyStopping counter: 4 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
EarlyStopping counter: 3 out of 40
EarlyStopping counter: 4 

In [16]:
# --- 5. 최종 평가 (Hurdle: BCE Prob * MSE Value) ---
model_bce.eval()
model_mse.eval()

all_pred = []
all_true = []
all_bce =[]
all_prob=[]
all_binary=[]
all_ct=[]
all_xt=[]
with torch.no_grad():
    for xt, xt2, ct in test_loader:
        xt, xt2, ct = xt.to(device), xt2.to(device), ct.to(device)
        
        # BCE 확률 (P)
        b_logit, _, _ = model_bce(xt, xt2, ct)
        b_prob = torch.sigmoid(b_logit)
        b_binary = (b_prob > 0.5).float()
        
        # MSE 수치 (V)
        m_val, _, _ = model_mse(xt, xt2, ct)
        
        # 결합 (P * V)
        combined_s = b_prob * m_val
        all_bce.append(b_logit.cpu().numpy())
        all_prob.append(b_prob.cpu().numpy())
        all_xt.append(xt.cpu().numpy())
        all_binary.append(b_binary.cpu().numpy())
        all_pred.append(combined_s.cpu().numpy())
        all_ct.append(ct.cpu().numpy())
        all_true.append(xt2.cpu().numpy())


In [6]:
temp = all_xt[10:13]   # (3, 64, 23)
all_xt1 = [x[0] for x in temp]   # (3, 23)
all_xt1

[array([0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.26193333, 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ], dtype=float32),
 array([0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.2, 0.2, 0. , 0. , 0. ,
        0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ], dtype=float32),
 array([0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.16666667, 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ], dtype=float32)]

In [7]:
temp = all_binary[10:13]   # (3, 64, 23)
all_binary1 = [x[0] for x in temp]   # (3, 23)
all_binary1

[array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
        1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)]

In [8]:
temp = all_bce[10:13]   # (3, 64, 23)
all_bce1 = [x[0] for x in temp]   # (3, 23)
all_bce1

[array([-223.40073 ,  -67.72378 ,  -89.833984, -241.93036 , -181.2636  ,
         -70.433815,  -93.44865 ,  -90.923775, -104.15372 ,  -74.21316 ,
        -109.117744,  -22.795023, -103.38128 ,  -59.63415 ,  -68.844604,
         -70.33725 ,  -74.907166,  -20.117456,  -20.433931, -157.11824 ,
        -151.01236 ,  -41.533363,  -14.395519, -190.12823 , -139.0396  ,
        -235.78098 , -197.73349 , -183.5798  , -240.984   ,  -48.081142,
          16.729588,   15.203286, -180.71912 , -183.94852 , -105.996994,
        -136.61327 , -176.26143 , -127.679276, -347.01685 , -175.96822 ,
        -153.2267  ,  -23.227345, -147.61281 , -139.30634 , -109.87209 ],
       dtype=float32),
 array([-149.71562 ,  -15.844072,  -90.54253 , -151.33232 ,  -50.177315,
         -97.27402 ,  -93.4712  ,  -55.404133,  -54.300095,  -90.79199 ,
         -82.95516 , -107.144485,  -99.10837 ,  -98.835495,  -32.09961 ,
         -83.407005, -137.52417 , -102.66231 ,  -67.63581 , -140.7391  ,
         -79.15925 ,  -97.4

In [9]:
temp = all_prob[10:13]   # (3, 64, 23)
all_prob1 = [x[0] for x in temp]   # (3, 23)
all_prob1

[array([0.0000000e+00, 3.8720148e-30, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 2.5762191e-31, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 5.8834918e-33, 0.0000000e+00, 1.2596425e-10,
        0.0000000e+00, 1.2624584e-26, 1.2623163e-30, 2.8374000e-31,
        2.9392151e-33, 1.8327350e-09, 1.3355383e-09, 0.0000000e+00,
        0.0000000e+00, 9.1683138e-19, 5.5989318e-07, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 1.3140901e-21, 1.0000000e+00, 9.9999976e-01,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 8.1750884e-11, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00], dtype=float32),
 array([0.0000000e+00, 1.3152453e-07, 0.0000000e+00, 0.0000000e+00,
        1.6153592e-22, 0.0000000e+00, 0.0000000e+00, 8.6754258e-25,
        2.6167881e-24, 0.0000000e+00, 9.3979012e-37, 0.0000000e+00,
        

In [10]:
temp = all_ct[10:13]   # (3, 64, 23)
all_ct1 = [x[0] for x in temp]   # (3, 23)
all_ct1

[array([0.21501848, 0.31034482, 0.08420593, 0.04166667, 1.        ,
        0.23285207, 0.5       , 0.16424474, 0.16424474, 0.6295575 ,
        0.65656567, 0.97998303], dtype=float32),
 array([0.50597894, 0.1724138 , 0.05090621, 0.08333334, 0.5       ,
        0.10557704, 0.445896  , 0.01088919, 0.00916972, 0.6014336 ,
        0.26262626, 0.9634164 ], dtype=float32),
 array([0.11020738, 0.10344828, 0.06674241, 0.125     , 1.        ,
        0.10557704, 0.5       , 0.16424474, 0.16424474, 0.59597725,
        0.13131313, 0.97032344], dtype=float32)]

In [12]:
all_pred

[array([[-0.0000000e+00, -1.1342997e-37, -0.0000000e+00, ...,
         -0.0000000e+00, -0.0000000e+00, -0.0000000e+00],
        [ 9.8648280e-01, -2.0378531e-17, -0.0000000e+00, ...,
          0.0000000e+00, -6.6871143e-28, -0.0000000e+00],
        [ 9.7434455e-01, -1.5820460e-18, -0.0000000e+00, ...,
          0.0000000e+00,  7.3409227e-29, -0.0000000e+00],
        ...,
        [-0.0000000e+00,  0.0000000e+00, -0.0000000e+00, ...,
          7.1507856e-23,  0.0000000e+00,  0.0000000e+00],
        [ 0.0000000e+00,  2.0447574e-32, -0.0000000e+00, ...,
          0.0000000e+00, -0.0000000e+00, -0.0000000e+00],
        [ 9.0924948e-01,  1.3953607e-33, -0.0000000e+00, ...,
          0.0000000e+00, -3.4433946e-23, -0.0000000e+00]], dtype=float32),
 array([[-0.0000000e+00, -0.0000000e+00,  0.0000000e+00, ...,
          1.7354339e-22, -0.0000000e+00,  0.0000000e+00],
        [-0.0000000e+00,  0.0000000e+00, -0.0000000e+00, ...,
         -1.1744180e-08, -0.0000000e+00,  3.5400982e-35],
        [ 

In [17]:
# Inverse Scaling 및 최종 R2 점수
y_pred_inv = x2_scaler.inverse_transform(np.vstack(all_pred))
y_true_inv = x2_scaler.inverse_transform(np.vstack(all_true))

final_score = r2_score(y_true_inv.flatten(), y_pred_inv.flatten())
print(f"\nFinal Combined R2 Score: {final_score:.4f}")


Final Combined R2 Score: 0.9628


In [None]:
temp = x_true_inv[10:13]
all_xt = [x for x in temp]
all_xt

[array([0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 5.275, 0.   ,
        0.   , 1.   , 0.   , 0.   , 0.   , 0.   , 1.   , 0.   , 0.   ,
        0.   , 0.   , 0.   , 0.   , 0.   ], dtype=float32),
 array([0. , 0. , 0. , 0. , 0. , 0. , 0. , 6.1, 0. , 0. , 0. , 0. , 0. ,
        0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ], dtype=float32),
 array([ 0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  , 14.45, 10.  ,  0.  ,
         0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,
         0.  ,  0.  ,  0.  ,  0.  ,  0.  ], dtype=float32)]

In [None]:
temp = c_true_inv[10:13]   # (3, 64, 23)
all_ct1 = [x for x in temp]   # (3, 23)
all_ct1

[array([6.0000000e+02, 6.0000000e+00, 5.5000000e+02, 3.0000000e+00,
        1.0000000e+02, 6.0000000e+02, 0.0000000e+00, 4.4186845e-01,
        4.4186845e-01, 9.8095423e-01, 4.1108742e+00, 3.8000000e+01,
        9.0582526e-01], dtype=float32),
 array([550.       ,   3.       , 700.       ,   2.       ,  50.       ,
        800.       ,   0.       ,   2.08156  ,   2.08156  ,   3.0934486,
          5.1984973,  78.       ,  42.17963  ], dtype=float32),
 array([700.00006  ,   2.       , 700.       ,   4.       , 100.       ,
        700.       ,   0.       ,   0.9807668,   0.9807668,   1.4662602,
          5.1357985,  70.       ,   5.692583 ], dtype=float32)]

In [None]:
import numpy as np

all_ct1 = np.array(all_ct1)

all_ct1[:,7:11] = np.expm1(all_ct1[:,7:11])
all_ct1[:,6] = np.exp(all_ct1[:,6]) #6번째

In [None]:
all_ct1

array([[6.0000000e+02, 6.0000000e+00, 5.5000000e+02, 3.0000000e+00,
        1.0000000e+02, 6.0000000e+02, 1.0000000e+00, 5.5561113e-01,
        5.5561113e-01, 1.6670001e+00, 6.0000019e+01, 3.8000000e+01,
        9.0582526e-01],
       [5.5000000e+02, 3.0000000e+00, 7.0000000e+02, 2.0000000e+00,
        5.0000000e+01, 8.0000000e+02, 1.0000000e+00, 7.0169649e+00,
        7.0169649e+00, 2.1052999e+01, 1.8000005e+02, 7.8000000e+01,
        4.2179630e+01],
       [7.0000006e+02, 2.0000000e+00, 7.0000000e+02, 4.0000000e+00,
        1.0000000e+02, 7.0000000e+02, 1.0000000e+00, 1.6665001e+00,
        1.6665001e+00, 3.3330002e+00, 1.6900000e+02, 7.0000000e+01,
        5.6925831e+00]], dtype=float32)