In [7]:
import plotly.graph_objects as go
import os
import json
import glob
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
import pandas as pd

from Server.model_lstm import LSTMRegressor

ROUND = 1
ROUND_DIR = os.path.join("Rounds", f"round_{ROUND:04d}")

GLOBAL_JSON = os.path.join(ROUND_DIR, "global.json")
GLOBAL_PT   = os.path.join(ROUND_DIR, "global.pt")
UPDATES_DIR = os.path.join(ROUND_DIR, "updates")

assert os.path.exists(GLOBAL_JSON), f"missing: {GLOBAL_JSON}"
assert os.path.exists(GLOBAL_PT), f"missing: {GLOBAL_PT}"

device = "cpu"


In [2]:
with open(GLOBAL_JSON, "r", encoding="utf-8") as f:
    meta = json.load(f)

cfg = meta["config"]

global_model = LSTMRegressor(
    input_size=cfg["input_size"],
    hidden_size=cfg["hidden_size"],
    num_layers=cfg["num_layers"],
    output_size=cfg["output_size"],
    dropout=cfg.get("dropout", 0.0),
).to(device)

global_sd = torch.load(GLOBAL_PT, map_location=device)
global_model.load_state_dict(global_sd, strict=True)

cfg


{'input_size': 1,
 'hidden_size': 64,
 'num_layers': 1,
 'output_size': 1,
 'dropout': 0.0,
 'seq_len': 10}

In [3]:
update_jsons = sorted(glob.glob(os.path.join(UPDATES_DIR, "client_*.json")))
print("found update metas:", len(update_jsons))
update_jsons[:5]


found update metas: 0


[]

In [4]:
def load_client_update(meta_path: str):
    with open(meta_path, "r", encoding="utf-8") as f:
        m = json.load(f)
    pt_path = m["weights_path"]
    if not os.path.exists(pt_path):
        raise FileNotFoundError(pt_path)
    sd = torch.load(pt_path, map_location=device)
    return m, sd

client_updates = []
for mp in update_jsons:
    m, sd = load_client_update(mp)
    client_updates.append((m, sd))

[(m["client_id"], m.get("local_loss", None), m["n_samples"]) for m, _ in client_updates]


[]

In [5]:
def state_dict_diff_stats(base_sd: dict, new_sd: dict):
    keys = list(base_sd.keys())
    per_key = []
    total_sq = 0.0
    total_n  = 0

    for k in keys:
        b = base_sd[k].detach().cpu().float()
        n = new_sd[k].detach().cpu().float()
        d = (n - b).reshape(-1)
        sq = float((d*d).sum().item())
        nn = d.numel()
        total_sq += sq
        total_n  += nn

        per_key.append({
            "key": k,
            "l2": float(np.sqrt(sq)),
            "rmse": float(np.sqrt(sq / max(nn, 1))),
            "max_abs": float(d.abs().max().item()) if nn > 0 else 0.0,
            "numel": nn
        })

    total_l2 = float(np.sqrt(total_sq))
    total_rmse = float(np.sqrt(total_sq / max(total_n, 1)))
    per_key_sorted = sorted(per_key, key=lambda x: x["l2"], reverse=True)

    return {
        "total_l2": total_l2,
        "total_rmse": total_rmse,
        "per_key_sorted": per_key_sorted
    }


In [6]:
summaries = []
for meta_c, sd_c in client_updates:
    stats = state_dict_diff_stats(global_sd, sd_c)
    summaries.append({
        "client_id": meta_c["client_id"],
        "n_samples": meta_c["n_samples"],
        "local_loss": meta_c.get("local_loss", None),
        "total_l2": stats["total_l2"],
        "total_rmse": stats["total_rmse"],
        "top_keys": stats["per_key_sorted"][:5],
    })

summaries


[]

In [7]:
for s in summaries:
    print(f"client {s['client_id']} | n={s['n_samples']} | loss={s['local_loss']} | ΔL2={s['total_l2']:.6f} | ΔRMSE={s['total_rmse']:.6f}")
    for tk in s["top_keys"]:
        print("  ", tk["key"], "| l2=", f"{tk['l2']:.6f}", "| rmse=", f"{tk['rmse']:.6f}", "| max_abs=", f"{tk['max_abs']:.6f}")
    print()


In [8]:
def flatten_params(sd: dict):
    arrs = []
    for k, v in sd.items():
        arrs.append(v.detach().cpu().float().reshape(-1))
    return torch.cat(arrs).numpy()

global_vec = flatten_params(global_sd)

# 비교할 client_id 선택
target_client_id = summaries[0]["client_id"] if len(summaries) > 0 else None
target_client_sd = None
for m, sd in client_updates:
    if m["client_id"] == target_client_id:
        target_client_sd = sd
        break

if target_client_sd is None:
    print("no client updates to plot")
else:
    client_vec = flatten_params(target_client_sd)

    plt.figure(figsize=(10,4))
    plt.hist(global_vec, bins=60, alpha=0.6, density=True, label="global")
    plt.hist(client_vec, bins=60, alpha=0.6, density=True, label=f"client {target_client_id}")
    plt.title("Parameter distribution (all params flattened)")
    plt.legend()
    plt.show()

    # 차이 분포도
    diff_vec = client_vec - global_vec
    plt.figure(figsize=(10,4))
    plt.hist(diff_vec, bins=60, alpha=0.8, density=True)
    plt.title(f"Parameter difference distribution (client {target_client_id} - global)")
    plt.show()


no client updates to plot


In [9]:
if target_client_sd is None:
    pass
else:
    stats = state_dict_diff_stats(global_sd, target_client_sd)
    top = stats["per_key_sorted"][:15]
    keys = [t["key"] for t in top][::-1]
    vals = [t["l2"] for t in top][::-1]

    plt.figure(figsize=(10,6))
    plt.barh(keys, vals)
    plt.title(f"Top-15 parameter tensors by L2 change (client {target_client_id})")
    plt.tight_layout()
    plt.show()


In [10]:
TEST_CSV = r"C:\Users\admin\OneDrive - 중앙대학교\Federated Learning\csv\Global Model Data.csv"
FEATURE_COLS = ["year"]          # 네가 학습에 쓴 feature_cols와 동일하게
TARGET_COL = "chloride"          # 네가 학습에 쓴 target_col과 동일하게
SEQ_LEN = cfg["seq_len"]         # global.json에 저장된 seq_len

df = pd.read_csv(TEST_CSV)

# 결측 제거 (최소 안전장치)
df = df.dropna(subset=FEATURE_COLS + [TARGET_COL]).reset_index(drop=True)

features = df[FEATURE_COLS].to_numpy(dtype=np.float32)
targets = df[TARGET_COL].to_numpy(dtype=np.float32)

def make_windows(features, targets, seq_len):
    if targets.ndim == 1:
        targets = targets.reshape(-1, 1)

    N, F = features.shape
    M = N - seq_len
    if M <= 0:
        raise ValueError(f"Not enough rows: N={N}, seq_len={seq_len}")

    X = np.zeros((M, seq_len, F), dtype=np.float32)
    y = np.zeros((M, 1), dtype=np.float32)

    for i in range(M):
        X[i] = features[i:i+seq_len]
        y[i] = targets[i+seq_len]
    return X, y

X, y = make_windows(features, targets, seq_len=SEQ_LEN)

X_t = torch.from_numpy(X).to(device)
y_t = torch.from_numpy(y).to(device)

global_model.eval()
with torch.no_grad():
    pred_global = global_model(X_t).detach().cpu().numpy().reshape(-1)

if target_client_sd is None:
    print("no client model to compare")
else:
    client_model = LSTMRegressor(
        input_size=cfg["input_size"],
        hidden_size=cfg["hidden_size"],
        num_layers=cfg["num_layers"],
        output_size=cfg["output_size"],
        dropout=cfg.get("dropout", 0.0),
    ).to(device)
    client_model.load_state_dict(target_client_sd, strict=True)
    client_model.eval()

    with torch.no_grad():
        pred_client = client_model(X_t).detach().cpu().numpy().reshape(-1)

    y_true = y.reshape(-1)

    plt.figure(figsize=(10,4))
    plt.plot(y_true, label="true")
    plt.plot(pred_global, label="global_pred")
    plt.plot(pred_client, label=f"client{target_client_id}_pred")
    plt.title("Predictions on same test windowed data")
    plt.legend()
    plt.tight_layout()
    plt.show()


no client model to compare


In [3]:
# ===== [GLOBAL MODEL CHECK] =====
print("\n[CHECK] Global model summary")

total_params = 0
for name, p in global_model.named_parameters():
    mean = p.data.mean().item()
    std  = p.data.std().item()
    print(f"{name:35s} | shape={tuple(p.shape)} | mean={mean:+.4e} | std={std:+.4e}")
    total_params += p.numel()

print(f"\n[CHECK] Total parameters: {total_params}")

# 파라미터 하나 찍어보기 (완전 랜덤인지 확인용)
for name, p in global_model.named_parameters():
    print(f"\n[CHECK] Sample value from {name}: {p.view(-1)[0].item():+.6f}")
    break
# ===== [END CHECK] =====



[CHECK] Global model summary
lstm.weight_ih_l0                   | shape=(256, 1) | mean=+1.3954e-03 | std=+7.4295e-02
lstm.weight_hh_l0                   | shape=(256, 64) | mean=-2.9717e-04 | std=+7.3733e-02
lstm.bias_ih_l0                     | shape=(256,) | mean=-8.8668e-04 | std=+7.5714e-02
lstm.bias_hh_l0                     | shape=(256,) | mean=-8.1919e-03 | std=+7.3768e-02
fc.weight                           | shape=(1, 64) | mean=+3.4557e-03 | std=+7.2109e-02
fc.bias                             | shape=(1,) | mean=+3.6012e-02 | std=+nan

[CHECK] Total parameters: 17217

[CHECK] Sample value from lstm.weight_ih_l0: +0.021841


  std  = p.data.std().item()


In [9]:
def plot_weight_hist_all(model, max_points=200_000, bins=80, title="Global model weight distribution"):
    vec = torch.cat([p.detach().cpu().float().reshape(-1) for p in model.parameters()])
    n = vec.numel()

    if n > max_points:
        idx = torch.randperm(n)[:max_points]
        v = vec[idx].numpy()
        subtitle = f"sampled {max_points}/{n}"
    else:
        v = vec.numpy()
        subtitle = f"n={n}"

    fig = go.Figure()
    fig.add_trace(go.Histogram(x=v, nbinsx=bins, name="weights"))
    fig.update_layout(
        title=f"{title} ({subtitle})",
        xaxis_title="parameter value",
        yaxis_title="count",
        template="plotly_white",
        bargap=0.02
    )
    fig.show()

plot_weight_hist_all(global_model, max_points=200_000, bins=80, title=f"Global weights | round {ROUND:04d}")

In [10]:
def plot_weight_hist_tensor(model, tensor_name, bins=80):
    sd = model.state_dict()
    if tensor_name not in sd:
        raise KeyError(f"not found: {tensor_name}\navailable example: {list(sd.keys())[:10]} ...")

    v = sd[tensor_name].detach().cpu().float().reshape(-1).numpy()

    fig = go.Figure()
    fig.add_trace(go.Histogram(x=v, nbinsx=bins, name=tensor_name))
    fig.update_layout(
        title=f"{tensor_name} distribution | round {ROUND:04d} | n={v.size}",
        xaxis_title="parameter value",
        yaxis_title="count",
        template="plotly_white",
        bargap=0.02
    )
    fig.show()

plot_weight_hist_tensor(global_model, "lstm.weight_ih_l0", bins=80)

In [8]:
# ===== CSV 로드 =====
csv_path = r"C:\Users\admin\OneDrive - 중앙대학교\Federated Learning\csv\Global Model Data.csv"
df = pd.read_csv(csv_path)

years = df["year"].values.astype(np.float32)
chloride_true = df["chloride"].values.astype(np.float32)

# LSTM 입력 시퀀스
SEQ_LEN = cfg["seq_len"]

X = []
for i in range(len(years) - SEQ_LEN):
    X.append(years[i:i+SEQ_LEN])

X = torch.tensor(X).unsqueeze(-1)

# 예측
global_model.eval()
with torch.no_grad():
    y_pred = global_model(X).cpu().numpy().flatten()

x_plot = years[SEQ_LEN:]

# Plotly 그래프
fig = go.Figure()

fig.add_trace(go.Scatter(
    x=x_plot,
    y=y_pred,
    mode="lines",
    name="Global model prediction",
    line=dict(color="red", width=3)
))

fig.add_trace(go.Scatter(
    x=years,
    y=chloride_true,
    mode="markers",
    name="Observed data",
    marker=dict(color="black", size=6)
))

fig.update_layout(
    title="Global LSTM: Year vs Chloride",
    xaxis_title="Year",
    yaxis_title="Chloride",
    template="plotly_white"
)

fig.show()

In [11]:
list(global_model.state_dict().keys())

['lstm.weight_ih_l0',
 'lstm.weight_hh_l0',
 'lstm.bias_ih_l0',
 'lstm.bias_hh_l0',
 'fc.weight',
 'fc.bias']