In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.ensemble import RandomForestRegressor
from sklearn.multioutput import MultiOutputRegressor
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

In [2]:
import os
par_dir = os.path.abspath(os.path.join(os.getcwd(),os.pardir))
os.chdir(par_dir)
from model.m_randomforest_bce import WeightedBCEcVAE
from model.m2_randomforest_mse import WeightedMSEcVAE
from loss.l2_bce import l2_bce
from loss.l2_mse import l2_mse
from vae_earlystopping import EarlyStopping

In [3]:
x_data = np.load('./data/metal.npy')      # Target (X)
c_data = np.load('./data/pre_re_fin.npy') # Condition (C)

# 데이터 분할 (Train/Val/Test)
c_train, c_temp, x_train, x_temp = train_test_split(c_data, x_data, test_size=0.4, random_state=42)
c_val, c_test, x_val, x_test = train_test_split(c_temp, x_temp, test_size=0.5, random_state=42)

# 스케일링
c_scaler, x_scaler = MinMaxScaler(), MinMaxScaler()
c_train = c_scaler.fit_transform(c_train)
x_train = x_scaler.fit_transform(x_train)
c_val, x_val = c_scaler.transform(c_val), x_scaler.transform(x_val)
c_test, x_test = c_scaler.transform(c_test), x_scaler.transform(x_test)

# 텐서 변환
def to_tensor(arr): return torch.tensor(arr, dtype=torch.float32)
c_train, x_train = to_tensor(c_train), to_tensor(x_train)
c_val, x_val = to_tensor(c_val), to_tensor(x_val)
c_test, x_test = to_tensor(c_test), to_tensor(x_test)

train_loader = DataLoader(TensorDataset(x_train, c_train), batch_size=64, shuffle=True)
val_loader = DataLoader(TensorDataset(x_val, c_val), batch_size=64, shuffle=False)
test_loader = DataLoader(TensorDataset(x_test, c_test), batch_size=64, shuffle=False)

# 2. RandomForest를 통한 가중치(Feature Importance) 계산
rf = MultiOutputRegressor(RandomForestRegressor(n_estimators=100, random_state=42))
rf.fit(c_train.numpy(), x_train.numpy()) # Train 데이터로만 중요도 계산

importances = np.mean([est.feature_importances_ for est in rf.estimators_], axis=0)
condition_weights = torch.tensor(importances / np.max(importances), dtype=torch.float32)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
x_dim, c_dim, z_dim = x_train.shape[1], c_train.shape[1], 8

# 3. BCE 및 MSE 모델 정의
model_bce = WeightedBCEcVAE(x_dim, c_dim, z_dim, condition_weights).to(device)
model_mse = WeightedMSEcVAE(x_dim, c_dim, z_dim, condition_weights).to(device)

In [9]:
optimizer_bce = optim.Adam(model_bce.parameters(), lr=1e-3, weight_decay=1e-5)
optimizer_mse = optim.Adam(model_mse.parameters(),lr=1e-3, weight_decay = 1e-5)
es_bce= EarlyStopping(patience=40, min_delta=1e-9)
es_mse = EarlyStopping(patience=40,min_delta=1e-9)

In [10]:
for epoch in range(1, 801):
    model_bce.train()
    t_loss = 0
    for x, c in train_loader:
        x, c = x.to(device), c.to(device)
        optimizer_bce.zero_grad()
        logit, mu, logvar = model_bce(x, c)
        loss = l2_bce(logit, x, mu, logvar)['loss']
        loss.backward()
        optimizer_bce.step()
        t_loss += loss.item()
    
    model_bce.eval()
    v_loss = 0
    with torch.no_grad():
        for vx, vc in val_loader:
            vx, vc = vx.to(device), vc.to(device)
            v_logit, v_mu, v_logvar = model_bce(vx, vc)
            v_loss += l2_bce(v_logit, vx, v_mu, v_logvar)['loss'].item()
    
    avg_v_loss = v_loss / len(val_loader)
    if epoch % 50 == 0: print(f"Epoch {epoch} | Val Loss: {avg_v_loss:.6f}")
    if es_bce(avg_v_loss, model_bce): break

es_bce.load_best_model(model_bce)

# --- 4. MSE 모델 학습 (수치 예측) ---
print("\n[Phase 2] Training Weighted MSEcVAE...")
model_mse = WeightedMSEcVAE(x_dim, c_dim, z_dim, condition_weights).to(device)
optimizer_mse = optim.Adam(model_mse.parameters(), lr=1e-3, weight_decay=1e-5)
es_mse = EarlyStopping(patience=40, min_delta=1e-9)

for epoch in range(1, 801):
    model_mse.train()
    t_loss = 0
    for x, c in train_loader:
        x, c = x.to(device), c.to(device)
        optimizer_mse.zero_grad()
        x_hat, mu, logvar = model_mse(x, c)
        loss = l2_mse(x_hat, x, mu, logvar)['loss']
        loss.backward()
        optimizer_mse.step()
        t_loss += loss.item()

    model_mse.eval()
    v_loss = 0
    with torch.no_grad():
        for vx, vc in val_loader:
            vx, vc = vx.to(device), vc.to(device)
            vh, vm, vl = model_mse(vx, vc)
            v_loss += l2_mse(vh, vx, vm, vl)['loss'].item()
    
    avg_v_loss = v_loss / len(val_loader)
    if epoch % 50 == 0: print(f"Epoch {epoch} | Val Loss: {avg_v_loss:.6f}")
    if es_mse(avg_v_loss, model_mse): break

es_mse.load_best_model(model_mse)

# --- 5. 최종 평가 (Hurdle: BCE Prob * MSE Value) ---
model_bce.eval()
model_mse.eval()

all_pred = []
all_true = []

with torch.no_grad():
    for xt, ct in test_loader:
        xt, ct = xt.to(device), ct.to(device)
        
        # BCE 확률 (P)
        b_logit, _, _ = model_bce(xt, ct)
        b_prob = torch.sigmoid(b_logit)
        
        # MSE 수치 (V)
        m_val, _, _ = model_mse(xt, ct)
        
        # 결합 (P * V)
        combined_s = b_prob * m_val
        
        all_pred.append(combined_s.cpu().numpy())
        all_true.append(xt.cpu().numpy())

# Inverse Scaling 및 최종 R2 점수
y_pred_inv = x_scaler.inverse_transform(np.vstack(all_pred))
y_true_inv = x_scaler.inverse_transform(np.vstack(all_true))

final_score = r2_score(y_true_inv.flatten(), y_pred_inv.flatten())
print(f"\nFinal Combined R2 Score: {final_score:.4f}")

EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
EarlyStopping counter: 3 out of 40
EarlyStopping counter: 4 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
EarlyStopping counter: 3 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 1 out of 40
Epoch 50 | Val Loss: 2.383520
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
EarlyStopping counter: 1 out of 40
EarlyStopping counter: 2 out of 40
EarlyStopping counter: 3 out of 40
EarlyStopping counter: 1 