# ========================================================================================
# JAX 1D-FDM + Newton inverse fitting (hout, hwi, hwo) - IMPROVED VERSION
# ========================================================================================
# 목적: 140/280/420W 각 전력 케이스별로 (hout, hwi, hwo) 3개 파라미터를 JAX로 피팅
# 
# 주요 개선사항:
#   1. Newton 수렴성 진단 기능 추가 (매 epoch마다 ||F||, ||dU|| 출력)
#   2. Gradient 흐름 복구 (stop_gradient 제거 옵션)
#   3. 학습률 증가 (1e-8 → 1e-4)
#   4. 경계값 완화 (더 넓은 탐색 공간)
#   5. 파라미터별 gradient 크기 모니터링
#   6. 경계값 경고 시스템
# ========================================================================================

In [None]:
# -------------------------------------------
# 0) 라이브러리 로드
# -------------------------------------------
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
from jax import lax
import optax

print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")

In [None]:
# ------------------------------------------------------------------------------------
# 1) 데이터 로드
# ------------------------------------------------------------------------------------
tc_pos = jnp.array(json.load(open('tc_pos.json'))) # Thermocouple position[m]
T_140 = jnp.array(json.load(open('Temp_profile_140W.json'))) # Temperature profile @ 140W [C]
T_280 = jnp.array(json.load(open('Temp_profile_280W.json'))) # Temperature profile @ 280W [C]
T_420 = jnp.array(json.load(open('Temp_profile_420W.json'))) # Temperature profile @ 420W [C]

powers = jnp.array([140.0, 280.0, 420.0]) # Applied power [W]
T_meas_cases = jnp.array([T_140, T_280, T_420]) # Measured temperature profile [C] (3, n_TC)
num_cases, n_TC = T_meas_cases.shape

print(f"Number of cases: {num_cases}")
print(f"Number of thermocouples: {n_TC}")
print(f"TC positions [m]: {tc_pos}")

plt.figure(figsize=(10, 6))
plt.scatter(np.array(tc_pos), np.array(T_140), label='140W', s=100, alpha=0.7)
plt.scatter(np.array(tc_pos), np.array(T_280), label='280W', s=100, alpha=0.7)
plt.scatter(np.array(tc_pos), np.array(T_420), label='420W', s=100, alpha=0.7)
plt.legend(fontsize=12)
plt.xlabel('z [m]', fontsize=12)
plt.ylabel('T [C]', fontsize=12)
plt.title('Measured Temperature Profiles', fontsize=14)
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# ------------------------------------------------------------------------------------
# 2) Geometry / Grid
# ------------------------------------------------------------------------------------
L = 430e-3  # Reactor length [m]
ID = 5.03e-3  # Inner diameter [m]
OD = 6.33e-3  # Outer diameter [m]

Ai = jnp.pi * ID**2 / 4  # Inner cross-sectional area [m2]
Ao = jnp.pi * OD**2 / 4  # Outer cross-sectional area [m2]
Aw = Ao - Ai  # Wall cross-sectional area [m2]
pri = jnp.pi * ID  # Inner perimeter [m]
pro = jnp.pi * OD  # Outer perimeter [m]

dz = 0.001  # Grid size [m]
n_nodes = int(L/dz) + 1  # Number of nodes
z = jnp.linspace(0, L, n_nodes)  # Node positions [m]

Awg = pri * dz # inner surface area per grid [m2]
Aout = pro * dz # outer surface area per grid [m2]

tc_idx = jnp.array([
    jnp.argmin(jnp.abs(z - zp)) for zp in tc_pos
])

print(f"Reactor length: {L*1000:.1f} mm")
print(f"Grid size: {dz*1000:.2f} mm")
print(f"Number of nodes: {n_nodes}")
print(f"TC indices: {tc_idx}")

In [None]:
# ------------------------------------------------------------------------------------
# 3) Material properties (He, KanthalD)
# ------------------------------------------------------------------------------------
path_he = '../data/He_property.csv'
path_kt = '../data/KanthalD_property.csv'

df_he = pd.read_csv(path_he)
df_kt = pd.read_csv(path_kt)
df_kt_k = df_kt[df_kt['k [W/m*K]'].notna()]
df_kt_cp = df_kt[df_kt['Cp [kJ/kg*K]'].notna()]

Mw_he = 4.0026 # He molar weight [g/mol]

## Jax-friendly 1D interpolation
def interp1d(xq, x, y):
    xq = jnp.asarray(xq)
    x  = jnp.asarray(x)
    y  = jnp.asarray(y)
    return jnp.interp(xq, x, y)

## Properties of He (simplified as constants)
def rho_he(T): # He mass density [kg/m3]
    return 0.16352 * jnp.ones_like(T)

def cp_he(T): # He specific heat [J/mol*K]
    return 20.786 * jnp.ones_like(T)

def k_he(T): # He thermal conductivity [W/m/K]
    return 0.15531 * jnp.ones_like(T)

## Properties of kanthalD (simplified as constants)
def rho_kt(T): # kanthalD mass density [kg/m3]
    return 7250 * jnp.ones_like(T)

def cp_kt(T): # kanthalD specific heat [J/kg*K]
    return 0.63 * 1000 * jnp.ones_like(T)

def k_kt(T): # kanthalD thermal conductivity [W/m*K]
    return 11 * jnp.ones_like(T)

## heat transfer coefficient of inner fluid
def h_wg(Tg):
    Nu = 4.36
    return Nu * k_he(Tg) / ID

print("Material properties loaded successfully")

In [None]:
# ------------------------------------------------------------------------------------
# 4) Feed conditions / constant
# ------------------------------------------------------------------------------------
P = 101325.0 # pressure [Pa]
Tamb = 25.0 # ambient temperature [C]

Fv_std = 50.0 # volumetric feed flow rate at standard condition [mL/min]
Fw = Fv_std * 1e-6 / 60 * rho_he(0) # feed mass flow rate [kg/s]
F = Fw / Mw_he * 1000.0 # feed molar flow rate [mol/s]

print(f"Ambient temperature: {Tamb} °C")
print(f"Feed flow rate: {Fv_std} mL/min")
print(f"Molar flow rate: {F:.6f} mol/s")

In [None]:
# ------------------------------------------------------------------------------------
# 5) Residual function
#   - coeffs=(hout, hwi, hwo) + Pw
# ------------------------------------------------------------------------------------
def residual(U, coeffs, Pw):
    hout, hwi, hwo = coeffs

    Tw = U[:n_nodes]
    Tg = U[n_nodes:]

    kw = k_kt(Tw)
    cpg = cp_he(Tg)
    hwg = h_wg(Tg)

    ## joule heating
    Qelec = Pw * dz / L # heat generation per unit length [W]

    ## residual variables initialization
    rw = jnp.zeros((n_nodes,))
    rg = jnp.zeros((n_nodes,))

    ## Wall 
    ### Robin BC for Tw: heat loss on the front and back surfaces
    rw = rw.at[0].set(hwi * Aw * (Tw[0] - Tamb) + kw[0] * Aw * (Tw[0] - Tw[1]) / dz)
    rw = rw.at[-1].set(hwo * Aw * (Tw[-1] - Tamb) + kw[-1] * Aw * (Tw[-1] - Tw[-2]) / dz)

    ### interior conduction
    kw_half = 0.5 * (kw[1:] + kw[:-1])
    wflux = kw_half * Aw * (Tw[1:] - Tw[:-1]) / dz
    Qcond = (wflux[1:] - wflux[:-1]) / dz

    ### wall-gas exchange + outer loss
    Qwg = hwg[1:-1] * Awg * (Tw[1:-1] - Tg[1:-1])
    Qout = hout * Aout * (Tw[1:-1] - Tamb)

    ### wall energy balance
    rw = rw.at[1:-1].set(Qcond + (Qelec/dz) - (Qwg/dz) - (Qout/dz))
    
    ## Gas
    ### Assume no temperature difference between the gas and the wall at the front and back
    rg = rg.at[0].set(Tg[0] - Tw[0])
    rg = rg.at[-1].set(Tg[-1] - Tw[-1])

    ### gas energy balance
    gflux = F * cpg[1:-1] * (Tg[1:-1] - Tg[:-2]) / dz
    rg = rg.at[1:-1].set((Qwg/dz) - gflux)

    return jnp.concatenate([rw, rg])

print("Residual function defined")

In [None]:
# ------------------------------------------------------------------------------------
# 6) Newton solver (JAX, scan)
# ------------------------------------------------------------------------------------
def newton_step(residual_fn, damping = 1.0):
    # residual_fn: U -> r
    def step(U, _):
        F = residual_fn(U)
        # J = dr/dU
        J = jax.jacfwd(residual_fn)(U)
        # solve J dU = -r
        dU = jnp.linalg.solve(J, -F)
        U_new = U + damping * dU

        res_norm = jnp.linalg.norm(F)
        step_norm = jnp.linalg.norm(dU)

        return U_new, (res_norm, step_norm)
    return step

def newton_solve(residual_fn, iters=20, damping=1.0):
    step = newton_step(residual_fn, damping = damping)
    def solve(U0):
        U_final, (res_hist, step_hist) = lax.scan(step, U0, xs=None, length=iters)
        return U_final, res_hist, step_hist
    return solve

print("Newton solver defined")

In [None]:
# ------------------------------------------------------------------------------------
# 7) Case solver + prediction
# ------------------------------------------------------------------------------------
def solve_case(coeffs, Pw, U0, iters=20, damping=1.0):
    res_fn = lambda U: residual(U, coeffs, Pw)
    newton = newton_solve(res_fn, iters=iters, damping=damping)
    U_star, res_hist, step_hist = newton(U0)
    return U_star, res_hist, step_hist

def predict_TC(U):
    Tw = U[:n_nodes]
    return Tw[tc_idx]

print("Case solver defined")

In [None]:
# ------------------------------------------------------------------------------------
# 8) Fitting setup with IMPROVED diagnostics
#   - theta_raw shape: (num_case, 3), softplus로 양수화
#   - warm-start: U0_case를 step마다 업데이트
#   - OPTION: stop_gradient 제거하여 gradient 흐름 복구
# ------------------------------------------------------------------------------------

# 설정 옵션
USE_STOP_GRADIENT = False  # True: 원본 방식, False: gradient 흐름 허용 (권장)
LEARNING_RATE = 1e-4       # 원본: 1e-8, 권장: 1e-4 ~ 1e-3
WIDE_BOUNDS = True         # True: 넓은 경계값, False: 원본 경계값

def theta_phys(theta_raw):
    """theta_raw를 물리적으로 유효한 h 값으로 변환"""
    eps = 1e-6
    h = jax.nn.softplus(theta_raw) + eps
    
    if WIDE_BOUNDS:
        # 넓은 경계값: 더 많은 탐색 공간
        h = jnp.clip(h, jnp.array([0.1, 0.1, 0.1]), jnp.array([500, 50000, 50000]))
    else:
        # 원본 경계값
        h = jnp.clip(h, jnp.array([1.0, 1.0, 1.0]), jnp.array([100, 10000, 10000]))
    
    return h

def softplus_inv(h):
    """h를 theta_raw로 역변환 (초기화용)"""
    return jnp.log(jnp.exp(h) - 1.0)

# 초기 파라미터 값
h_init = jnp.array([
    [30., 100., 100.],
    [30., 100., 100.],
    [30., 100., 100.]
])

# 초기 상태
Tw0 = Tamb * jnp.ones(n_nodes)
Tg0 = Tamb * jnp.ones(n_nodes)
U0 = jnp.concatenate([Tw0, Tg0])

U0_cases = jnp.stack([U0, U0, U0])  # (num_cases, 2*n_nodes)

# 학습 파라미터 초기화
theta_raw = softplus_inv(h_init)   # (3,3)

def case_loss(coeffs, Pw, U0, T_meas):
    """단일 케이스의 손실함수 + Newton 수렴 정보"""
    U_star, res_hist, step_hist = solve_case(coeffs, Pw, U0, iters=20, damping=1.0)
    T_pred = predict_TC(U_star)
    loss = jnp.mean((T_pred - T_meas)**2)
    
    # Newton 수렴 진단 정보
    final_res = res_hist[-1]
    final_step = step_hist[-1]
    
    return loss, U_star, final_res, final_step

def total_loss(theta_raw, U0_cases, T_meas_cases, powers):
    """전체 손실함수 (3개 케이스 합산)"""
    theta = theta_phys(theta_raw)

    def one_case(theta_k, U0_k, T_k, Pw_k):
        loss_k, _, _, _ = case_loss(theta_k, Pw_k, U0_k, T_k)
        return loss_k

    losses = jax.vmap(one_case)(theta, U0_cases, T_meas_cases, powers)
    return jnp.sum(losses)

def warm_start_update(theta_raw, U0_cases, T_meas_cases, powers):
    """Newton 초기값 업데이트 (warm-start)"""
    theta = theta_phys(theta_raw)

    def one_case(theta_k, U0_k, T_k, Pw_k):
        _, U_star, _, _ = case_loss(theta_k, Pw_k, U0_k, T_k)
        
        if USE_STOP_GRADIENT:
            # 원본: gradient 차단
            return lax.stop_gradient(U_star)
        else:
            # 개선: gradient 흐름 허용
            return U_star

    return jax.vmap(one_case)(theta, U0_cases, T_meas_cases, powers)

def get_newton_convergence(theta_raw, U0_cases, T_meas_cases, powers):
    """모든 케이스의 Newton 수렴 정보 추출"""
    theta = theta_phys(theta_raw)

    def one_case(theta_k, U0_k, T_k, Pw_k):
        _, _, final_res, final_step = case_loss(theta_k, Pw_k, U0_k, T_k)
        return final_res, final_step

    res_norms, step_norms = jax.vmap(one_case)(theta, U0_cases, T_meas_cases, powers)
    return res_norms, step_norms

print("="*80)
print("Fitting setup complete")
print(f"  - stop_gradient: {USE_STOP_GRADIENT} (False = 개선 버전)")
print(f"  - Learning rate: {LEARNING_RATE}")
print(f"  - Wide bounds: {WIDE_BOUNDS}")
if WIDE_BOUNDS:
    print(f"  - Bounds: hout[0.1-500], hwi[0.1-50000], hwo[0.1-50000]")
else:
    print(f"  - Bounds: hout[1-100], hwi[1-10000], hwo[1-10000]")
print("="*80)

In [None]:
# ------------------------------------------------------------------------------------
# 9) Training Loop with ENHANCED diagnostics
# ------------------------------------------------------------------------------------
opt = optax.adam(learning_rate=LEARNING_RATE)
opt_state = opt.init(theta_raw)

loss_and_grad = jax.value_and_grad(total_loss)

print("="*80)
print("TRAINING START: Enhanced Newton Convergence Diagnostics")
print("="*80)
print(f"Optimizer: Adam (lr={LEARNING_RATE})")
print(f"Epochs: 50 (매 epoch 출력)")
print("="*80)

# 학습 히스토리 저장
history = {
    'loss': [],
    'params': [],
    'grads': [],
    'newton_res': [],
    'newton_step': []
}

for step in range(50):  # 50 epoch으로 줄임 (진단 목적)
    # Loss와 gradient 계산
    loss, grads = loss_and_grad(theta_raw, U0_cases, T_meas_cases, powers)
    
    # 파라미터 업데이트
    updates, opt_state = opt.update(grads, opt_state)
    theta_raw = optax.apply_updates(theta_raw, updates)

    # Newton 수렴 정보
    res_norms, step_norms = get_newton_convergence(theta_raw, U0_cases, T_meas_cases, powers)
    
    # Warm-start 업데이트
    U0_cases = warm_start_update(theta_raw, U0_cases, T_meas_cases, powers)

    # 현재 파라미터
    theta = theta_phys(theta_raw)
    
    # 히스토리 저장
    history['loss'].append(float(loss))
    history['params'].append(np.array(theta))
    history['grads'].append(np.array(grads))
    history['newton_res'].append(np.array(res_norms))
    history['newton_step'].append(np.array(step_norms))
    
    # 상세 출력 (매 epoch)
    print(f'\n{"="*80}')
    print(f'EPOCH {step:04d} | Total Loss = {float(loss):.6e}')
    print(f'{"="*80}')
    
    for k, Pw_k in enumerate([140., 280., 420.]):
        hout_k, hwi_k, hwo_k = map(float, theta[k])
        grad_hout, grad_hwi, grad_hwo = map(float, grads[k])
        
        print(f'\n  케이스 {k+1} ({Pw_k}W):')
        print(f'    파라미터:   hout={hout_k:.3f}, hwi={hwi_k:.3f}, hwo={hwo_k:.3f}')
        print(f'    Gradient:   ∇hout={grad_hout:.3e}, ∇hwi={grad_hwi:.3e}, ∇hwo={grad_hwo:.3e}')
        print(f'    Newton:     ||F||={float(res_norms[k]):.3e}, ||dU||={float(step_norms[k]):.3e}')
        
        # 경계값 체크
        if WIDE_BOUNDS:
            at_lower = [hout_k <= 0.11, hwi_k <= 0.11, hwo_k <= 0.11]
            at_upper = [hout_k >= 499, hwi_k >= 49999, hwo_k >= 49999]
        else:
            at_lower = [hout_k <= 1.01, hwi_k <= 1.01, hwo_k <= 1.01]
            at_upper = [hout_k >= 99.9, hwi_k >= 9999, hwo_k >= 9999]
        
        if any(at_lower) or any(at_upper):
            bound_info = []
            if at_lower[0] or at_upper[0]: bound_info.append("hout")
            if at_lower[1] or at_upper[1]: bound_info.append("hwi")
            if at_lower[2] or at_upper[2]: bound_info.append("hwo")
            print(f'    ⚠️  WARNING: 경계값에 도달 → {", ".join(bound_info)}')
        
        # Newton 수렴 경고
        if float(res_norms[k]) > 1e-4:
            print(f'    ⚠️  WARNING: Newton 수렴 불량 (||F|| > 1e-4)')
        
        # Gradient 소실 경고
        grad_max = max(abs(grad_hout), abs(grad_hwi), abs(grad_hwo))
        if grad_max < 1e-15:
            print(f'    ⚠️  WARNING: Gradient 소실 (max|∇| < 1e-15)')
    
    print(f'{"="*80}')

print(f'\n{"="*80}')
print("TRAINING COMPLETE")
print(f'{"="*80}')

In [None]:
# ------------------------------------------------------------------------------------
# 10) Training history visualization
# ------------------------------------------------------------------------------------
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Loss history
axes[0, 0].semilogy(history['loss'])
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Total Loss')
axes[0, 0].set_title('Loss History')
axes[0, 0].grid(True, alpha=0.3)

# Parameter evolution
params_array = np.array(history['params'])
for k in range(3):
    axes[0, 1].plot(params_array[:, k, 0], label=f'Case {k+1} hout', marker='o', markersize=3)
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('hout')
axes[0, 1].set_title('hout Evolution')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Newton residual norm
newton_res_array = np.array(history['newton_res'])
for k in range(3):
    axes[1, 0].semilogy(newton_res_array[:, k], label=f'Case {k+1}', marker='o', markersize=3)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('||F||')
axes[1, 0].set_title('Newton Residual Norm')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Gradient magnitude
grads_array = np.array(history['grads'])
for k in range(3):
    grad_norm = np.linalg.norm(grads_array[:, k, :], axis=1)
    axes[1, 1].semilogy(grad_norm, label=f'Case {k+1}', marker='o', markersize=3)
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('||∇L||')
axes[1, 1].set_title('Gradient Magnitude')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# ------------------------------------------------------------------------------------
# 11) Results visualization
# ------------------------------------------------------------------------------------
theta = theta_phys(theta_raw)

print("\n" + "="*80)
print("FINAL RESULTS")
print("="*80)

for k, Pw_k in enumerate([140., 280., 420.]):
    coeffs = theta[k]
    U_star, res_hist, step_hist = solve_case(coeffs, Pw_k, U0_cases[k], iters=20, damping=1.0)
    Tw = U_star[:n_nodes]
    Tg = U_star[n_nodes:]
    Tw_tc = predict_TC(U_star)
    
    print(f"\n케이스 {k+1} ({Pw_k}W):")
    print(f"  최적 파라미터: hout={float(coeffs[0]):.3f}, hwi={float(coeffs[1]):.3f}, hwo={float(coeffs[2]):.3f}")
    print(f"  Newton 수렴:   ||F||={float(res_hist[-1]):.3e}, ||dU||={float(step_hist[-1]):.3e}")
    print(f"  TC 위치 MSE:   {float(jnp.mean((Tw_tc - T_meas_cases[k])**2)):.3e}")

    # Temperature profile plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # Left: Temperature profiles
    ax1.plot(np.array(z), np.array(Tw), label='Tw (model)', linewidth=2)
    ax1.plot(np.array(z), np.array(Tg), label='Tg (model)', linewidth=2, linestyle='--')
    ax1.scatter(np.array(tc_pos), np.array(T_meas_cases[k]), label='TC (measured)', 
                s=100, c='red', marker='o', zorder=5)
    ax1.scatter(np.array(tc_pos), np.array(Tw_tc), label='Tw @ TC', 
                s=80, c='blue', marker='x', linewidths=2, zorder=5)
    ax1.set_title(f'{Pw_k}W | hout={float(coeffs[0]):.2f}, hwi={float(coeffs[1]):.2f}, hwo={float(coeffs[2]):.2f}', 
                  fontsize=12)
    ax1.legend(fontsize=10)
    ax1.set_xlabel('z [m]', fontsize=11)
    ax1.set_ylabel('T [°C]', fontsize=11)
    ax1.grid(True, alpha=0.3)
    
    # Right: Newton convergence
    ax2.semilogy(np.array(res_hist), label='||F|| (residual)', marker='o', markersize=4)
    ax2.semilogy(np.array(step_hist), label='||dU|| (step)', marker='s', markersize=4)
    ax2.axhline(1e-6, color='r', linestyle='--', alpha=0.5, label='Target (1e-6)')
    ax2.set_title(f'Newton Convergence History ({Pw_k}W)', fontsize=12)
    ax2.legend(fontsize=10)
    ax2.set_xlabel('Newton Iteration', fontsize=11)
    ax2.set_ylabel('Norm', fontsize=11)
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

print("\n" + "="*80)

In [None]:
# ------------------------------------------------------------------------------------
# 12) Diagnostic summary
# ------------------------------------------------------------------------------------
print("\n" + "="*80)
print("진단 요약 (DIAGNOSTIC SUMMARY)")
print("="*80)

print("\n1. 학습 설정:")
print(f"   - stop_gradient 사용: {USE_STOP_GRADIENT}")
print(f"   - Learning rate: {LEARNING_RATE}")
print(f"   - 경계값: {'Wide (0.1-50000)' if WIDE_BOUNDS else 'Narrow (1-10000)'}")

print("\n2. 최종 손실:")
print(f"   - Initial loss: {history['loss'][0]:.6e}")
print(f"   - Final loss:   {history['loss'][-1]:.6e}")
print(f"   - Reduction:    {(1 - history['loss'][-1]/history['loss'][0])*100:.2f}%")

print("\n3. 파라미터 변화:")
for k in range(3):
    init_params = history['params'][0][k]
    final_params = history['params'][-1][k]
    change = np.abs(final_params - init_params)
    print(f"   케이스 {k+1} ({powers[k]}W):")
    print(f"     hout: {init_params[0]:.3f} → {final_params[0]:.3f} (변화: {change[0]:.3f})")
    print(f"     hwi:  {init_params[1]:.3f} → {final_params[1]:.3f} (변화: {change[1]:.3f})")
    print(f"     hwo:  {init_params[2]:.3f} → {final_params[2]:.3f} (변화: {change[2]:.3f})")

print("\n4. Newton 수렴 상태:")
final_newton_res = history['newton_res'][-1]
for k in range(3):
    status = "✅ 양호" if final_newton_res[k] < 1e-6 else "⚠️  주의" if final_newton_res[k] < 1e-4 else "❌ 불량"
    print(f"   케이스 {k+1}: ||F|| = {final_newton_res[k]:.3e} {status}")

print("\n5. Gradient 상태:")
final_grads = history['grads'][-1]
for k in range(3):
    grad_norm = np.linalg.norm(final_grads[k])
    status = "✅ 정상" if grad_norm > 1e-10 else "⚠️  약함" if grad_norm > 1e-15 else "❌ 소실"
    print(f"   케이스 {k+1}: ||∇|| = {grad_norm:.3e} {status}")

print("\n" + "="*80)
print("권장 사항:")
print("="*80)

if history['loss'][-1] / history['loss'][0] > 0.9:
    print("⚠️  손실 감소가 부족합니다 (<10%).")
    print("   → Learning rate를 증가시켜보세요 (현재: {:.0e} → 권장: {:.0e})".format(LEARNING_RATE, LEARNING_RATE*10))
    if USE_STOP_GRADIENT:
        print("   → stop_gradient를 False로 설정하세요")

if any(final_newton_res > 1e-4):
    print("⚠️  Newton solver 수렴이 불량합니다.")
    print("   → damping factor를 0.5~0.8로 줄여보세요")
    print("   → Newton iteration 횟수를 늘려보세요 (현재: 20 → 권장: 30)")

if np.max([np.linalg.norm(g) for g in history['grads'][-1]]) < 1e-15:
    print("⚠️  Gradient가 소실되었습니다.")
    print("   → stop_gradient를 False로 설정하세요")
    print("   → Learning rate를 증가시켜보세요")

if history['loss'][-1] / history['loss'][0] < 0.5:
    print("✅ 학습이 양호하게 진행되었습니다 (>50% 손실 감소).")
    print("   → epoch 수를 늘려서 더 학습해보세요 (현재: 50 → 권장: 200~500)")

print("\n" + "="*80)