In [None]:
# ======================  CONFIG  ======================
CFG = dict(
    csv_path   = r"",  
    date_col   = "date",      
    lookback   = 120,
    tau        = 25,
    pred_len   = 96,
    d_model    = 64,
    nhead      = 4,
    num_layers = 2,
    dropout    = 0.001,
    concept_hidden = [64, 32],
    mlp_hidden     = [128, 64],
    lr         = 0.0002,
    epochs     = 10,
    batch_size = 64,
    lamb_phys  = 1,
    lamb_con   = 1,  
    lambs      = (1, 1, 1, 1, 1),
    device     = "cuda" if __import__("torch").cuda.is_available() else "cpu",
    plot_each  = True,
    show_epoch_mse = True,
    scale_method = "standard",    
)
# =======================================================
import warnings, numpy as np, pandas as pd, torch, \
       torch.nn as nn, torch.nn.functional as F, matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from tqdm import trange



def load_csv(path, date_col):
    df = pd.read_csv(path)
    if date_col and date_col in df.columns:
        df[date_col] = pd.to_datetime(df[date_col])
        df = df.set_index(date_col).sort_index()
    return df

def window_stack(arr, L):
    """Rolling view of length‑L windows as ndarray (n_win, L)."""
    return np.stack([arr[i : i + L] for i in range(len(arr) - L + 1)], axis=0)

def ensure_list(x):
    return x if isinstance(x, (list, tuple)) else [x]

def build_mlp(in_dim, hidden, out_dim):
    layers, last = [], in_dim
    for h in hidden:
        layers += [nn.Linear(last, h), nn.ReLU(inplace=True)]
        last = h
    layers.append(nn.Linear(last, out_dim))
    return nn.Sequential(*layers)

def get_scaler():
    return MinMaxScaler() if CFG["scale_method"].lower() == "minmax" else StandardScaler()


class SeriesDataset(Dataset):
    def __init__(self, arr, L, tau):
        self.win = torch.from_numpy(window_stack(arr, L)).float()
        self.tau = tau
    def __len__(self):
        return len(self.win) - self.tau
    def __getitem__(self, idx):
        w = self.win[idx]
        y = self.win[idx + self.tau, -1]
        return w, y


class Encoder(nn.Module):
    def __init__(self, L, d_model, nhead, nlayers, dropout):
        super().__init__()
        self.pos_emb = nn.Parameter(torch.randn(1, L, d_model))
        self.in_proj = nn.Linear(1, d_model)
        enc_layer = nn.TransformerEncoderLayer(d_model, nhead, dropout=dropout, batch_first=True)
        self.transformer = nn.TransformerEncoder(enc_layer, nlayers)
    def forward(self, x):
        x = self.in_proj(x.unsqueeze(-1)) + self.pos_emb
        return self.transformer(x)

class ConceptLayer(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.mlp = build_mlp(d_model, ensure_list(CFG["concept_hidden"]), 5)
    def forward(self, z):
        return self.mlp(z[:, -1])

class PINNHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = build_mlp(5, ensure_list(CFG["mlp_hidden"]), 1)
    def forward(self, c):
        return self.fc(c).squeeze(-1)



def physics_residual(last_x, C, head):
    preds = head(C)
    r1 = torch.mean((preds - last_x) ** 2)
    r2 = torch.mean(C ** 2)
    a, b = CFG["lambs"][:2]
    return CFG["lamb_phys"] * (a * r1 + b * r2)



def hard_concepts(win, tau):
    x = win[:, -tau:].mean(dim=1)
    return torch.stack([
        (x > 0).float(),
        (x < 0).float(),
        torch.abs(x),
        torch.exp(-x),
        torch.sin(x),
    ], dim=1)



def train_series(arr_std):
    dev, L = CFG["device"], CFG["lookback"]
    ds = SeriesDataset(arr_std, L, CFG["tau"])
    dl = DataLoader(ds, batch_size=CFG["batch_size"], shuffle=True, drop_last=True)

    enc, con, head = (
        Encoder(L, CFG["d_model"], CFG["nhead"], CFG["num_layers"], CFG["dropout"]).to(dev),
        ConceptLayer(CFG["d_model"]).to(dev),
        PINNHead().to(dev),
    )
    opt = torch.optim.AdamW(
        list(enc.parameters()) + list(con.parameters()) + list(head.parameters()), lr=CFG["lr"]
    )

    for ep in trange(CFG["epochs"], desc="epochs", leave=False):
        mse_sum, n_sum = 0.0, 0
        for win, y in dl:
            win, y = win.to(dev), y.to(dev)
            C = con(enc(win))
            y_hat = head(C)
            data = F.mse_loss(y_hat, y)

            phys = physics_residual(win[:, -1], C, head)
            con_loss = F.mse_loss(C, hard_concepts(win, CFG["tau"]).to(dev))
            trend_true, trend_pred = y - win[:, -1], y_hat - win[:, -1]
            t_loss = F.mse_loss(trend_pred, trend_true)

            loss = data + phys + CFG["lamb_con"] * con_loss + CFG["lamb_trend"] * t_loss
            opt.zero_grad(); loss.backward(); opt.step()

            mse_sum += (y_hat - y).pow(2).sum().item(); n_sum += y.numel()
        if CFG["show_epoch_mse"]:
            print(f"Epoch {ep+1:3d}/{CFG['epochs']}: Train MSE (norm) = {mse_sum / n_sum:.6f}")
    return enc, con, head



def predict_last96(arr_std, enc, con, head, scaler):
    """返回 (raw_true, raw_pred, norm_true, norm_pred) 四元组"""
    dev, L, P = CFG["device"], CFG["lookback"], CFG["pred_len"]
    wins = window_stack(arr_std, L)[-P:]
    with torch.no_grad():
        y_pred_norm = (
            head(con(enc(torch.from_numpy(wins).float().to(dev)))).cpu().numpy().ravel()
        )
    y_true_norm = arr_std[-P:]

    
    y_pred_raw = scaler.inverse_transform(y_pred_norm.reshape(-1, 1)).ravel()
    y_true_raw = scaler.inverse_transform(y_true_norm.reshape(-1, 1)).ravel()
    return y_true_raw, y_pred_raw, y_true_norm, y_pred_norm



def main():
    df = load_csv(CFG["csv_path"], CFG["date_col"])

    test_mse_list, test_mae_list = [], []

    for col in df.columns:
        arr = df[col].dropna().values.astype("float32")
        if len(arr) < CFG["lookback"] + CFG["pred_len"] + 10:
            continue
        scaler = get_scaler(); arr_norm = scaler.fit_transform(arr.reshape(-1, 1)).ravel()

        try:
            enc, con, head = train_series(arr_norm)
            y_true_raw, y_pred_raw, y_true_norm, y_pred_norm = predict_last96(
                arr_norm, enc, con, head, scaler
            )
        except RuntimeError as e:
            warnings.warn(f"{col}: {e}"); continue

       
        mse = float(np.mean((y_pred_norm - y_true_norm) ** 2))
        mae = float(np.mean(np.abs(y_pred_norm - y_true_norm)))
        test_mse_list.append(mse); test_mae_list.append(mae)
        print(f"{col}: Test MSE (norm) = {mse:.6f}, MAE (norm) = {mae:.6f}")

        
        if CFG["plot_each"]:
            plt.figure(figsize=(8, 3))
            plt.plot(y_true_raw, label="True")
            plt.plot(y_pred_raw, label="Pred")
            plt.title(f"{col} – Last {CFG['pred_len']} steps")
            plt.legend(); plt.tight_layout(); plt.show()

    
    if test_mse_list:
        print("\n----- Overall Normalised Test Results -----")
        print(f"Series processed : {len(test_mse_list)}")
        print(f"Average MSE (norm): {np.mean(test_mse_list):.6f}")
        print(f"Average MAE (norm): {np.mean(test_mae_list):.6f}")
    else:
        print("No valid series processed – please check the CSV and CFG settings.")

if __name__ == "__main__":
    main()