# [Í≥†ÏÑ±Îä• MSE Î∞òÏòÅ] Î∂ÑÎ•ò Î∞è Ïó∞ÏÜçÎ™®Îç∏ Î≥µÌï© Íµ¨Ï°∞ cVAE Ï¥âÎß§ ÏµúÏ†ÅÌôî
### NVAE Í∏∞Î∞ò KL Balancing [0.01, 0.005, 0.002] Î∞è Recon 1:1:1 Í∞ÄÏ§ëÏπò

In [None]:
import os, sys, torch, json
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import r2_score
import matplotlib.pyplot as plt

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))
from multilayer_model.m3_multi_bce import M3_Multi_BCE
from multilayer_model.m3_multi_mse import M3_Multi_MSE
from multilayer_loss.l_multi3_final_logic import l_multi3_final_loss
from vae_earlystopping import EarlyStopping

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
EPOCHS = 800
print(f"Using device: {device}")

In [None]:
x1_raw = np.load('../data/metal.npy')
x2_raw = np.load('../data/support_norm.npy')
x3_raw = np.load('../data/pre_fin_norm.npy')
c_raw = np.load('../data/re_fin.npy')

tr_idx, te_idx = train_test_split(np.arange(len(x1_raw)), test_size=0.1, random_state=42)
tr_idx, va_idx = train_test_split(tr_idx, test_size=0.1, random_state=42)

sc1, sc2, sc3, scc = MinMaxScaler(), MinMaxScaler(), MinMaxScaler(), MinMaxScaler()
def prep(d, i, s, fit=False): return s.fit_transform(d[i]) if fit else s.transform(d[i])

x1_tr, x1_va, x1_te = prep(x1_raw, tr_idx, sc1, True), prep(x1_raw, va_idx, sc1), prep(x1_raw, te_idx, sc1)
x2_tr, x2_va, x2_te = prep(x2_raw, tr_idx, sc2, True), prep(x2_raw, va_idx, sc2), prep(x2_raw, te_idx, sc2)
x3_tr, x3_va, x3_te = prep(x3_raw, tr_idx, sc3, True), prep(x3_raw, va_idx, sc3), prep(x3_raw, te_idx, sc3)
c_tr, c_va, c_te = prep(c_raw, tr_idx, scc, True), prep(c_raw, va_idx, scc), prep(c_raw, te_idx, scc)

def to_t(a): return torch.tensor(a, dtype=torch.float32)
train_loader = DataLoader(TensorDataset(to_t(x1_tr), to_t(x2_tr), to_t(x3_tr), to_t(c_tr)), batch_size=32, shuffle=True)
val_loader = DataLoader(TensorDataset(to_t(x1_va), to_t(x2_va), to_t(x3_va), to_t(c_va)), batch_size=32)
test_loader = DataLoader(TensorDataset(to_t(x1_te), to_t(x2_te), to_t(x3_te), to_t(c_te)), batch_size=32)

x_dims = [x1_tr.shape[1], x2_tr.shape[1], x3_tr.shape[1]]
c_dim = c_tr.shape[1]

In [None]:
def train_expert(model, mode='mse'):
    # Í≥†ÏÑ±Îä• Î™®Îç∏Ïóê ÎßûÏ∂∞ ÌïôÏäµÎ•† Ï°∞Ï†ï
    optimizer = optim.Adam(model.parameters(), lr=5e-4 if mode=='mse' else 1e-3, weight_decay=1e-5)
    es = EarlyStopping(patience=40, min_delta=1e-9)
    
    for epoch in range(1, EPOCHS + 1):
        model.train()
        for b1, b2, b3, bc in train_loader:
            b1, b2, b3, bc = b1.to(device), b2.to(device), b3.to(device), bc.to(device)
            optimizer.zero_grad()
            p, m, v = model(b1, b2, b3, bc)
            loss = l_multi3_final_loss(p, [b1, b2, b3], m, v, mode=mode, gamma_list=[0.01, 0.005, 0.002])
            loss.backward(); optimizer.step()
        
        model.eval()
        v_loss = 0
        with torch.no_grad():
            for v1, v2, v3, vc in val_loader:
                v1, v2, v3, vc = v1.to(device), v2.to(device), v3.to(device), vc.to(device)
                vp, vm, vv = model(v1, v2, v3, vc)
                v_loss += l_multi3_final_loss(vp, [v1, v2, v3], vm, vv, mode=mode).item()
        
        if epoch % 50 == 0: print(f"[{mode.upper()}] Epoch {epoch} | Val Loss: {v_loss/len(val_loader):.6f}")
        if es(v_loss/len(val_loader), model): break
    es.load_best_model(model)
    return model

m_bce = train_expert(M3_Multi_BCE(x_dims, c_dim).to(device), 'bce')
m_mse = train_expert(M3_Multi_MSE(x_dims, c_dim).to(device), 'mse')

In [None]:
print("--- Final Evaluation (Metal Generation) ---")
m_bce.eval(); m_mse.eval()
all_gen, all_true = [], []

with torch.no_grad():
    for b1, b2, b3, bc in test_loader:
        bc = bc.to(device)
        b_gen = m_bce.generate(bc, device)
        m_gen = m_mse.generate(bc, device)
        
        # Gating + Noise Threshold
        final = torch.sigmoid(b_gen[0]) * m_gen[0]
        final[final < 1e-3] = 0
        
        all_gen.append(sc1.inverse_transform(final.cpu().numpy()))
        all_true.append(sc1.inverse_transform(b1.numpy()))

y_p, y_t = np.concatenate(all_gen).flatten(), np.concatenate(all_true).flatten()
print(f"üèÜ Final R2 Score: {r2_score(y_t, y_p):.4f}")

plt.figure(figsize=(10, 5))
plt.scatter(y_t, y_p, alpha=0.5, color='orange')
plt.plot([0, y_t.max()], [0, y_t.max()], 'r--')
plt.title("High Performance Generation Results")
plt.show()