In [1]:
import os
par_dir = os.path.abspath(os.path.join(os.getcwd(),os.pardir))
os.chdir(par_dir)

In [2]:
import torch
import torch.optim as optim
import joblib
import optuna
from model.m26_prob_1 import MultiDecoderCondVAE
from loss.l26oss_all import integrated_loss_fn

# 1. 환경 및 데이터 준비
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 데이터 로더 및 스케일러 로드
x_scaler = joblib.load('torch/abs_x_scaler.pkl')
c_saler = joblib.load('torch/sta_pre_re_scaler.pkl')
train_loader = torch.load('torch/pre_retrain_loader.pt', weights_only=False)
val_loader = torch.load('torch/pre_reval_loader.pt', weights_only=False)

# 입력 차원 자동 추출 (첫 번째 배치를 통해 확인)
x_sample, c_sample = next(iter(train_loader))
x_dim = x_sample.shape[1]
c_dim = c_sample.shape[1]

def objective(trial):
    # 2. 튜닝할 하이퍼파라미터 제안
    lr = trial.suggest_float("lr", 1e-4, 1e-2, log=True)
    z_dim = trial.suggest_int("z_dim", 4, 32)
    alpha = trial.suggest_float("alpha", 0.1, 5.0)
    beta = trial.suggest_float("beta", 0.1, 5) # KL 가중치는 보통 작게 시작
    gamma = trial.suggest_float("gamma", 0.001,0.1)
    # 3. 모델 및 옵티마이저 선언
    model = MultiDecoderCondVAE(x_dim, c_dim, z_dim).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # 각 trial당 학습할 에포크 수 (성능 확인을 위해 최소 20~50회 권장)
    epochs = 50 
    
    for epoch in range(epochs):
        # --- Training Loop ---
        model.train()
        for x, c in train_loader:
            x, c = x.to(device), c.to(device)
            optimizer.zero_grad()
            
            bce_logit, binary_out, x_hat, z_mu, z_logvar = model(x, c)
            
            # integrated_loss_fn에 제안된 가중치 적용
            loss_dict = integrated_loss_fn(
                bce_logit, x_hat, x, z_mu, z_logvar
                ,alpha=alpha, beta=beta, gamma=gamma)
            
            loss_dict['loss'].backward()
            optimizer.step()
        
        # --- Validation Loop ---
        model.eval()
        v_total_loss = 0
        with torch.no_grad():
            for v_x, v_c in val_loader:
                v_x, v_c = v_x.to(device), v_c.to(device)
                v_bce_logit,v_binary_out, v_x_hat, v_z_mu, v_z_logvar = model(v_x, v_c)
                
                v_loss_dict = integrated_loss_fn(
                    v_bce_logit, v_x_hat, v_x, v_z_mu, v_z_logvar,
                    alpha=alpha, beta=beta, gamma=gamma
                )
                v_total_loss += v_loss_dict['loss'].item()
        
        avg_val_loss = v_total_loss / len(val_loader)
        
        # Pruning: 성능이 개선되지 않는 trial은 조기 종료하여 시간 절약
        trial.report(avg_val_loss, epoch)
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    return avg_val_loss

# 4. 최적화 실행
# n_trials: 총 시도 횟수 (예: 30)
study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=30)

# 5. 최적화 결과 확인 및 모델 재학습 활용
print("-" * 30)
print("Best hyperparameters:", study.best_params)
print("Best validation loss:", study.best_value)



  from .autonotebook import tqdm as notebook_tqdm
[32m[I 2026-01-29 11:06:37,008][0m A new study created in memory with name: no-name-f82737d7-8605-4c40-bf0c-67f7b515da43[0m
[32m[I 2026-01-29 11:06:55,770][0m Trial 0 finished with value: 0.504323746812971 and parameters: {'lr': 0.00347274782846782, 'z_dim': 16, 'alpha': 3.645904427253932, 'beta': 0.7522006229160809, 'gamma': 0.09885897330577079}. Best is trial 0 with value: 0.504323746812971.[0m
[32m[I 2026-01-29 11:07:09,387][0m Trial 1 finished with value: 1.7800922770249217 and parameters: {'lr': 0.007888216271758123, 'z_dim': 23, 'alpha': 0.8642798346871813, 'beta': 3.634497520648618, 'gamma': 0.06803779965153234}. Best is trial 0 with value: 0.504323746812971.[0m
[32m[I 2026-01-29 11:07:23,284][0m Trial 2 finished with value: 1.1241484347142672 and parameters: {'lr': 0.0007726690584931776, 'z_dim': 19, 'alpha': 3.956443745259353, 'beta': 1.1376437162430781, 'gamma': 0.08813492416850997}. Best is trial 0 with value: 0.50

------------------------------
Best hyperparameters: {'lr': 0.005973210895790041, 'z_dim': 21, 'alpha': 1.4067140824040858, 'beta': 0.16236664842673776, 'gamma': 0.0016322069344605304}
Best validation loss: 0.09428630151638859


In [3]:
optuna.visualization.plot_param_importances(study).show()

In [4]:
optuna.visualization.plot_parallel_coordinate(study).show()