In [2]:
# %%
import sys
from pathlib import Path
import torch

nb_dir = Path.cwd()
repo_root = None
for p in [nb_dir] + list(nb_dir.parents):
    if (p / "src").is_dir():
        repo_root = p
        break
if repo_root is None:
    raise RuntimeError("Cannot find repo root containing src/")

sys.path.insert(0, str(repo_root))
print("Notebook dir:", nb_dir)
print("Repo root   :", repo_root)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

Notebook dir: h:\GK-MRL-PhysicsConsistent-Inversion\GK-MRL-PhysicsConsistent-Inversion\notebooks
Repo root   : h:\GK-MRL-PhysicsConsistent-Inversion\GK-MRL-PhysicsConsistent-Inversion
device: cuda


In [3]:
# %%
import os
import numpy as np
from src.geo_constraints import DataPaths

DATA_ROOT = r"H:\GK-MRL-PhysicsConsistent-Inversion\GK-MRL-PhysicsConsistent-Inversion\data"
paths = DataPaths(DATA_ROOT)

constraints_npz = os.path.join(paths.processed_dir, "constraints.npz")
npz = np.load(constraints_npz, allow_pickle=True)

P_VOL = npz["P"].astype(np.float32)       # [4,150,200,200]
C_VOL = npz["C"].astype(np.float32)       # [150,200,200]
M_VOL = npz["M"].astype(np.float32)       # [150,200,200]
ilines = npz["ilines"].astype(int)        # [150]
xlines = npz["xlines"].astype(int)        # [200]
twt_ms = npz["twt_ms"].astype(np.float32) # [200]
dt_ms  = float(npz["dt_ms"])              # scalar

print("P_VOL:", P_VOL.shape)
print("C_VOL:", C_VOL.shape)
print("M_VOL:", M_VOL.shape)
print("ilines:", ilines[:3], "...", ilines[-3:])
print("xlines:", xlines[:3], "...", xlines[-3:])
print("dt_ms:", dt_ms, "twt_ms[0:5]:", twt_ms[:5])
print("processed_dir:", paths.processed_dir)

P_VOL: (4, 150, 200, 200)
C_VOL: (150, 200, 200)
M_VOL: (150, 200, 200)
ilines: [1 2 3] ... [148 149 150]
xlines: [1 2 3] ... [198 199 200]
dt_ms: 1.0 twt_ms[0:5]: [0. 1. 2. 3. 4.]
processed_dir: H:\GK-MRL-PhysicsConsistent-Inversion\GK-MRL-PhysicsConsistent-Inversion\data\processed


In [4]:
# %%
import segyio
import numpy as np

SEIS_SGY = r"H:\GK-MRL-PhysicsConsistent-Inversion\GK-MRL-PhysicsConsistent-Inversion\data\ATTR_TIME_SGY\AI_FINAL_IL2_XL1_S0_XY25.sgy"
assert os.path.isfile(SEIS_SGY), SEIS_SGY

def read_segy_cube_ilxl_t(path, ilines, xlines, T_expected=200):
    """
    Read SEG-Y into cube [IL,XL,T] using segyio.
    Assumes file contains traces for all IL×XL.
    """
    with segyio.open(path, "r", ignore_geometry=False) as f:
        f.mmap()
        samples = np.array(f.samples, dtype=np.float32)
        T = len(samples)
        assert T == T_expected, f"T mismatch: segy T={T}, expected {T_expected}"

        # map from (il,xl) -> trace index
        # segyio gives us attributes via f.attributes
        ils = f.attributes(segyio.TraceField.INLINE_3D)[:]
        xls = f.attributes(segyio.TraceField.CROSSLINE_3D)[:]

        idx_map = {}
        for tr in range(f.tracecount):
            idx_map[(int(ils[tr]), int(xls[tr]))] = tr

        ILn = len(ilines); XLn = len(xlines)
        cube = np.zeros((ILn, XLn, T), dtype=np.float32)

        missing = 0
        for i, il in enumerate(ilines):
            for j, xl in enumerate(xlines):
                tr = idx_map.get((int(il), int(xl)), None)
                if tr is None:
                    missing += 1
                    continue
                cube[i, j, :] = f.trace[tr].astype(np.float32)

        print("Loaded cube:", cube.shape, "missing traces:", missing)
        return cube, samples

SEIS, seis_samples = read_segy_cube_ilxl_t(SEIS_SGY, ilines, xlines, T_expected=200)
print("SEIS min/max:", float(SEIS.min()), float(SEIS.max()))

Loaded cube: (150, 200, 200) missing traces: 0
SEIS min/max: 4.318020820617676 9.578778266906738


In [5]:
# %%
import torch
from src.models.geo_cnn_multitask import GeoCNNMultiTask
from src.dataset_vie import StanfordVIEWellPatchDataset

# 用 dataset 只拿 mean/std（必须与训练一致）
ds = StanfordVIEWellPatchDataset(paths, constraints_npz, patch_hw=4, use_masked_y=True, normalize=True)

ckpt_dir = os.path.join(paths.processed_dir, "checkpoints_multitask_final")
ckpt_path = os.path.join(ckpt_dir, "best_joint.pt")  # or best_ai.pt
assert os.path.isfile(ckpt_path), ckpt_path
print("ckpt:", ckpt_path)

def build_model():
    return GeoCNNMultiTask(in_channels=7, base=32, t=200, n_facies=4).to(device)

def load_ckpt(model, ckpt_path):
    ckpt = torch.load(ckpt_path, map_location=device)
    state = ckpt["model_state"] if isinstance(ckpt, dict) and "model_state" in ckpt else ckpt
    model.load_state_dict(state, strict=True)
    model.eval()

def unpack_outputs(out):
    ai_pred = None
    facies_logits = None
    if isinstance(out, dict):
        for k in ["ai", "ai_pred", "y", "imp", "impedance"]:
            if k in out:
                ai_pred = out[k]; break
        for k in ["facies", "facies_logits", "logits", "facies_logit"]:
            if k in out:
                facies_logits = out[k]; break
    elif isinstance(out, (list, tuple)):
        ai_pred = out[0]
        facies_logits = out[1] if len(out) > 1 else None
    else:
        ai_pred = out
    return ai_pred, facies_logits

model = build_model()
load_ckpt(model, ckpt_path)

print("Model loaded.")
print("AI mean/std:", ds.ai_mean, ds.ai_std)
print("Seis mean/std:", ds.seis_mean, ds.seis_std)

ckpt: H:\GK-MRL-PhysicsConsistent-Inversion\GK-MRL-PhysicsConsistent-Inversion\data\processed\checkpoints_multitask_final\best_joint.pt
Model loaded.
AI mean/std: 7.208985805511475 1.2618153095245361
Seis mean/std: 7.208985805511475 1.2618153095245361


  ckpt = torch.load(ckpt_path, map_location=device)


In [6]:
# %%
SEIS_N = (SEIS - float(ds.seis_mean)) / (float(ds.seis_std) + 1e-12)
print("SEIS_N min/max:", float(SEIS_N.min()), float(SEIS_N.max()))

SEIS_N min/max: -2.2911157608032227 1.8780819177627563


In [7]:
# %%
import numpy as np
import torch

patch_hw = 4
H = W = 2*patch_hw + 1
IL, XL, T = SEIS_N.shape
assert (IL,XL,T) == (150,200,200)

AI_pred_cube = np.empty((IL, XL, T), dtype=np.float32)
Facies_pred_cube = np.full((IL, XL, T), -1, dtype=np.int16)

batch_size = 64  # ✅ 你可以尝试 64/128，显存不够就降回32/16
print("Total traces:", IL*XL)

# =========================
# 1) 只 pad 一次（巨大提速点）
# =========================
pad = patch_hw

SEIS_pad = np.pad(SEIS_N, ((pad,pad),(pad,pad),(0,0)), mode="edge")   # [IL+2p, XL+2p, T]
C_pad    = np.pad(C_VOL,  ((pad,pad),(pad,pad),(0,0)), mode="edge")   # [IL+2p, XL+2p, T]
M_pad    = np.pad(M_VOL,  ((pad,pad),(pad,pad),(0,0)), mode="edge")   # [IL+2p, XL+2p, T]
P_pad    = np.pad(P_VOL,  ((0,0),(pad,pad),(pad,pad),(0,0)), mode="edge")  # [4, IL+2p, XL+2p, T]

# =========================
# 2) 预分配 batch numpy
# =========================
x_batch = np.empty((batch_size, 1, H, W, T), dtype=np.float32)
p_batch = np.empty((batch_size, 4, H, W, T), dtype=np.float32)
c_batch = np.empty((batch_size, 1, H, W, T), dtype=np.float32)
m_batch = np.empty((batch_size, 1, H, W, T), dtype=np.float32)

# autocast：GPU 上可进一步加速（不影响输出类型，我们最终还是 denorm 到 float32）
use_amp = (device.type == "cuda")
amp_ctx = torch.cuda.amp.autocast if use_amp else torch.cpu.amp.autocast  # cpu 上基本无意义，但保持结构

@torch.no_grad()
def run_full_infer_fast():
    b = 0   # batch counter
    done = 0

    # 逐 IL 扫描，cache 更友好
    for il in range(IL):
        ilp = il  # 在 pad 后数组里，原 il 对应起始就是 il（因为前面 pad 了 pad 个）
        for xl in range(XL):
            xlp = xl

            # 取 patch：pad数组里直接切片
            # SEIS_pad: [IL+2p, XL+2p, T] -> patch [H,W,T]
            x_patch = SEIS_pad[ilp:ilp+H, xlp:xlp+W, :]
            c_patch = C_pad[ilp:ilp+H, xlp:xlp+W, :]
            m_patch = M_pad[ilp:ilp+H, xlp:xlp+W, :]
            p_patch = P_pad[:, ilp:ilp+H, xlp:xlp+W, :]

            # 填入 batch buffer（避免 list + stack）
            x_batch[b, 0] = x_patch
            p_batch[b]    = p_patch
            c_batch[b, 0] = c_patch
            m_batch[b, 0] = m_patch

            b += 1

            # 满 batch 或者最后一个点，执行推理
            if b == batch_size or (il == IL-1 and xl == XL-1):
                xb = torch.from_numpy(x_batch[:b]).to(device, non_blocking=True)
                pb = torch.from_numpy(p_batch[:b]).to(device, non_blocking=True)
                cb = torch.from_numpy(c_batch[:b]).to(device, non_blocking=True)
                mb = torch.from_numpy(m_batch[:b]).to(device, non_blocking=True)

                with amp_ctx(enabled=use_amp):
                    out = model(xb, pb, cb, mb)

                ai_pred, facies_logits = unpack_outputs(out)

                # ai_pred -> [b,T]
                ai_pred = ai_pred.squeeze()
                if ai_pred.ndim == 1:
                    ai_pred = ai_pred[None, :]
                if ai_pred.ndim != 2:
                    ai_pred = ai_pred.reshape(ai_pred.shape[0], -1)

                ai_den = ai_pred.detach().float().cpu().numpy().astype(np.float32) * float(ds.ai_std) + float(ds.ai_mean)

                # facies：用 torch argmax，最后转 numpy
                fac_pred = None
                if facies_logits is not None:
                    L = facies_logits.detach()
                    # squeeze 多余 1 维
                    while L.dim() > 2 and 1 in L.shape:
                        L = L.squeeze()
                    # 统一到 [B,4,T]
                    if L.dim() == 2:
                        if L.shape[0] == 4:
                            L = L.unsqueeze(0)
                        elif L.shape[1] == 4:
                            L = L.transpose(0,1).unsqueeze(0)
                        else:
                            raise RuntimeError(f"Unexpected facies logits shape (2D): {tuple(L.shape)}")
                    elif L.dim() == 3:
                        if L.shape[1] == 4:
                            pass
                        elif L.shape[2] == 4:
                            L = L.transpose(1,2)
                        else:
                            raise RuntimeError(f"Unexpected facies logits shape (3D): {tuple(L.shape)}")
                    else:
                        raise RuntimeError(f"Unexpected facies logits dim: {L.dim()} shape={tuple(L.shape)}")

                    fac_pred = torch.argmax(L, dim=1).cpu().numpy().astype(np.int16)  # [b,T]

                # 写回 cube：需要知道这 b 个样本对应的 (il,xl)
                # 我们按扫描顺序写回：done 是全局 trace index
                for k in range(b):
                    g = done + k
                    i = g // XL
                    j = g % XL
                    AI_pred_cube[i, j, :] = ai_den[k]
                    if fac_pred is not None:
                        Facies_pred_cube[i, j, :] = fac_pred[k]

                done += b
                if (done // batch_size) % 50 == 0:
                    print(f"  {done}/{IL*XL} done")

                b = 0  # reset batch counter

run_full_infer_fast()
print("Inference done:", AI_pred_cube.shape, Facies_pred_cube.shape)

Total traces: 30000


  with amp_ctx(enabled=use_amp):


  3200/30000 done
  6400/30000 done
  9600/30000 done
  12800/30000 done
  16000/30000 done
  19200/30000 done
  22400/30000 done
  25600/30000 done
  28800/30000 done
Inference done: (150, 200, 200) (150, 200, 200)


In [8]:
# %%
import os
import numpy as np

OUT_DIR = os.path.join(paths.processed_dir, "inversion_volume_fullcube")
os.makedirs(OUT_DIR, exist_ok=True)

ai_pred_npy  = os.path.join(OUT_DIR, "AI_pred_cube.npy")
fac_pred_npy = os.path.join(OUT_DIR, "Facies_pred_cube.npy")
np.save(ai_pred_npy, AI_pred_cube)
np.save(fac_pred_npy, Facies_pred_cube)
print("Saved npy:", ai_pred_npy)
print("Saved npy:", fac_pred_npy)

def write_segy_cube(filepath, cube, ilines, xlines, twt_ms):
    import segyio
    ILn, XLn, Tn = cube.shape
    spec = segyio.spec()
    spec.sorting = 2
    spec.format = 5
    spec.ilines = ilines.tolist()
    spec.xlines = xlines.tolist()
    spec.samples = twt_ms.astype(np.float32).tolist()

    with segyio.create(filepath, spec) as f:
        tr = 0
        for i, il in enumerate(ilines):
            for j, xl in enumerate(xlines):
                trace = cube[i, j, :].astype(np.float32)
                trace = np.nan_to_num(trace, nan=0.0)
                f.trace[tr] = trace
                f.header[tr][segyio.TraceField.INLINE_3D] = int(il)
                f.header[tr][segyio.TraceField.CROSSLINE_3D] = int(xl)
                tr += 1
        f.flush()

ai_pred_segy  = os.path.join(OUT_DIR, "AI_pred_cube.segy")
fac_pred_segy = os.path.join(OUT_DIR, "Facies_pred_cube.segy")

write_segy_cube(ai_pred_segy, AI_pred_cube, ilines, xlines, twt_ms)
write_segy_cube(fac_pred_segy, Facies_pred_cube.astype(np.float32), ilines, xlines, twt_ms)

print("Saved segy:")
print(" ", ai_pred_segy)
print(" ", fac_pred_segy)

Saved npy: H:\GK-MRL-PhysicsConsistent-Inversion\GK-MRL-PhysicsConsistent-Inversion\data\processed\inversion_volume_fullcube\AI_pred_cube.npy
Saved npy: H:\GK-MRL-PhysicsConsistent-Inversion\GK-MRL-PhysicsConsistent-Inversion\data\processed\inversion_volume_fullcube\Facies_pred_cube.npy
Saved segy:
  H:\GK-MRL-PhysicsConsistent-Inversion\GK-MRL-PhysicsConsistent-Inversion\data\processed\inversion_volume_fullcube\AI_pred_cube.segy
  H:\GK-MRL-PhysicsConsistent-Inversion\GK-MRL-PhysicsConsistent-Inversion\data\processed\inversion_volume_fullcube\Facies_pred_cube.segy


In [17]:
# %%
import os
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# ---------- 0) Input paths ----------
AI_TRUE_SGY = r"H:\GK-MRL-PhysicsConsistent-Inversion\GK-MRL-PhysicsConsistent-Inversion\data\ATTR_TIME_SGY\Acoustic_Impedance_TIME_dt1ms_XY25.sgy"
assert os.path.isfile(AI_TRUE_SGY), AI_TRUE_SGY

# ---------- 1) Read true cube ----------
AI_TRUE, _ = read_segy_cube_ilxl_t(AI_TRUE_SGY, ilines, xlines, T_expected=200)
print("AI_TRUE:", AI_TRUE.shape, "min/max:", float(AI_TRUE.min()), float(AI_TRUE.max()))
print("AI_PRED:", AI_pred_cube.shape, "min/max:", float(np.nanmin(AI_pred_cube)), float(np.nanmax(AI_pred_cube)))

# ---------- 2) Output dir ----------
FIG_DIR = Path(OUT_DIR) / "figs"
FIG_DIR.mkdir(parents=True, exist_ok=True)
print("FIG_DIR:", FIG_DIR)

# ---------- 3) Helpers ----------
def ms_to_k(ms):
    # twt_ms: [T]
    return int(np.argmin(np.abs(twt_ms - ms)))

def plot_3panel(
    A_true,
    A_pred,
    title,
    fp,
    cmap_ai="viridis",
    cmap_diff="coolwarm",
    p_low=2,
    p_high=98,
    p_diff=98,
    diff_vmax=5,      # ✅ 手动控制 Diff 色标：None=自动；数值=固定±diff_vmax
):
    """
    diff_vmax:
        None  -> 使用自动百分位 (±p_diff)
        float -> 使用固定对称色标 [-diff_vmax, +diff_vmax]
    """

    A_true = np.asarray(A_true, dtype=np.float32)
    A_pred = np.asarray(A_pred, dtype=np.float32)
    diff = A_pred - A_true

    # ---- True / Pred 色标（来自 True 的 robust percentile）----
    vmin = float(np.nanpercentile(A_true, p_low))
    vmax = float(np.nanpercentile(A_true, p_high))

    # ---- Diff 色标 ----
    if diff_vmax is None:
        dmax = float(np.nanpercentile(np.abs(diff), p_diff))
    else:
        dmax = float(diff_vmax)

    fig, axes = plt.subplots(1, 3, figsize=(14, 4), constrained_layout=True)

    # ---- True ----
    im0 = axes[0].imshow(A_true, aspect="auto", vmin=vmin, vmax=vmax, cmap=cmap_ai)
    axes[0].set_title("True")
    axes[0].set_xlabel("XL")
    axes[0].set_ylabel("IL / Time")
    c0 = fig.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04)
    c0.set_label("AI")

    # ---- Pred（与 True 同 vmin/vmax）----
    im1 = axes[1].imshow(A_pred, aspect="auto", vmin=vmin, vmax=vmax, cmap=cmap_ai)
    axes[1].set_title("Pred")
    axes[1].set_xlabel("XL")
    axes[1].set_ylabel("IL / Time")
    c1 = fig.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)
    c1.set_label("AI")

    # ---- Diff（对称色标）----
    im2 = axes[2].imshow(diff, aspect="auto", vmin=-dmax, vmax=dmax, cmap=cmap_diff)
    axes[2].set_title("Diff (Pred − True)")
    axes[2].set_xlabel("XL")
    axes[2].set_ylabel("IL / Time")
    c2 = fig.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04)
    c2.set_label("ΔAI")

    fig.suptitle(title, fontsize=14)
    fig.savefig(fp, dpi=300)
    plt.close(fig)

# ---------- 4) Make figures ----------
saved = []

# ✅ 你可以在这里统一设置 Diff 的固定范围（比如 2000 / 3000）
# diff_fixed = 2000
diff_fixed = 5   # None=自动百分位；改成数值=固定±diff_fixed

# (A) Time slices @ 40/100/160 ms
for ms in [40, 100, 160]:
    k = ms_to_k(ms)
    fp = FIG_DIR / f"timeslice_{ms}ms_true_pred_diff.png"
    plot_3panel(
        AI_TRUE[:, :, k],
        AI_pred_cube[:, :, k],
        title=f"AI @ {ms} ms",
        fp=str(fp),
        cmap_ai="viridis",
        cmap_diff="coolwarm",      # 你想用 coolwarm 就保留；也可用 seismic
        diff_vmax=diff_fixed
    )
    saved.append(str(fp))

# (B) Inline = 100
inline_val = 100
ii = int(np.where(ilines == inline_val)[0][0])
fp = FIG_DIR / f"inline_{inline_val}_true_pred_diff.png"
plot_3panel(
    AI_TRUE[ii, :, :].T,
    AI_pred_cube[ii, :, :].T,
    title=f"Inline {inline_val}",
    fp=str(fp),
    cmap_ai="viridis",
    cmap_diff="seismic",
    diff_vmax=diff_fixed
)
saved.append(str(fp))

# (C) Xline = 50
xline_val = 50
jj = int(np.where(xlines == xline_val)[0][0])
fp = FIG_DIR / f"xline_{xline_val}_true_pred_diff.png"
plot_3panel(
    AI_TRUE[:, jj, :].T,
    AI_pred_cube[:, jj, :].T,
    title=f"Xline {xline_val}",
    fp=str(fp),
    cmap_ai="viridis",
    cmap_diff="seismic",
    diff_vmax=diff_fixed
)
saved.append(str(fp))

print("\nSaved figs:")
for p in saved:
    print(" ", p)

Loaded cube: (150, 200, 200) missing traces: 0
AI_TRUE: (150, 200, 200) min/max: 4.318020820617676 9.578778266906738
AI_PRED: (150, 200, 200) min/max: 4.1333112716674805 9.656217575073242
FIG_DIR: H:\GK-MRL-PhysicsConsistent-Inversion\GK-MRL-PhysicsConsistent-Inversion\data\processed\inversion_volume_fullcube\figs

Saved figs:
  H:\GK-MRL-PhysicsConsistent-Inversion\GK-MRL-PhysicsConsistent-Inversion\data\processed\inversion_volume_fullcube\figs\timeslice_40ms_true_pred_diff.png
  H:\GK-MRL-PhysicsConsistent-Inversion\GK-MRL-PhysicsConsistent-Inversion\data\processed\inversion_volume_fullcube\figs\timeslice_100ms_true_pred_diff.png
  H:\GK-MRL-PhysicsConsistent-Inversion\GK-MRL-PhysicsConsistent-Inversion\data\processed\inversion_volume_fullcube\figs\timeslice_160ms_true_pred_diff.png
  H:\GK-MRL-PhysicsConsistent-Inversion\GK-MRL-PhysicsConsistent-Inversion\data\processed\inversion_volume_fullcube\figs\inline_100_true_pred_diff.png
  H:\GK-MRL-PhysicsConsistent-Inversion\GK-MRL-Physic