In [4]:
import os
import math
import random
import time

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt

# ============================================================
# Config
# ============================================================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# Folder that contains subfolders: 3C90, 3C94, 3E6, 3F4, 77, 78, N27, N30, N49, N87
# If you run this script from that folder, "." is fine.
ROOT = "."  

MATERIALS = ["3C90", "3C94", "3E6", "3F4", "77", "78", "N27", "N30", "N49", "N87"]
FREQS = [1, 2, 3, 4, 5, 6, 7]

SEQ_LEN = 80
STRIDE = 5
BATCH_SIZE = 256      # you can drop to 128 if CPU is slow
EPOCHS = 80
LR = 1e-3
PATIENCE = 20


# ============================================================
# Helpers
# ============================================================

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def load_1d(path: str, nrows: int = None) -> np.ndarray:
    """
    Load a 1D csv file (no header) as float32 np.array.
    If nrows is given, only that many rows are read (to avoid massive I/O).
    """
    df = pd.read_csv(path, header=None, nrows=nrows)
    arr = df.values.squeeze().astype(np.float32)
    return arr.reshape(-1)


# ============================================================
# Dataset: all materials + all frequencies
# ============================================================

class GlobalDataset(Dataset):
    """
    Builds a single dataset over all materials and all frequencies.

    For each (material, freq):
      - read T fully (short)
      - read only len(T) rows from B and H
      - truncate if B/H/T still mismatched
      - compute dB/dt
      - create sliding windows of length SEQ_LEN, stride=1 (we don't
        actually use STRIDE here; datasets are already big)

    Each sample:
      x      : (SEQ_LEN, 3)  [B_norm, H_norm, dBdt_norm]
      mat_id : int in [0, n_materials)
      freq_id: int in [0, n_freqs)
      y      : scalar H_norm at (t+SEQ_LEN)
      H_mean, H_std: to denormalize later (per-series stats)
    """

    def __init__(self, root, materials, freqs, seq_len=80, stride=5):
        super().__init__()
        self.seq_len = seq_len
        self.stride = stride

        self.inputs = []
        self.targets = []
        self.mat_ids = []
        self.freq_ids = []
        self.H_means = []
        self.H_stds = []

        for m_idx, mat in enumerate(materials):
            for f_idx, freq in enumerate(freqs):
                B_path = os.path.join(root, mat, f"{mat}_{freq}_B.csv")
                H_path = os.path.join(root, mat, f"{mat}_{freq}_H.csv")
                T_path = os.path.join(root, mat, f"{mat}_{freq}_T.csv")

                if not (os.path.exists(B_path) and os.path.exists(H_path) and os.path.exists(T_path)):
                    print(f"Skipping [{mat}, f={freq}]: missing B/H/T CSVs")
                    continue

                # Read T fully (short), then only len(T) rows from B/H
                T = load_1d(T_path)
                lenT = len(T)

                B = load_1d(B_path, nrows=lenT)
                H = load_1d(H_path, nrows=lenT)

                lenB, lenH = len(B), len(H)
                if not (lenB == lenH == lenT):
                    min_len = min(lenB, lenH, lenT)
                    print(
                        f"Warning [{mat}, f={freq}]: "
                        f"B/H/T lengths differ (B={lenB}, H={lenH}, T={lenT}) "
                        f"-> truncating to {min_len}"
                    )
                    B = B[:min_len]
                    H = H[:min_len]
                    T = T[:min_len]

                # safety
                B = np.asarray(B, dtype=np.float32).reshape(-1)
                H = np.asarray(H, dtype=np.float32).reshape(-1)
                dBdt = np.gradient(B).astype(np.float32)

                B_mean, B_std = float(B.mean()), float(B.std() + 1e-8)
                H_mean, H_std = float(H.mean()), float(H.std() + 1e-8)
                dB_mean, dB_std = float(dBdt.mean()), float(dBdt.std() + 1e-8)

                Bn = (B - B_mean) / B_std
                Hn = (H - H_mean) / H_std
                dBn = (dBdt - dB_mean) / dB_std

                n = len(Bn)
                count_windows = 0

                for start in range(0, n - seq_len):
                    end = start + seq_len
                    if end >= n:
                        break

                    x = np.stack([Bn[start:end], Hn[start:end], dBn[start:end]], axis=-1)
                    y = Hn[end]

                    self.inputs.append(x.astype(np.float32))
                    self.targets.append(np.float32(y))
                    self.mat_ids.append(m_idx)
                    self.freq_ids.append(f_idx)
                    self.H_means.append(H_mean)
                    self.H_stds.append(H_std)
                    count_windows += 1

                print(f"[{mat}, f={freq}] windows added: {count_windows} (series length {n})")

        self.inputs = np.stack(self.inputs, axis=0)
        self.targets = np.stack(self.targets, axis=0)
        self.mat_ids = np.array(self.mat_ids, dtype=np.int64)
        self.freq_ids = np.array(self.freq_ids, dtype=np.int64)
        self.H_means = np.array(self.H_means, dtype=np.float32)
        self.H_stds = np.array(self.H_stds, dtype=np.float32)

        print(f"\nTotal samples in GlobalDataset: {len(self.inputs)}")

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        x = torch.from_numpy(self.inputs[idx])         # (L, 3)
        y = torch.tensor(self.targets[idx], dtype=torch.float32)
        mat_id = torch.tensor(self.mat_ids[idx], dtype=torch.long)
        freq_id = torch.tensor(self.freq_ids[idx], dtype=torch.long)
        h_mean = torch.tensor(self.H_means[idx], dtype=torch.float32)
        h_std = torch.tensor(self.H_stds[idx], dtype=torch.float32)
        return x, mat_id, freq_id, y, h_mean, h_std


# ============================================================
# Model: shared LSTM + (material, freq) embeddings
# ============================================================

class GlobalModel(nn.Module):
    def __init__(
        self,
        n_materials: int,
        n_freqs: int,
        input_dim: int = 3,
        lstm_hidden: int = 32,
        lstm_layers: int = 1,
        mat_emb_dim: int = 4,
        freq_emb_dim: int = 2,
    ):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=lstm_hidden,
            num_layers=lstm_layers,
            batch_first=True,
        )
        self.mat_emb = nn.Embedding(n_materials, mat_emb_dim)
        self.freq_emb = nn.Embedding(n_freqs, freq_emb_dim)

        fc_in = lstm_hidden + mat_emb_dim + freq_emb_dim
        self.fc = nn.Linear(fc_in, 1)

    def forward(self, x, mat_id, freq_id):
        # x: (B, L, 3)
        lstm_out, _ = self.lstm(x)
        last = lstm_out[:, -1, :]  # (B, hidden)

        mat_e = self.mat_emb(mat_id)   # (B, mat_emb_dim)
        freq_e = self.freq_emb(freq_id)  # (B, freq_emb_dim)

        h = torch.cat([last, mat_e, freq_e], dim=-1)
        y = self.fc(h).squeeze(-1)
        return y


# ============================================================
# Training + evaluation
# ============================================================

def train_global_model():
    set_seed(42)

    dataset = GlobalDataset(ROOT, MATERIALS, FREQS, seq_len=SEQ_LEN, stride=STRIDE)
    n_total = len(dataset)
    n_train = int(0.8 * n_total)
    n_val = n_total - n_train
    train_ds, val_ds = random_split(
        dataset, [n_train, n_val],
        generator=torch.Generator().manual_seed(42),
    )

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

    model = GlobalModel(
        n_materials=len(MATERIALS),
        n_freqs=len(FREQS),
        input_dim=3,
        lstm_hidden=32,
        lstm_layers=1,
        mat_emb_dim=4,
        freq_emb_dim=2,
    ).to(DEVICE)

    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)

    best_val = float("inf")
    best_state = None
    no_improve = 0

    train_history = []
    val_history = []

    print("\n========== Training GLOBAL model ==========")
    for epoch in range(1, EPOCHS + 1):
        # ---- train
        model.train()
        train_losses = []
        for xb, mat_idb, freq_idb, yb, hmean_b, hstd_b in train_loader:
            xb = xb.to(DEVICE)
            mat_idb = mat_idb.to(DEVICE)
            freq_idb = freq_idb.to(DEVICE)
            yb = yb.to(DEVICE)

            optimizer.zero_grad()
            preds = model(xb, mat_idb, freq_idb)
            loss = criterion(preds, yb)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())

        # ---- val
        model.eval()
        val_losses = []
        with torch.no_grad():
            for xb, mat_idb, freq_idb, yb, hmean_b, hstd_b in val_loader:
                xb = xb.to(DEVICE)
                mat_idb = mat_idb.to(DEVICE)
                freq_idb = freq_idb.to(DEVICE)
                yb = yb.to(DEVICE)

                preds = model(xb, mat_idb, freq_idb)
                loss = criterion(preds, yb)
                val_losses.append(loss.item())

        train_loss = float(np.mean(train_losses))
        val_loss = float(np.mean(val_losses))
        train_history.append(train_loss)
        val_history.append(val_loss)

        print(f"Epoch {epoch:3d}/{EPOCHS}  train_loss={train_loss:.6f}  val_loss={val_loss:.6f}")

        if val_loss < best_val - 1e-6:
            best_val = val_loss
            best_state = model.state_dict()
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= PATIENCE:
                print(f"Early stopping at epoch {epoch}. Best val_loss={best_val:.6f}")
                break

    # load best weights
    if best_state is not None:
        model.load_state_dict(best_state)

    # ---- Evaluate on ALL samples in physical units
    full_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

    preds_norm_all = []
    targets_norm_all = []
    mat_ids_all = []
    freq_ids_all = []
    H_means_all = []
    H_stds_all = []

    model.eval()
    with torch.no_grad():
        for xb, mat_idb, freq_idb, yb, hmean_b, hstd_b in full_loader:
            xb = xb.to(DEVICE)
            mat_idb = mat_idb.to(DEVICE)
            freq_idb = freq_idb.to(DEVICE)

            preds = model(xb, mat_idb, freq_idb).cpu().numpy()
            y_np = yb.numpy()

            preds_norm_all.append(preds)
            targets_norm_all.append(y_np)
            mat_ids_all.append(mat_idb.cpu().numpy())
            freq_ids_all.append(freq_idb.cpu().numpy())
            H_means_all.append(hmean_b.numpy())
            H_stds_all.append(hstd_b.numpy())

    preds_norm_all = np.concatenate(preds_norm_all)
    targets_norm_all = np.concatenate(targets_norm_all)
    mat_ids_all = np.concatenate(mat_ids_all)
    freq_ids_all = np.concatenate(freq_ids_all)
    H_means_all = np.concatenate(H_means_all)
    H_stds_all = np.concatenate(H_stds_all)

    preds_phys_all = preds_norm_all * H_stds_all + H_means_all
    targets_phys_all = targets_norm_all * H_stds_all + H_means_all

    # global metrics
    global_rmse = float(np.sqrt(np.mean((preds_phys_all - targets_phys_all) ** 2)))
    rms_meas_global = float(np.sqrt(np.mean(targets_phys_all ** 2)))
    global_rel = global_rmse / (rms_meas_global + 1e-12) * 100.0

    print("\n===== GLOBAL METRICS =====")
    print(f"Global RMSE: {global_rmse:.4f} (H units)")
    print(f"Global Relative Error: {global_rel:.2f}%")

    # ---- per-material & per-frequency metrics
    rows = []
    for m_idx, mat in enumerate(MATERIALS):
        for f_idx, freq in enumerate(FREQS):
            mask = (mat_ids_all == m_idx) & (freq_ids_all == f_idx)
            if not np.any(mask):
                continue

            p = preds_phys_all[mask]
            t = targets_phys_all[mask]
            rmse = float(np.sqrt(np.mean((p - t) ** 2)))
            rms_meas = float(np.sqrt(np.mean(t ** 2)))
            rel = rmse / (rms_meas + 1e-12) * 100.0

            rows.append({
                "material": mat,
                "freq": freq,
                "n_samples": int(mask.sum()),
                "rmse": rmse,
                "rel_err": rel,
            })

    df = pd.DataFrame(rows)
    df.to_csv("global_results.csv", index=False)

    print("\n===== Per-material, per-frequency metrics =====")
    print(df.to_string(index=False, float_format=lambda x: f"{x:.4f}"))

    # ---- plots
    # 1) Train vs val loss
    plt.figure(figsize=(6, 4))
    epochs = np.arange(1, len(train_history) + 1)
    plt.plot(epochs, train_history, label="train")
    plt.plot(epochs, val_history, label="val")
    plt.xlabel("Epoch")
    plt.ylabel("MSE loss")
    plt.title("Global model: train vs val loss")
    plt.legend()
    plt.tight_layout()
    plt.savefig("global_train_val_loss.png", dpi=150)
    plt.close()

    # 2) Average RMSE per material
    mat_rmse = df.groupby("material")["rmse"].mean().reset_index()
    plt.figure(figsize=(8, 4))
    plt.bar(mat_rmse["material"], mat_rmse["rmse"])
    plt.xlabel("Material")
    plt.ylabel("RMSE (H units)")
    plt.title("Average RMSE per material (across frequencies)")
    plt.tight_layout()
    plt.savefig("global_rmse_by_material.png", dpi=150)
    plt.close()

    # save model
    torch.save(model.state_dict(), "global_lstm_model.pt")
    print("\nSaved global model to: global_lstm_model.pt")
    print("Saved metrics to: global_results.csv")
    print("Saved plots: global_train_val_loss.png, global_rmse_by_material.png")


# ============================================================
# Run
# ============================================================
if __name__ == "__main__":
    train_global_model()


Device: cpu
[3C90, f=1] windows added: 567 (series length 647)
[3C90, f=2] windows added: 1379 (series length 1459)
[3C90, f=3] windows added: 2172 (series length 2252)
[3C90, f=4] windows added: 3172 (series length 3252)
[3C90, f=5] windows added: 3189 (series length 3269)
[3C90, f=6] windows added: 3220 (series length 3300)
[3C90, f=7] windows added: 3010 (series length 3090)
[3C94, f=1] windows added: 1028 (series length 1108)
[3C94, f=2] windows added: 258 (series length 338)
[3C94, f=3] windows added: 761 (series length 841)
[3C94, f=4] windows added: 1979 (series length 2059)
[3C94, f=5] windows added: 1787 (series length 1867)
[3C94, f=6] windows added: 2020 (series length 2100)
[3C94, f=7] windows added: 2020 (series length 2100)
[3E6, f=1] windows added: 815 (series length 895)
[3E6, f=2] windows added: 2015 (series length 2095)
[3E6, f=3] windows added: 142 (series length 222)
[3E6, f=4] windows added: 1292 (series length 1372)
[3E6, f=5] windows added: 1720 (series length 18