In [None]:
import numpy as np, pandas as pd
import torch 
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

# -----------------------------
# 0) Reproducibility
# -----------------------------
torch.manual_seed(0)
rng = np.random.default_rng(42)

# -----------------------------
# 1) Heat flux function (same form, now fed with t in Myr)
# -----------------------------
def heat_flux(beta, t):
    # Parameters (interpret in the same units as t; here t is in Myr)
    a = 125
    tau = 62.8     # if your physics expects tau in other units, rescale here
    lam = 3.5
    T1 = 1333
    N = 100

    T_sum = 0.0
    for n in range(1, N+1):
        Cn = beta/(n*np.pi) * np.sin(n*np.pi/beta) * np.exp(-(n**2)*t/tau)
        T_sum += Cn
    T_over_Tm = 0.8 * (1 + 2*T_sum)  # 0.8 ~ baseline scale

    # Scaling kept from your snippet
    return T_over_Tm * 1e-3 * 60 * 697

# -----------------------------
# 2) Synthetic multi-well data (8 wells, geological time 0..260 Myr)
# -----------------------------
WELLS = 5
N = 500                                 # time samples per well
t_myr = np.linspace(0.0, 260.0, N)       # geological time axis (Myr)

LITHO_NAMES = ["Shale", "Chalk", "Limestone", "Anhydrite", "Quartzite", "Dolomite"]
LITHO_TO_INT = {name: i for i, name in enumerate(LITHO_NAMES)}

rows, layer_rows = [], []

for w in range(WELLS):
    # Static per-well
    litho_name = rng.choice(LITHO_NAMES)
    litho = LITHO_TO_INT[litho_name]
    porosity = rng.uniform(0.0, 0.7)
    k = rng.uniform(1.0, 6.0)              # W/mK
    swit = rng.uniform(-20.0, 2.0)         # °C
    age = rng.uniform(0.0, 260.0)          # Myr

    # Radiogenic heat (per-layer) -> aggregate totals
    n_layers = rng.integers(2, 7)
    rad_layers = rng.uniform(0.1, 4.0, n_layers)
    rad_heat_total = float(rad_layers.sum())
    rad_heat_mean  = float(rad_layers.mean())
    for li, rh in enumerate(rad_layers):
        layer_rows.append({"well_id": w, "layer_id": li, "radiogenic_heat": float(rh)})

    # Target from heat_flux over geological time
    beta = rng.uniform(1.5, 3.0)           # per-well control parameter
    q_base = heat_flux(beta, t_myr)        # (N,)

    # Small static effects to differentiate wells
    q = (
        q_base
        + 0.05*(k - 3.5)
        - 0.06*(porosity - 0.2)
        + 0.005*(swit + 10.0)
        + 0.0008*(age - 130.0)
        + 0.02*(rad_heat_total - 8.0)
        + rng.normal(0, 0.05, N)           # small noise
    )

    dfw = pd.DataFrame({
        "well_id": w,
        "time_myr": t_myr,
        "q": q.astype(float),
        "lithology": litho,
        "lithology_name": litho_name,
        "thermal_conductivity": k,
        "porosity": porosity,
        "swit": swit,
        "age": age,
        "rad_heat_total": rad_heat_total,
        "rad_heat_mean": rad_heat_mean,
        "layers_count": n_layers,
    })
    rows.append(dfw)

df = pd.concat(rows).reset_index(drop=True).sort_values(["well_id","time_myr"]).reset_index(drop=True)
well_layers = pd.DataFrame(layer_rows).sort_values(["well_id", "layer_id"]).reset_index(drop=True)

# -----------------------------
# 3) Columns config (no calendar/date features)
# -----------------------------
# Add normalized time as a dynamic feature so the model "knows" where it is on the Myr axis
df["time_myr_norm"] = (df["time_myr"] - df["time_myr"].min()) / (df["time_myr"].max() - df["time_myr"].min())

TARGET_COL = "q"
DYN_COLS = ["q", "time_myr_norm"]                    # autoregressive + time position
STATIC_CONT_COLS = ["thermal_conductivity","porosity","swit","age","rad_heat_total"]
STATIC_CAT_COLS  = ["lithology"]
USE_WELL_ID_EMB = True

# -----------------------------
# 4) Per-well chronological split (by geological time)
# -----------------------------
TRAIN_FRAC, VAL_FRAC = 0.70, 0.15
def split_per_well(df_all):
    parts = []
    for wid, g in df_all.groupby("well_id", sort=False):
        g = g.sort_values("time_myr").copy()
        n = len(g)
        n_tr = int(n*TRAIN_FRAC)
        n_va = int(n*(TRAIN_FRAC+VAL_FRAC))
        g.loc[g.index[:n_tr], "split"] = "train"
        g.loc[g.index[n_tr:n_va], "split"] = "val"
        g.loc[g.index[n_va:], "split"] = "test"
        parts.append(g)
    return pd.concat(parts).reset_index(drop=True)

df_splits = split_per_well(df)

# -----------------------------
# 5) Normalize (fit on train only)
# -----------------------------
dyn_mean = df_splits[df_splits["split"]=="train"][DYN_COLS].mean()
dyn_std  = df_splits[df_splits["split"]=="train"][DYN_COLS].std().replace(0,1.0).fillna(1.0)
stat_mean = df_splits[df_splits["split"]=="train"][STATIC_CONT_COLS].mean()
stat_std  = df_splits[df_splits["split"]=="train"][STATIC_CONT_COLS].std().replace(0,1.0).fillna(1.0)

df_norm = df_splits.copy()
df_norm[DYN_COLS] = (df_norm[DYN_COLS]-dyn_mean)/dyn_std
df_norm[STATIC_CONT_COLS] = (df_norm[STATIC_CONT_COLS]-stat_mean)/stat_std

num_wells = df_norm["well_id"].nunique()
cat_card  = {c: int(df_norm[c].max())+1 for c in STATIC_CAT_COLS}

# -----------------------------
# 6) Dataset (sequence windows; sorted by time_myr)
# -----------------------------
WIN = 168
HORIZON = 1

class MultiWellTS(torch.utils.data.Dataset):
    def __init__(self, df_all, win, horizon, target_split: str):
        self.samples = []
        self.win, self.h = win, horizon
        assert target_split in {"train","val","test"}
        for w, g in df_all.groupby("well_id", sort=False):
            g = g.sort_values("time_myr").copy()
            Xdyn = g[DYN_COLS].to_numpy(np.float32)
            y    = g[TARGET_COL].to_numpy(np.float32)
            tmyr = g["time_myr"].to_numpy(np.float32)
            scont = g[STATIC_CONT_COLS].iloc[0].to_numpy(np.float32)
            scat  = {c:int(g[c].iloc[0]) for c in STATIC_CAT_COLS}
            wid   = int(w)
            mask_target = (g["split"].values == target_split)
            for j in range(self.win, len(g)-self.h+1):
                if not mask_target[j]:
                    continue
                x_win = Xdyn[j-self.win:j]
                y_next = y[j:j+self.h]
                if np.isnan(x_win).any() or np.isnan(y_next).any():
                    continue
                t_val = tmyr[j]  # keep the time (Myr) of the target
                self.samples.append((x_win, y_next, scont, scat, wid, t_val))

    def __len__(self): return len(self.samples)

    def __getitem__(self, idx):
        x, y, scont, scat, wid, t_val = self.samples[idx]
        return (torch.from_numpy(x),
                torch.from_numpy(y),
                torch.from_numpy(scont),
                torch.tensor([scat[c] for c in STATIC_CAT_COLS], dtype=torch.long),
                torch.tensor(wid, dtype=torch.long),
                torch.tensor(t_val, dtype=torch.float32))  # time in Myr

train_ds = MultiWellTS(df_norm, WIN, HORIZON, target_split="train")
val_ds   = MultiWellTS(df_norm, WIN, HORIZON, target_split="val")
test_ds  = MultiWellTS(df_norm, WIN, HORIZON, target_split="test")

train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, drop_last=True)
val_loader   = DataLoader(val_ds,   batch_size=128, shuffle=False)
test_loader  = DataLoader(test_ds,  batch_size=128, shuffle=False)

print(f"Windows per split -> train: {len(train_ds)}, val: {len(val_ds)}, test: {len(test_ds)}")
assert len(train_ds)>0 and len(val_ds)>0 and len(test_ds)>0

F_DYN = len(DYN_COLS)

# -----------------------------
# 7) Model: GRU + static conditioning
# -----------------------------
class GRUWithStatics(nn.Module):
    def __init__(self, f_dyn, hidden=192, layers=1, dropout=0.1,
                 cat_card=None, cat_emb_dim=8, static_cont_dim=2,
                 use_well_emb=True, num_wells=1, well_emb_dim=8, horizon=1):
        super().__init__()
        self.gru = nn.GRU(f_dyn, hidden, num_layers=layers, batch_first=True,
                          dropout=dropout if layers>1 else 0.0)

        self.cat_embs = nn.ModuleDict()
        total_cat_dim = 0
        if cat_card:
            for name, K in cat_card.items():
                self.cat_embs[name] = nn.Embedding(K, cat_emb_dim)
            total_cat_dim = cat_emb_dim * len(cat_card)

        self.use_well_emb = use_well_emb
        self.well_emb = nn.Embedding(num_wells, well_emb_dim) if use_well_emb else None

        static_in = static_cont_dim + total_cat_dim + (well_emb_dim if use_well_emb else 0)
        self.static_mlp = nn.Sequential(
            nn.Linear(static_in, 64), nn.ReLU(),
            nn.Linear(64, 64), nn.ReLU()
        )

        self.head = nn.Linear(hidden + 64, horizon)

    def forward(self, x_dyn, x_stat_cont, x_stat_cat, well_id):
        h_seq, _ = self.gru(x_dyn)           # (B, T, H)
        h_last = h_seq[:, -1, :]             # (B, H)

        cat_vecs = []
        if self.cat_embs:
            for i, name in enumerate(self.cat_embs.keys()):
                cat_vecs.append(self.cat_embs[name](x_stat_cat[:, i]))
        cat_vec = torch.cat(cat_vecs, dim=-1) if cat_vecs else None
        well_vec = self.well_emb(well_id) if self.use_well_emb else None

        parts = [x_stat_cont]
        if cat_vec is not None: parts.append(cat_vec)
        if well_vec is not None: parts.append(well_vec)
        static_all = torch.cat(parts, dim=-1)
        static_feat = self.static_mlp(static_all)

        yhat = self.head(torch.cat([h_last, static_feat], dim=-1))
        return yhat

device = "cuda" if torch.cuda.is_available() else "cpu"
model = GRUWithStatics(
    f_dyn=F_DYN, hidden=192, layers=1, dropout=0.1,
    cat_card=cat_card, cat_emb_dim=8, static_cont_dim=len(STATIC_CONT_COLS),
    use_well_emb=USE_WELL_ID_EMB, num_wells=num_wells, well_emb_dim=8, horizon=1
).to(device)

opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
loss_fn = nn.L1Loss()

# -----------------------------
# 8) Train + simple early stop
# -----------------------------
train_hist, val_hist = [], []
best_val = 1e9; best = None; patience=5; noimp=0
for epoch in range(1, 31):
    model.train(); tr_losses=[]
    for xb, yb, scont, scat, wid, _t in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        scont, scat, wid = scont.to(device), scat.to(device), wid.to(device)
        opt.zero_grad()
        pred = model(xb, scont, scat, wid)
        loss = loss_fn(pred, yb)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        tr_losses.append(loss.item())
    model.eval(); va_losses=[]
    with torch.no_grad():
        for xb, yb, scont, scat, wid, _t in val_loader:
            xb, yb = xb.to(device), yb.to(device)
            scont, scat, wid = scont.to(device), scat.to(device), wid.to(device)
            va_losses.append(loss_fn(model(xb, scont, scat, wid), yb).item())
    tr, va = float(np.mean(tr_losses)), float(np.mean(va_losses))
    train_hist.append(tr); val_hist.append(va)
    print(f"Epoch {epoch:02d} | train MAE={tr:.4f} | val MAE={va:.4f}")
    if va + 1e-6 < best_val:
        best_val = va; best = {k:v.cpu().clone() for k,v in model.state_dict().items()}; noimp=0
    else:
        noimp += 1
        if noimp >= patience:
            print("Early stopping."); break
if best: model.load_state_dict({k:v.to(device) for k,v in best.items()})

# -----------------------------
# 9) Test (collect preds + metadata)
# -----------------------------
model.eval()
preds, truths, well_ids, t_list = [], [], [], []
with torch.no_grad():
    for xb, yb, scont, scat, wid, t_val in test_loader:
        xb, yb = xb.to(device), yb.to(device)
        scont, scat, wid = scont.to(device), scat.to(device), wid.to(device)
        p = model(xb, scont, scat, wid)
        preds.append(p.cpu().numpy())
        truths.append(yb.cpu().numpy())
        well_ids.append(wid.cpu().numpy())
        t_list.append(t_val.cpu().numpy())

preds    = np.vstack(preds).squeeze(-1)
truths   = np.vstack(truths).squeeze(-1)
well_ids = np.concatenate(well_ids)
tvals    = np.concatenate(t_list)   # time in Myr (float)

mae = np.mean(np.abs(preds-truths))
rmse = np.sqrt(np.mean((preds-truths)**2))
print(f"Test MAE={mae:.3f} | RMSE={rmse:.3f}  (normalized units)")

# To express in original q-units, multiply by training std of q:
print("Approx in original q-units:",
      f"MAE≈{mae*float(dyn_std['q']):.3f}, RMSE≈{rmse*float(dyn_std['q']):.3f}")

# -----------------------------
# 10) Visualizations (geologic time)
# -----------------------------
# (A) Train vs Val MAE
plt.figure(figsize=(7,3))
plt.plot(train_hist, label="train MAE")
plt.plot(val_hist, label="val MAE")
plt.title("Training history")
plt.xlabel("Epoch")
plt.ylabel("MAE")
plt.legend(); plt.tight_layout(); plt.show()

# (B) Time series (test) for a chosen well (tail) vs Myr
well_to_plot = int(pd.Series(well_ids).mode()[0])
mask = (well_ids == well_to_plot)
df_plot = pd.DataFrame({
    "time_myr": tvals[mask],
    "y_true": truths[mask],
    "y_pred": preds[mask],
}).sort_values("time_myr")
tail = df_plot.tail(1000)
plt.figure(figsize=(10,3))
plt.plot(tail["time_myr"], tail["y_true"], label="true")
plt.plot(tail["time_myr"], tail["y_pred"], label="pred")
plt.title(f"Test predictions vs truth (well {well_to_plot}, tail)")
plt.xlabel("Time (Myr)")
plt.ylabel("Normalized q")
plt.legend(); plt.tight_layout(); plt.show()

# (C) Scatter: True vs Pred
plt.figure(figsize=(4,4))
plt.scatter(truths, preds, s=10, alpha=0.5)
minv, maxv = float(min(truths.min(), preds.min())), float(max(truths.max(), preds.max()))
plt.plot([minv, maxv], [minv, maxv])
plt.title("Predicted vs True (test)")
plt.xlabel("True"); plt.ylabel("Predicted")
plt.tight_layout(); plt.show()

# (D) Residual histogram
res = preds - truths
plt.figure(figsize=(7,3))
plt.hist(res, bins=40)
plt.title("Residuals (pred - true)")
plt.xlabel("Residual"); plt.ylabel("Count")
plt.tight_layout(); plt.show()

# (E) Per-well MAE bar plot
per_well = []
for w in np.unique(well_ids):
    m = (well_ids==w)
    per_well.append((int(w), float(np.mean(np.abs(preds[m]-truths[m])))))
per_well = pd.DataFrame(per_well, columns=["well_id","MAE"]).sort_values("well_id")

plt.figure(figsize=(6,3))
plt.bar(per_well["well_id"].astype(str), per_well["MAE"])
plt.title("Per-well MAE (test)")
plt.xlabel("well_id"); plt.ylabel("MAE (normalized)")
plt.tight_layout(); plt.show()


ModuleNotFoundError: No module named 'torch'