In [0]:
import os, sys
sys.path.append("/Workspace/Users/1dt003@msacademy.msai.kr/team4-CICD/src")

from src.config_loader import load_config
from src.data_preprocessor import preprocess_dataframe, merge_batch_recent
from src.model_trainer import build_preprocessor, train_models
from src.mlflow_manager import init_experiment, log_model_mlflow
from src.model_serving import predict_with_model
from pyspark.sql import SparkSession
import pandas as pd

# ==============================
# Spark 세션
# ==============================
spark = SparkSession.builder.getOrCreate()

# ==============================
# Config 로드
# ==============================
config = load_config("/Workspace/Users/1dt003@msacademy.msai.kr/team4-CICD/configs/config_v2.json")

jdbc_url = config['jdbc_url']
connection_properties = config.get('connection_properties')
batch_table = config['batch_table']
recent_table = config['recent_table']
exclude_cols = config.get('exclude_cols', ['learnerID','testID','correct_cnt','items_attempted'])
categorical_candidates = config.get('categorical_candidates', ['gender','grade'])
rename_map = config.get('rename_map', {'pred_realScore_clean': 'realScore_clean'})
drop_cols = config.get('drop_cols', ['percent_rank','grade_percentile_calc'])
target = config.get("target", "realScore_clean")

# ==============================
# 환경변수로 모델/실험/타겟 정보
# ==============================

PREDICT_EXPERIMENT = os.environ.get("PREDICT_EXPERIMENT") or "/Workspace/Users/1dt003@msacademy.msai.kr/team4_pred_experiment"
BASE_MODEL_NAME = os.environ.get("BASE_MODEL_NAME") or "team4-pred-model"

# ==============================
# 데이터 로드 & 병합
# ==============================
df_batch = spark.read.jdbc(url=jdbc_url, table=batch_table, properties=connection_properties)
df_recent = spark.read.jdbc(url=jdbc_url, table=recent_table, properties=connection_properties)

df_merge = merge_batch_recent(
    df_batch,
    df_recent,
    rename_map=config.get('rename_map', {'pred_realScore_clean': target}),
    drop_cols=config.get('drop_cols', ['percent_rank','grade_percentile_calc'])
)

# ==============================
# 전처리
# ==============================
df_processed, categorical_cols = preprocess_dataframe(df_merge, categorical_candidates=categorical_candidates)

# ==============================
# feature/target 분리
# ==============================
feature_cols = [c for c in df_processed.columns if c not in ([target] + exclude_cols)]
X = df_processed[feature_cols]
y = df_processed[target]

# ==============================
# train/test split
# ==============================
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# ==============================
# 전처리 파이프라인 생성
# ==============================
preprocessor = build_preprocessor(X_train, categorical_cols)

# ==============================
# 모델 학습
# ==============================
train_results = train_models(X_train, y_train, X_test, y_test, preprocessor)

print("✅ 학습 완료. 평가 결과:")
for name, res in train_results['results'].items():
    print(f"{name}: RMSE={res['rmse']:.4f}, MAE={res['mae']:.4f}, R2={res['r2']:.4f}")

# ==============================
# MLflow 등록
# ==============================

init_experiment(PREDICT_EXPERIMENT)

log_model_mlflow(
    best_name=train_results['best_name'],
    best_model=train_results['best_model'],
    results=train_results['results'],
    model_name=BASE_MODEL_NAME
)

from mlflow.tracking import MlflowClient

client = MlflowClient()

# 최신 None stage 모델 가져오기
latest_versions = client.get_latest_versions(name=BASE_MODEL_NAME, stages=["None"])
if not latest_versions:
    raise ValueError(f"No versions found for {BASE_MODEL_NAME}")

latest_version_num = max([int(v.version) for v in latest_versions])


print(f"✅ 모델 {BASE_MODEL_NAME} version {latest_version_num}를 Staging stage로 전환 완료")

# ==============================
# 샘플 서빙 테스트
# ==============================
sample_preds = predict_with_model(train_results['best_model'], X_test.head(5), feature_cols)
print("샘플 예측값:\n", sample_preds)
