In [6]:
# ================================================================
# 모듈 9: 본부 단위 CL별 정규화 모듈 (연말 전용) - 주피터 노트북용
# ================================================================

from typing import Annotated, List, Literal, TypedDict, Dict, Optional
from langchain_core.messages import HumanMessage 
import operator
from langgraph.graph import StateGraph, START, END
import json
import re
import sys
import os
import statistics

# 기존 imports
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Row
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import SystemMessage, AIMessage
from dotenv import load_dotenv

load_dotenv()

# DB 설정
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../../../..')))
from config.settings import DatabaseConfig

db_config = DatabaseConfig()
DATABASE_URL = db_config.DATABASE_URL
engine = create_engine(DATABASE_URL, pool_pre_ping=True)

# LLM 클라이언트 설정
llm_client = ChatOpenAI(model="gpt-4o-mini", temperature=0)
print(f"LLM Client initialized: {llm_client.model_name}")

def row_to_dict(row: Row) -> Dict:
    """SQLAlchemy Row 객체를 딕셔너리로 변환"""
    if row is None:
        return {}
    return row._asdict()

def _extract_json_from_llm_response(text: str) -> str:
    """LLM 응답에서 JSON 블록 추출"""
    match = re.search(r"```(?:json)?\s*(.*?)\s*```", text, re.DOTALL)
    if match:
        return match.group(1).strip()
    return text.strip()

# ================================================================
# HeadquarterModule9AgentState 정의 - 본부 단위 처리
# ================================================================

class HeadquarterModule9AgentState(TypedDict):
    """모듈 9 (본부 단위 CL별 정규화) 상태 - 본부 단위 처리"""
    messages: Annotated[List[HumanMessage], operator.add]
    
    # 본부 기본 정보
    headquarter_id: str
    period_id: int  # 연말: 4
    
    # 본부 전체 데이터
    headquarter_members: List[Dict]  # 본부 내 모든 직원 데이터
    cl_groups: Dict  # CL별 그룹화된 데이터
    
    # 정규화 결과
    normalized_scores: List[Dict]  # 정규화된 점수 및 코멘트
    
    # 처리 결과
    processed_count: int
    failed_members: List[str]

# ================================================================
# CL별 정규화 관련 함수들 (모듈 7에서 재사용)
# ================================================================

def get_cl_normalization_params(cl) -> Dict[str, float]:
    """CL별 정규화 파라미터 반환 (SK 표준)"""
    # CL 값을 문자열로 변환 (숫자든 문자열이든 처리)
    if isinstance(cl, (int, float)):
        cl_key = f"CL{int(cl)}"
    else:
        cl_key = str(cl).upper()
        if not cl_key.startswith("CL"):
            cl_key = f"CL{cl_key}"
    
    params = {
        "CL3": {"target_mean": 3.5, "target_stdev": 1.7},
        "CL2": {"target_mean": 3.5, "target_stdev": 1.5}, 
        "CL1": {"target_mean": 3.5, "target_stdev": 1.4}
    }
    return params.get(cl_key, {"target_mean": 3.5, "target_stdev": 1.5})

def normalize_cl_group(members: List[Dict], cl: str) -> List[Dict]:
    """CL 그룹 내 정규화 실행 (무조건 정규화 적용)"""
    
    if len(members) == 0:
        return members
    
    print(f"   {cl} 그룹 ({len(members)}명) 정규화 처리:")
    
    # CL별 목표 파라미터
    params = get_cl_normalization_params(cl)
    target_mean = params["target_mean"]
    target_stdev = params["target_stdev"]
    
    # 원시점수 수집 (manager_score) - Decimal을 float로 변환
    raw_scores = [float(m["manager_score"]) for m in members]
    
    # 현재 통계
    current_mean = statistics.mean(raw_scores)
    current_stdev = statistics.stdev(raw_scores) if len(raw_scores) > 1 else 0
    
    print(f"     정규화 적용: 평균 {current_mean:.2f} → {target_mean}, 표준편차 {current_stdev:.2f} → {target_stdev}")
    
    # 정규화 적용
    for member in members:
        raw_score = float(member["manager_score"])  # Decimal을 float로 변환
        
        if current_stdev == 0 or len(members) == 1:
            # 모든 점수가 동일하거나 1명인 경우
            final_score = target_mean
            reason = f"본부 내 {cl} 정규화 → 평균 {target_mean}점"
        else:
            # Z-score 계산 후 목표 분포로 변환
            z_score = (raw_score - current_mean) / current_stdev
            final_score = target_mean + (z_score * target_stdev)
            
            # 0.0-5.0 범위 제한 (SK 기준)
            final_score = max(0.0, min(5.0, final_score))
            
            reason = f"본부 내 {cl} 정규화 (Z-Score: {z_score:.2f})"
        
        member["final_score"] = round(final_score, 2)
        member["cl_reason"] = reason
        
        print(f"     {member['emp_no']}: {raw_score:.2f} → {final_score:.2f} ({reason})")
    
    return members

# ================================================================
# 본부 단위 DB 조회 함수들
# ================================================================

def fetch_headquarter_members(headquarter_id: str, period_id: int) -> List[Dict]:
    """본부 내 모든 직원의 manager_score + 기본정보 조회"""
    with engine.connect() as connection:
        query = text("""
            SELECT 
                e.emp_no, e.emp_name, e.cl, e.position, e.team_id,
                te.manager_score,
                fer.final_evaluation_report_id
            FROM employees e
            JOIN teams t ON e.team_id = t.team_id
            JOIN headquarters h ON t.headquarter_id = h.headquarter_id
            JOIN temp_evaluations te ON e.emp_no = te.TempEvaluation_empNo
            JOIN team_evaluations tea ON t.team_id = tea.team_id
            JOIN final_evaluation_reports fer ON (e.emp_no = fer.emp_no AND tea.team_evaluation_id = fer.team_evaluation_id)
            WHERE h.headquarter_id = :headquarter_id 
              AND tea.period_id = :period_id
              AND te.manager_score IS NOT NULL
            ORDER BY e.cl DESC, e.position DESC
        """)
        results = connection.execute(query, {
            "headquarter_id": headquarter_id, 
            "period_id": period_id
        }).fetchall()
        return [row_to_dict(row) for row in results]

def batch_update_final_evaluation_reports(score_data: List[Dict]) -> Dict:
    """본부 전체 final_evaluation_reports 배치 업데이트 (score, cl_reason 저장)"""
    success_count = 0
    failed_members = []
    
    with engine.connect() as connection:
        try:
            for data in score_data:
                try:
                    query = text("""
                        UPDATE final_evaluation_reports 
                        SET score = :score,
                            cl_reason = :cl_reason
                        WHERE final_evaluation_report_id = :report_id
                    """)
                    
                    result = connection.execute(query, {
                        "report_id": data["final_evaluation_report_id"],
                        "score": data["final_score"],
                        "cl_reason": data["cl_reason"]
                    })
                    
                    if result.rowcount > 0:
                        success_count += 1
                        print(f"DB 업데이트 성공: {data['emp_no']} (정규화: {data['final_score']})")
                    else:
                        failed_members.append(data["emp_no"])
                        print(f"DB 업데이트 실패: {data['emp_no']} (행 없음)")
                        
                except Exception as e:
                    failed_members.append(data["emp_no"])
                    print(f"DB 업데이트 실패: {data['emp_no']} - {e}")
            
            connection.commit()
            print(f"배치 업데이트 완료: 성공 {success_count}건, 실패 {len(failed_members)}건")
            
            return {
                "success_count": success_count,
                "failed_members": failed_members
            }
            
        except Exception as e:
            print(f"배치 업데이트 실패: {e}")
            connection.rollback()
            return {
                "success_count": 0,
                "failed_members": [data["emp_no"] for data in score_data]
            }

# ================================================================
# 본부 단위 서브모듈 함수들
# ================================================================

def headquarter_data_collection_submodule(state: HeadquarterModule9AgentState) -> HeadquarterModule9AgentState:
    """1. 본부 데이터 수집 서브모듈"""
    
    headquarter_id = state["headquarter_id"]
    period_id = state["period_id"]
    
    try:
        print(f"🔍 본부 데이터 수집 시작: {headquarter_id}")
        
        # 본부 내 모든 직원 데이터 조회
        headquarter_members = fetch_headquarter_members(headquarter_id, period_id)
        print(f"   본부 내 직원 수: {len(headquarter_members)}명")
        
        updated_state = state.copy()
        updated_state.update({
            "messages": [HumanMessage(content=f"본부 데이터 수집 완료: {len(headquarter_members)}명")],
            "headquarter_members": headquarter_members
        })
        return updated_state
        
    except Exception as e:
        updated_state = state.copy()
        updated_state["messages"] = [HumanMessage(content=f"데이터 수집 실패: {str(e)}")]
        raise e

def headquarter_cl_grouping_submodule(state: HeadquarterModule9AgentState) -> HeadquarterModule9AgentState:
    """2. CL별 그룹화 서브모듈"""
    
    try:
        headquarter_members = state["headquarter_members"]
        
        print("📊 본부 내 CL별 그룹화 시작...")
        
        # CL별 그룹화 (숫자/문자열 모두 처리)
        cl_groups = {
            "CL1": [],
            "CL2": [], 
            "CL3": []
        }
        
        for member in headquarter_members:
            cl_raw = member.get("cl", 2)  # 기본값 2
            
            # CL 값 정규화
            if isinstance(cl_raw, (int, float)):
                cl = f"CL{int(cl_raw)}"
            else:
                cl = str(cl_raw).upper()
                if not cl.startswith("CL"):
                    cl = f"CL{cl}"
            
            # 유효한 CL인지 확인
            if cl in cl_groups:
                cl_groups[cl].append(member)
                member["cl"] = cl  # 정규화된 CL 값으로 업데이트
            else:
                print(f"⚠️ 알 수 없는 CL: {cl_raw} → CL2로 처리")
                cl_groups["CL2"].append(member)
                member["cl"] = "CL2"
        
        print(f"   CL별 분포: CL3({len(cl_groups['CL3'])}명), CL2({len(cl_groups['CL2'])}명), CL1({len(cl_groups['CL1'])}명)")
        
        updated_state = state.copy()
        updated_state.update({
            "messages": [HumanMessage(content="CL별 그룹화 완료")],
            "cl_groups": cl_groups
        })
        return updated_state
        
    except Exception as e:
        updated_state = state.copy()
        updated_state["messages"] = [HumanMessage(content=f"CL별 그룹화 실패: {str(e)}")]
        raise e

def headquarter_cl_normalization_submodule(state: HeadquarterModule9AgentState) -> HeadquarterModule9AgentState:
    """3. 본부 내 CL별 정규화 서브모듈"""
    
    try:
        cl_groups = state["cl_groups"]
        
        print("🔄 본부 내 CL별 정규화 시작...")
        
        # CL별 정규화 실행 (무조건 정규화)
        normalized_scores = []
        
        for cl, members in cl_groups.items():
            if len(members) > 0:
                print(f"\n📊 {cl} 정규화 처리:")
                normalized_members = normalize_cl_group(members, cl)
                normalized_scores.extend(normalized_members)
        
        # 정규화 통계 출력
        raw_scores = [m["manager_score"] for m in normalized_scores]
        norm_scores = [m["final_score"] for m in normalized_scores]
        
        print(f"\n📈 정규화 결과:")
        print(f"   원시점수: 평균 {statistics.mean(raw_scores):.2f}, 표준편차 {statistics.stdev(raw_scores) if len(raw_scores) > 1 else 0:.2f}")
        print(f"   정규화점수: 평균 {statistics.mean(norm_scores):.2f}, 표준편차 {statistics.stdev(norm_scores) if len(norm_scores) > 1 else 0:.2f}")
        
        updated_state = state.copy()
        updated_state.update({
            "messages": [HumanMessage(content=f"CL별 정규화 완료: {len(normalized_scores)}명")],
            "normalized_scores": normalized_scores
        })
        return updated_state
        
    except Exception as e:
        updated_state = state.copy()
        updated_state["messages"] = [HumanMessage(content=f"정규화 실패: {str(e)}")]
        raise e

def headquarter_batch_storage_submodule(state: HeadquarterModule9AgentState) -> HeadquarterModule9AgentState:
    """4. 본부 배치 저장 서브모듈"""
    
    try:
        normalized_scores = state["normalized_scores"]
        
        print("💾 배치 저장 시작...")
        
        # 배치 업데이트 실행
        update_result = batch_update_final_evaluation_reports(normalized_scores)
        
        updated_state = state.copy()
        updated_state.update({
            "messages": [HumanMessage(content=f"배치 저장 완료: 성공 {update_result['success_count']}건")],
            "processed_count": update_result["success_count"],
            "failed_members": update_result["failed_members"]
        })
        return updated_state
        
    except Exception as e:
        updated_state = state.copy()
        updated_state.update({
            "messages": [HumanMessage(content=f"배치 저장 실패: {str(e)}")],
            "processed_count": 0,
            "failed_members": []
        })
        raise e

# ================================================================
# 본부 단위 워크플로우 생성
# ================================================================

def create_headquarter_module9_graph():
    """본부 단위 모듈 9 그래프 생성 및 반환"""
    headquarter_module9_workflow = StateGraph(HeadquarterModule9AgentState)
    
    # 노드 추가
    headquarter_module9_workflow.add_node("headquarter_data_collection", headquarter_data_collection_submodule)
    headquarter_module9_workflow.add_node("headquarter_cl_grouping", headquarter_cl_grouping_submodule)
    headquarter_module9_workflow.add_node("headquarter_cl_normalization", headquarter_cl_normalization_submodule)
    headquarter_module9_workflow.add_node("headquarter_batch_storage", headquarter_batch_storage_submodule)
    
    # 엣지 정의 (순차 실행)
    headquarter_module9_workflow.add_edge(START, "headquarter_data_collection")
    headquarter_module9_workflow.add_edge("headquarter_data_collection", "headquarter_cl_grouping")
    headquarter_module9_workflow.add_edge("headquarter_cl_grouping", "headquarter_cl_normalization")
    headquarter_module9_workflow.add_edge("headquarter_cl_normalization", "headquarter_batch_storage")
    headquarter_module9_workflow.add_edge("headquarter_batch_storage", END)
    
    return headquarter_module9_workflow.compile()

# ================================================================
# 실행 함수들
# ================================================================

def run_headquarter_module9_evaluation(headquarter_id: str, period_id: int = 4):
    """본부 단위 모듈 9 CL별 정규화 실행"""
    
    print(f"🚀 본부 단위 모듈 9 CL별 정규화 실행 시작: {headquarter_id} (period_id: {period_id})")
    
    # State 정의
    state = HeadquarterModule9AgentState(
        messages=[HumanMessage(content=f"본부 {headquarter_id}: CL별 정규화 시작")],
        headquarter_id=headquarter_id,
        period_id=period_id,
        headquarter_members=[],
        cl_groups={},
        normalized_scores=[],
        processed_count=0,
        failed_members=[]
    )
    
    # 그래프 생성 및 실행
    headquarter_module9_graph = create_headquarter_module9_graph()
    
    try:
        result = headquarter_module9_graph.invoke(state)
        
        print("✅ 본부 단위 모듈 9 CL별 정규화 실행 완료!")
        print(f"📊 결과:")
        for message in result['messages']:
            print(f"  - {message.content}")
        
        if result.get('processed_count'):
            print(f"🎯 처리 완료: {result['processed_count']}명")
            if result.get('failed_members'):
                print(f"❌ 실패한 직원: {result['failed_members']}")
        
        return result
        
    except Exception as e:
        print(f"❌ 본부 단위 모듈 9 CL별 정규화 실행 실패: {e}")
        return None

def run_all_headquarters_module9(period_id: int = 4):
    """전체 본부 일괄 실행"""
    
    # 모든 본부 ID 조회
    with engine.connect() as connection:
        query = text("""
            SELECT DISTINCT h.headquarter_id, h.headquarter_name
            FROM headquarters h
            JOIN teams t ON h.headquarter_id = t.headquarter_id
            JOIN employees e ON t.team_id = e.team_id
            JOIN temp_evaluations te ON e.emp_no = te.TempEvaluation_empNo
            WHERE te.manager_score IS NOT NULL
            ORDER BY h.headquarter_id
        """)
        headquarters = connection.execute(query).fetchall()
    
    print(f"🚀 전체 본부 모듈 9 CL별 정규화 실행: {len(headquarters)}개 본부")
    
    results = {}
    total_processed = 0
    total_failed = 0
    
    for hq in headquarters:
        headquarter_id = hq.headquarter_id
        headquarter_name = hq.headquarter_name
        
        print(f"\n{'='*50}")
        print(f"본부 {headquarter_id} ({headquarter_name}) 처리 중...")
        
        result = run_headquarter_module9_evaluation(headquarter_id, period_id)
        results[headquarter_id] = result
        
        if result:
            total_processed += result.get('processed_count', 0)
            total_failed += len(result.get('failed_members', []))
    
    print(f"\n🎯 전체 결과:")
    print(f"   처리된 본부: {len([r for r in results.values() if r is not None])}/{len(headquarters)}")
    print(f"   처리된 인원: {total_processed}명")
    print(f"   실패한 인원: {total_failed}명")
    
    return results

# ================================================================
# 테스트 및 디버깅 함수들
# ================================================================

def get_all_headquarters_with_data(period_id: int = 4) -> List[str]:
    """평가 데이터가 있는 모든 본부 ID 조회"""
    with engine.connect() as connection:
        query = text("""
            SELECT DISTINCT h.headquarter_id
            FROM headquarters h
            JOIN teams t ON h.headquarter_id = t.headquarter_id
            JOIN employees e ON t.team_id = e.team_id
            JOIN temp_evaluations te ON e.emp_no = te.TempEvaluation_empNo
            WHERE te.manager_score IS NOT NULL
            ORDER BY h.headquarter_id
        """)
        results = connection.execute(query).fetchall()
        return [row.headquarter_id for row in results]

def test_headquarter_module9(headquarter_id: str = None, period_id: int = 4):
    """본부 모듈 9 CL별 정규화 테스트"""
    if not headquarter_id:
        headquarters = get_all_headquarters_with_data(period_id)
        if headquarters:
            headquarter_id = headquarters[0]
            print(f"🧪 테스트 본부 자동 선택: {headquarter_id}")
        else:
            print("❌ 테스트할 본부가 없습니다")
            return
    
    return run_headquarter_module9_evaluation(headquarter_id, period_id)

# ================================================================
# 실행 예시
# ================================================================

if __name__ == "__main__":
    print("🚀 본부 단위 모듈 9 CL별 정규화 준비 완료!")
    print("\n🔥 주요 특징:")
    print("✅ 본부 내 CL별 정규화 (무조건 정규화 적용)")
    print("✅ CL별 목표: 평균 3.5점, CL별 표준편차 차등")
    print("✅ final_evaluation_reports.score, cl_reason 업데이트")
    print("✅ 본부 단위 배치 처리")
    
    print("\n실행 명령어:")
    print("1. run_headquarter_module9_evaluation('HQ001', 4)     # 단일 본부 실행")
    print("2. run_all_headquarters_module9(4)                   # 전체 본부 일괄 실행")
    print("3. test_headquarter_module9()                        # 테스트 실행")
    
    # 자동 테스트 (필요시 주석 해제)
    test_headquarter_module9()

LLM Client initialized: gpt-4o-mini
🚀 본부 단위 모듈 9 CL별 정규화 준비 완료!

🔥 주요 특징:
✅ 본부 내 CL별 정규화 (무조건 정규화 적용)
✅ CL별 목표: 평균 3.5점, CL별 표준편차 차등
✅ final_evaluation_reports.score, cl_reason 업데이트
✅ 본부 단위 배치 처리

실행 명령어:
1. run_headquarter_module9_evaluation('HQ001', 4)     # 단일 본부 실행
2. run_all_headquarters_module9(4)                   # 전체 본부 일괄 실행
3. test_headquarter_module9()                        # 테스트 실행
🧪 테스트 본부 자동 선택: 1
🚀 본부 단위 모듈 9 CL별 정규화 실행 시작: 1 (period_id: 4)
🔍 본부 데이터 수집 시작: 1
   본부 내 직원 수: 30명
📊 본부 내 CL별 그룹화 시작...
   CL별 분포: CL3(10명), CL2(13명), CL1(7명)
🔄 본부 내 CL별 정규화 시작...

📊 CL1 정규화 처리:
   CL1 그룹 (7명) 정규화 처리:
     정규화 적용: 평균 3.10 → 3.5, 표준편차 0.22 → 1.4
     E016: 3.00 → 2.85 (본부 내 CL1 정규화 (Z-Score: -0.46))
     E030: 2.90 → 2.20 (본부 내 CL1 정규화 (Z-Score: -0.93))
     E011: 3.10 → 3.50 (본부 내 CL1 정규화 (Z-Score: 0.00))
     E021: 3.30 → 4.80 (본부 내 CL1 정규화 (Z-Score: 0.93))
     E006: 3.20 → 4.15 (본부 내 CL1 정규화 (Z-Score: 0.46))
     E026: 3.40 → 5.00 (본부 내 CL1 정규화 (Z-Score: 1.39))
     E031: 2.