In [1]:
# =========================
# TRAINING CONFIG
# =========================
CONFIG = {
    "data": {
        "csv_path":      r"C:\\Users\\rcper\Desktop\SGN_NILM2\\refit_data\\CLEAN_House2.csv",  
        "appliance_col": "Appliance1",                    
        "max_rows":      750_000,
        "resample_rule": "30s",
        "train_split":   0.70,  # 70% train, 15% val, 15% test
        "val_split":     0.85,
    },
    "model": {
        "backbone_kind": "tcn",  
        "win_len":       256,
        "stride":        32,
    },
    "optim": {
        "batch_size": 32,
        "lr":         1e-3,
        "patience":   10,
        "min_delta":  0.0,
        "use_scheduler": False,
    },
    "staged_training": {
        "epochs_reg":   5,   # regressor pretrain
        "epochs_cls":   3,   # classifier pretrain
        "epochs_joint": 60,  # joint finetune
        "on_sample_prob_reg": 0.7,  # oversampling ONLY for reg batches
    },
    "loss": {
        "delta_watts":     50.0,  # Huber delta in Watts
        "alpha_on":        1.0,   # ON regression on gated power
        "alpha_reg_raw":   1.0,   # ON regression on raw reg head
        "alpha_off":       0.02,  # tiny leak penalty when OFF
        "beta_cls":        1.5,   # BCE weight
        "focal_gamma":     2.0,   # focal BCE focusing
        "pos_weight_cap":  8.0,   # cap for neg/pos weighting
    },
    "checkpoint": {
        "ckpt_path": "sgn_best.pt"
    },
}


In [2]:
# =========================
# IMPORTS
# =========================
import numpy as np
import torch

from dataset import Seq2PointWindows
from training import main_train_staged
from loss import delta_from_watts
from inference import (
    infer_seq2point_timeline_all,
    infer_seq2point_timeline_all_with_hard,
    tune_hysteresis_for_mae,
    smape,
)
from sgnNet import SGN
from refit_dataloader import load_house_csv

# plotting utils (decoupled)
from plot import plot_training_curves, plot_reg_vs_true, plot_soft, plot_hard


In [3]:
# =========================
# LOAD DATA
# =========================

mains_all, target_all, ts_all = load_house_csv(
    CONFIG["data"]["csv_path"],
    appliance_col=CONFIG["data"]["appliance_col"],
    max_rows=CONFIG["data"]["max_rows"],
    resample_rule=CONFIG["data"]["resample_rule"],
)
print(f"Loaded {len(mains_all)} samples for {CONFIG['data']['appliance_col']}")

# Split
n = len(mains_all)
train_end = int(CONFIG["data"]["train_split"] * n)
val_end   = int(CONFIG["data"]["val_split"] * n)

mains_train, target_train = mains_all[:train_end], target_all[:train_end]
mains_val,   target_val   = mains_all[train_end:val_end], target_all[train_end:val_end]
mains_test,  target_test  = mains_all[val_end:], target_all[val_end:]


Loaded 750000 samples for Appliance1


In [4]:
# =========================
# PREP DATA STATS & LOSS PARAMS
# =========================
WIN_LEN  = CONFIG["model"]["win_len"]
STRIDE   = CONFIG["model"]["stride"]

# Build small datasets JUST to compute scaling/imbalance
train_ds = Seq2PointWindows(mains_train, target_train, win_len=WIN_LEN, stride=STRIDE, train=True)
val_ds   = Seq2PointWindows(mains_val,   target_val,   win_len=WIN_LEN, stride=STRIDE, train=False)

# Huber delta in scaled units
delta = delta_from_watts(CONFIG["loss"]["delta_watts"], train_ds.target_scale)

# Class imbalance → pos_weight
pos_rate = float(train_ds.onoff.mean() + 1e-9)
neg_rate = 1.0 - pos_rate
pos_weight = min(neg_rate / pos_rate, CONFIG["loss"]["pos_weight_cap"])
print(f"ON rate ~ {pos_rate:.4f} → pos_weight={pos_weight:.2f}, delta(huber)={delta:.5f} (scaled)")


ON rate ~ 0.1246 → pos_weight=7.03, delta(huber)=0.53763 (scaled)


In [None]:
# =========================
# TRAIN (STAGED)
# =========================
results = main_train_staged(
    mains_train, target_train,
    mains_val,   target_val,
    win_len=CONFIG["model"]["win_len"],
    stride=CONFIG["model"]["stride"],
    batch_size=CONFIG["optim"]["batch_size"],
    lr=CONFIG["optim"]["lr"],
    epochs_reg=CONFIG["staged_training"]["epochs_reg"],
    epochs_cls=CONFIG["staged_training"]["epochs_cls"],
    epochs_joint=CONFIG["staged_training"]["epochs_joint"],
    on_sample_prob_reg=CONFIG["staged_training"]["on_sample_prob_reg"],
    kind=CONFIG["model"]["backbone_kind"],
    patience=CONFIG["optim"]["patience"],
    min_delta=CONFIG["optim"]["min_delta"],
    ckpt_path=CONFIG["checkpoint"]["ckpt_path"],
    delta_huber=delta,
    focal_gamma=CONFIG["loss"]["focal_gamma"],
    pos_weight=pos_weight,
    alpha_on=CONFIG["loss"]["alpha_on"],
    alpha_reg_raw=CONFIG["loss"]["alpha_reg_raw"],
    alpha_off=CONFIG["loss"]["alpha_off"],
    beta_cls=CONFIG["loss"]["beta_cls"],
    use_scheduler=CONFIG["optim"]["use_scheduler"],
    plot=False,  # plotting handled later via plot.py
)

print("Best Val MAE (W):", results["best_val_mae"])


[REG] Epoch 001 | Train 0.0904 | ValLoss 1.7913 | ValMAE 48.87
[REG] Epoch 002 | Train 0.0566 | ValLoss 1.5800 | ValMAE 47.32
[REG] Epoch 003 | Train 0.0527 | ValLoss 2.1673 | ValMAE 50.26
[REG] Epoch 004 | Train 0.0505 | ValLoss 1.7905 | ValMAE 45.66
[REG] Epoch 005 | Train 0.0469 | ValLoss 1.9570 | ValMAE 45.60
[CLS] Epoch 001 | Train 0.1383 | ValLoss 0.3824 | ValMAE 32.16
[CLS] Epoch 002 | Train 0.0718 | ValLoss 0.3512 | ValMAE 27.28
[CLS] Epoch 003 | Train 0.0638 | ValLoss 0.3346 | ValMAE 22.18
[JOINT] Epoch 001 | Train 0.1558 | ValLoss 0.3353 | ValMAE 35.83
[JOINT] Epoch 002 | Train 0.1485 | ValLoss 0.3167 | ValMAE 33.58
[JOINT] Epoch 003 | Train 0.1468 | ValLoss 0.3205 | ValMAE 28.97
[JOINT] Epoch 004 | Train 0.1380 | ValLoss 0.2990 | ValMAE 27.28
[JOINT] Epoch 005 | Train 0.1367 | ValLoss 0.3293 | ValMAE 27.16
[JOINT] Epoch 006 | Train 0.1376 | ValLoss 0.3042 | ValMAE 27.93
[JOINT] Epoch 007 | Train 0.1326 | ValLoss 0.2984 | ValMAE 29.95
[JOINT] Epoch 008 | Train 0.1239 | ValLos

In [None]:
# =========================
# LOAD BEST MODEL & RUN INFERENCE
# =========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt = torch.load(CONFIG["checkpoint"]["ckpt_path"], map_location=device)

model = SGN(in_ch=1, hid=64, kind=CONFIG["model"]["backbone_kind"], out_len=1).to(device)
model.load_state_dict(ckpt["model"])
stats = ckpt["stats"]

gate_tau = stats.get("gate_tau", 0.75) if isinstance(stats, dict) else 0.75

# Validation inference for hysteresis tuning
pow_val_soft, reg_val_w, prob_val = infer_seq2point_timeline_all(model, mains_val, stats, device, gate_tau=gate_tau)
y_val = target_val
t_on, t_off, min_hold, best_val_mae = tune_hysteresis_for_mae(reg_val_w, prob_val, y_val, gate_floor=0.0)
print(f"[HYSTERESIS][VAL] best → t_on={t_on:.2f}, t_off={t_off:.2f}, min_hold={min_hold} | MAE={best_val_mae:.2f} W")

# Test inference (soft + hard)
power_soft_w, reg_w, prob, gate_hard, power_hard_w = infer_seq2point_timeline_all_with_hard(
    model, mains_test, stats, device,
    gate_tau=gate_tau,
    t_on=t_on, t_off=t_off, min_hold=min_hold,
    gate_floor=0.0
)
y_true = target_test

# Metrics
mae_soft = float(np.mean(np.abs(power_soft_w - y_true)))
mae_hard = float(np.mean(np.abs(power_hard_w - y_true)))
energy_true = float(y_true.sum())
sae_soft = float(abs(power_soft_w.sum() - energy_true) / (energy_true + 1e-6))
sae_hard = float(abs(power_hard_w.sum() - energy_true) / (energy_true + 1e-6))
print(f"[TEST][SOFT] MAE={mae_soft:.2f} W | SAE={sae_soft:.4f} | sMAPE={smape(y_true, power_soft_w):.2f}%")
print(f"[TEST][HARD] MAE={mae_hard:.2f} W | SAE={sae_hard:.4f} | sMAPE={smape(y_true, power_hard_w):.2f}%")


In [None]:
# =========================
# PLOTS 
# =========================
PLOT_CFG = {
    "enable_training_curve": True,
    "show_regression": True,
    "show_soft": True,
    "show_hard": True,
    "test_plot_len": 80_000,
    "on_threshold_for_plot": 15.0,  # for binary true_on curve in hard plot
}




N = min(PLOT_CFG["test_plot_len"], len(y_true))

plot_training_curves(results, enable=PLOT_CFG["enable_training_curve"])
plot_reg_vs_true(y_true, reg_w, N=N, show=PLOT_CFG["show_regression"])
plot_soft(y_true, power_soft_w, N=N, show=PLOT_CFG["show_soft"])
plot_hard(y_true, power_hard_w, gate_hard,
          on_threshold=PLOT_CFG["on_threshold_for_plot"],
          N=N, show=PLOT_CFG["show_hard"])
