<a href="https://colab.research.google.com/github/OLEDman926/OEFingerprint/blob/main/YS_model(01_20_26)_%EC%96%91%ED%83%9C%EC%9A%B1V1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### Yejin Shim 고분자 응력필드 코드 (updated 01/20/2026)



---


#### 학습된 모델 셋업:

*   epoch=50 (both global/local); lr=1e-3;
*   dataset total=2400 (train=1920, val=240, test=240)
*   batch=5
*   save된 model weights directory: ./model/weights(01.20.26); output directory: ./model/out(01.20.26)

#1. Gdrive load

In [None]:
import os
import math
import csv
import numpy as np
import datetime
import cv2
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
import scipy
from tensorflow.keras import layers
import time
from keras import activations
from IPython.display import display


import matplotlib as mpl
mpl.rcParams.update({
    "font.family": "Liberation Sans",
    "font.size": 11,
    "axes.labelsize": 11,
    "axes.titlesize": 11,
    "xtick.labelsize": 10,
    "ytick.labelsize": 10,
    "legend.fontsize": 10,
    "pdf.fonttype": 42,
    "ps.fonttype": 42,
})

from google.colab import drive
drive.mount('/content/drive/')
os.chdir(os.getcwd() + "/drive/MyDrive/YS_고분자응력필드_dataset") #---- "YS_고분자응력필드_dataset" 폴더 G-drive 위치로 설정해주세요

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [None]:
os.getcwd()

'/content/drive/.shortcut-targets-by-id/1MI7GB2cSnswvIZvKi4PyjTyRmytr3OIQ/YS_고분자응력필드_dataset'

# 2. Preprocessing dataset

### Keywords:

*   geom: local/global 모델들 인풋
*   target01: local 모델 아웃풋 groundtruth - [0,1]로 normalized된 target
*   norm_stats: global 모델 아웃풋 - [min_log, max_log, eps] or [min_log, max_log]
*   mm: [min_log and max_log]

In [None]:
import os
import numpy as np
import tensorflow as tf

# =========================================================================================
# 1) Dataset 로드
# geom (input), target01 (local_output), norm_stats (global_output)
# ****norm_stats expected: [min_log, max_log, eps] or [min_log, max_log]
# *eps(=1e-12): log로 처음 변환할 때 raw=0값이 negative infinity로 발산하지 않도록 더했음
# =========================================================================================

def load_one_geom_target_stats(stem: str, root=".", target_hw=None):
    g_path = os.path.join(root, f"{stem}-geom.npy")
    y_path = os.path.join(root, f"{stem}-target.npy")
    s_path = os.path.join(root, f"{stem}-norm_stats.npy")

    geom = np.load(g_path).astype(np.float32)      # (H,W)
    target = np.load(y_path).astype(np.float32)    # (H,W)
    stats = np.load(s_path).astype(np.float32)     # (2,) or (3,)

    if geom.shape != target.shape:
        raise ValueError(f"shape mismatch: geom{geom.shape} vs target{target.shape} for stem={stem}")

    stats = np.array(stats, dtype=np.float32).reshape(-1)
    if stats.size < 2:
        raise ValueError(f"norm_stats must have at least [min_log, max_log]; got {stats.shape}")

    if stats.size == 2:
        stats = np.array([stats[0], stats[1], 1e-12], dtype=np.float32) # eps가 없을 경우에 default값으로 설정
    else:
        stats = stats[:3].astype(np.float32)

    return geom, target, stats  # (H,W), (H,W), (3,)


def load_dataset_geom_to_target(stems, root=".", target_hw=None):
    X, Y, S = [], [], []

    for stem in stems:
        try:
            g, y, s = load_one_geom_target_stats(stem, root=root, target_hw=target_hw)
            X.append(g)
            Y.append(y)
            S.append(s)
        except Exception as e:
            print(f"Skipped {stem}: {e}")
            continue

    if not X:
        raise RuntimeError("No samples loaded. Check stems/root/patterns.")

    X = np.stack(X, axis=0).astype(np.float32)  # (N,H,W)
    Y = np.stack(Y, axis=0).astype(np.float32)  # (N,H,W)
    S = np.stack(S, axis=0).astype(np.float32)  # (N,3) = [min_log, max_log, eps]
    return X, Y, S


# ===========================================================================================
# 2) 응력필드 reconstruction (inverse-normalize target01 using norm_stats [min_log, max_log, eps])
# ===========================================================================================

def reconstruct_stress_from_target(target01, norm_stats):
    """
    (H,W) or (H,W,1): target01
    (3,): norm_stats
    """
    if target01.ndim == 3 and target01.shape[-1] == 1: #(H,W,1)인 경우에
        target01 = target01[..., 0]

    mm = np.array(norm_stats, dtype=np.float32).reshape(-1)
    if mm.size < 2:
        raise ValueError(f"norm_stats must have [min_log,max_log,(eps)]; got {mm.shape}")

    min_log = float(mm[0])
    max_log = float(mm[1])
    eps = float(mm[2]) if mm.size >= 3 else 1e-12

    log_field = target01.astype(np.float32) * (max_log - min_log) + min_log
    stress = (10.0 ** log_field) - eps
    return stress.astype(np.float32)


# ============================================================
# 3) tf.data dataset 생성
#
#    ds_local: (geom, target01), ds_global (geom, mm)
#    ds_full : (geom, target01, mm)
# ============================================================

def make_tf_dataset_local(X, Y, stats, batch=8, shuffle=4096):
    X = X[..., None].astype(np.float32)  # (N,H,W,1)
    Y = Y[..., None].astype(np.float32)  # (N,H,W,1)
    stats = stats.astype(np.float32)     # (N,3)

    ds = tf.data.Dataset.from_tensor_slices((X, Y, stats))
    if shuffle and shuffle > 1:
        ds = ds.shuffle(int(shuffle))
    ds = ds.batch(int(batch)).prefetch(tf.data.AUTOTUNE)

    ds_local = ds.map(lambda x, y, s: (x, y), num_parallel_calls=tf.data.AUTOTUNE)
    return ds_local, ds


def load_samples_local_plus_minmax(stems, root=".", target_hw=(128,128,1)):
    """
    (N,H,W): geom
    (N,H,W): target01
    (N,2): mm
    """
    geom, target01, s = load_dataset_geom_to_target(stems, root=root, target_hw=target_hw)
    mm = s[:, :2].astype(np.float32)
    return geom, target01, mm

def make_tf_datasets(geom, target01, mm, batch=8, shuffle=4096):
    """
    ds_full (x, target01, mm) 생성
    """
    geom = geom[..., None].astype(np.float32)     # (N,H,W,1)
    target01 = target01[..., None].astype(np.float32)  # (N,H,W,1)
    mm = mm.astype(np.float32)              # (N,2)

    ds = tf.data.Dataset.from_tensor_slices((geom, target01, mm))
    if shuffle and shuffle > 1:
        ds = ds.shuffle(int(shuffle))
    ds = ds.batch(int(batch)).prefetch(tf.data.AUTOTUNE)
    return ds #(geom, target01, mm)

def split_two_model_datasets(ds_full):
    """
    ds_local (x, target01)
    ds_global (x, mm) 생성
    """
    ds_local  = ds_full.map(lambda x, y, mm: (x, y),  num_parallel_calls=tf.data.AUTOTUNE)
    ds_global = ds_full.map(lambda x, y, mm: (x, mm), num_parallel_calls=tf.data.AUTOTUNE)
    return ds_local, ds_global #(geom, target01) and (geom, mm)


# ========================
# 6) Main loop functions
# ========================

def make_splits(n, train_frac=0.8, val_frac=0.1, seed=42):
    rng = np.random.default_rng(seed)
    idx = np.arange(n)
    rng.shuffle(idx)

    n_train = int(n * train_frac)
    n_val = int(n * val_frac)

    train_ids = idx[:n_train]
    val_ids   = idx[n_train:n_train + n_val]
    test_ids  = idx[n_train + n_val:]
    return train_ids, val_ids, test_ids

def build_two_model_datasets(
    all_ids,
    root=".",
    batch=8,
    shuffle=4096,
    seed=42,
    target_hw=(128,128,1),
):
    # (1)Splitting the IDs
    train_ids, val_ids, test_ids = make_splits(len(all_ids), seed=seed)
    train_ids = [all_ids[i] for i in train_ids]
    val_ids   = [all_ids[i] for i in val_ids]
    test_ids  = [all_ids[i] for i in test_ids]

    # (2)Load based on the IDs
    geom_tr, target01_tr, mm_tr = load_samples_local_plus_minmax(train_ids, root=root, target_hw=target_hw)
    geom_va, target01_va, mm_va = load_samples_local_plus_minmax(val_ids,   root=root, target_hw=target_hw)
    geom_te, target01_te, mm_te = load_samples_local_plus_minmax(test_ids,  root=root, target_hw=target_hw)

    # 3) tf.data datasets 생성
    ds_tr = make_tf_datasets(geom_tr, target01_tr, mm_tr, batch=batch, shuffle=shuffle)
    ds_va = make_tf_datasets(geom_va, target01_va, mm_va, batch=batch, shuffle=1)      # no shuffle
    ds_te = make_tf_datasets(geom_te, target01_te, mm_te, batch=batch, shuffle=1)      # no shuffle

    # 4) local_ds와 global_ds로 split
    tr_local, tr_global = split_two_model_datasets(ds_tr)
    va_local, va_global = split_two_model_datasets(ds_va)
    te_local, te_global = split_two_model_datasets(ds_te)

    return {
        "ids": {"train": train_ids, "val": val_ids, "test": test_ids},
        "local":  {"train": tr_local,  "val": va_local,  "test": te_local},
        "global": {"train": tr_global, "val": va_global, "test": te_global},
        "shapes": {"X": geom_tr.shape[1:], "Y_local": target01_tr.shape[1:], "minmax": mm_tr.shape[1:]},
    }


#3. Model build & training

In [None]:
import os
import numpy as np
import tensorflow as tf

# ======================
# 1) Model architectures
# ======================

def _conv_block(x, filters, name):
  x = tf.keras.layers.Conv2D(filters, 3, padding="same", activation="relu", name=f"{name}_c1")(x)
  x = tf.keras.layers.Conv2D(filters, 3, padding="same", activation="relu", name=f"{name}_c2")(x)
  return x

def _match_hw(x, ref, name):
  return tf.keras.layers.Lambda(
      lambda t: tf.image.resize(t[0], tf.shape(t[1])[1:3], method="bilinear"),
      name=name
  )([x, ref])

def build_local_unet(input_shape, base=16):
  inp = tf.keras.Input(shape=input_shape, name="geom_in")

  e1 = _conv_block(inp, base,   "enc1"); p1 = tf.keras.layers.MaxPool2D(2)(e1)
  e2 = _conv_block(p1,  base*2,   "enc2"); p2 = tf.keras.layers.MaxPool2D(2)(e2)
  e3 = _conv_block(p2,  base*4,   "enc3"); p3 = tf.keras.layers.MaxPool2D(2)(e3)
  e4 = _conv_block(p3,  base*8,   "enc4"); p4 = tf.keras.layers.MaxPool2D(2)(e4)

  b  = _conv_block(p4,  base*16, "bottleneck")

  u4 = tf.keras.layers.UpSampling2D(2)(b);  u4 = _match_hw(u4, e4, "m4")
  d4 = _conv_block(tf.keras.layers.Concatenate()([u4, e4]), base*8, "dec4")

  u3 = tf.keras.layers.UpSampling2D(2)(d4); u3 = _match_hw(u3, e3, "m3")
  d3 = _conv_block(tf.keras.layers.Concatenate()([u3, e3]), base*4, "dec3")

  u2 = tf.keras.layers.UpSampling2D(2)(d3); u2 = _match_hw(u2, e2, "m2")
  d2 = _conv_block(tf.keras.layers.Concatenate()([u2, e2]), base*2, "dec2")

  u1 = tf.keras.layers.UpSampling2D(2)(d2); u1 = _match_hw(u1, e1, "m1")
  d1 = _conv_block(tf.keras.layers.Concatenate()([u1, e1]), base, "dec1")

  out = tf.keras.layers.Conv2D(1, 1, padding="same", activation="sigmoid", name="y_log_01")(d1)
  return tf.keras.Model(inp, out, name="Local-U-Net")


def build_global_minmax_model(input_shape, base=16, dropout=0.0):
  inp = tf.keras.Input(shape=input_shape, name="geom_in")

  x = tf.keras.layers.Conv2D(base,   3, padding="same", activation="relu")(inp)
  x = tf.keras.layers.Conv2D(base,   3, padding="same", activation="relu")(x)
  x = tf.keras.layers.MaxPool2D(2)(x)

  x = tf.keras.layers.Conv2D(base*2, 3, padding="same", activation="relu")(x)
  x = tf.keras.layers.Conv2D(base*2, 3, padding="same", activation="relu")(x)
  x = tf.keras.layers.MaxPool2D(2)(x)

  x = tf.keras.layers.Conv2D(base*4, 3, padding="same", activation="relu")(x)
  x = tf.keras.layers.Conv2D(base*4, 3, padding="same", activation="relu")(x)
  x = tf.keras.layers.MaxPool2D(2)(x)

  x = tf.keras.layers.Conv2D(base*8, 3, padding="same", activation="relu")(x)
  x = tf.keras.layers.Conv2D(base*8, 3, padding="same", activation="relu")(x)

  gap = tf.keras.layers.GlobalAveragePooling2D()(x)
  gmp = tf.keras.layers.GlobalMaxPooling2D()(x)
  g = tf.keras.layers.Concatenate()([gap, gmp])

  g = tf.keras.layers.Dense(128, activation="relu")(g)
  if dropout and dropout > 0:
    g = tf.keras.layers.Dropout(dropout)(g)
  g = tf.keras.layers.Dense(64, activation="relu")(g)

  min_log = tf.keras.layers.Dense(1, activation="linear", name="min_log")(g)
  raw_range = tf.keras.layers.Dense(1, activation="linear", name="raw_range")(g)
  range_log = tf.keras.layers.Activation(tf.nn.softplus, name="range_log")(raw_range)
  max_log = tf.keras.layers.Add(name="max_log")([min_log, range_log])

  out = tf.keras.layers.Concatenate(name="minmax")([min_log, max_log])  # (B,2)
  return tf.keras.Model(inp, out, name="Global-MinMax")


# ==================
# 2) MSE loss function
# ==================
@tf.function
def plain_mse(y_true, y_pred):
  return tf.reduce_mean(tf.square(y_true - y_pred))


# ===============
# 3) Local trainer
# ===============

def train_local_unet(
    ds_train, ds_val, ds_test, input_shape,
    epochs=15, lr=1e-3,
    weights_dir="./model/weights/local"
):
  os.makedirs(weights_dir, exist_ok=True)

  model = build_local_unet(input_shape=input_shape, base=16)
  opt = tf.keras.optimizers.Adam(lr)

  @tf.function
  def train_step(x, y):
    with tf.GradientTape() as tape:
      pred = model(x, training=True)
      loss = plain_mse(y, pred)
    grads = tape.gradient(loss, model.trainable_variables)
    opt.apply_gradients(zip(grads, model.trainable_variables))
    return loss

  @tf.function
  def eval_step(x, y):
    pred = model(x, training=False)
    return plain_mse(y, pred)

  for ep in range(1, epochs + 1):
    tr = [float(train_step(x, y).numpy()) for x, y in ds_train]
    va = [float(eval_step(x, y).numpy())  for x, y in ds_val]
    print(f"[Local-U-Net] Ep {ep:02d} | train={np.mean(tr):.6f} | val={np.mean(va):.6f}")

    wpath = os.path.join(weights_dir, f"local_unet_ep{ep:04d}.weights.h5")
    model.save_weights(wpath)

  te = [float(eval_step(x, y).numpy()) for x, y in ds_test]
  print(f"[Local-U-Net] TEST={np.mean(te):.6f}")
  return model


# ===============
# 4) Global trainer
# ===============

class _SaveGlobalWeightsEachEpoch(tf.keras.callbacks.Callback):
  def __init__(self, weights_dir="./model/weights/global"):
    super().__init__()
    self.weights_dir = weights_dir
    os.makedirs(self.weights_dir, exist_ok=True)

  def on_epoch_end(self, epoch, logs=None):
    ep = int(epoch) + 1
    wpath = os.path.join(self.weights_dir, f"global_minmax_ep{ep:04d}.weights.h5")
    self.model.save_weights(wpath)


def train_global_minmax(
    ds_train, ds_val, ds_test, input_shape,
    epochs=15, lr=1e-3,
    weights_dir="./model/weights/global"
):
  os.makedirs(weights_dir, exist_ok=True)

  model = build_global_minmax_model(input_shape=input_shape, base=16, dropout=0.0)
  model.compile(
    optimizer=tf.keras.optimizers.Adam(lr),
    loss=tf.keras.losses.MeanSquaredError(),
    metrics=[tf.keras.metrics.MeanAbsoluteError(name="mae")]
  )

  cb = _SaveGlobalWeightsEachEpoch(weights_dir=weights_dir)

  model.fit(
    ds_train,
    validation_data=ds_val,
    epochs=epochs,
    verbose=1,
    callbacks=[cb]
  )

  model.evaluate(ds_test, verbose=1)
  return model


# ==========================
# NOT currently using this
# ==========================

def train_local_and_global_from_bundles(
    bundles,
    epochs_local=15,
    epochs_global=15,
    lr=1e-3,
    weights_dir="./model/weights",
):

  os.makedirs(weights_dir, exist_ok=True)

  ds_local_train = bundles["local"]["train"]
  ds_local_val   = bundles["local"]["val"]
  ds_local_test  = bundles["local"]["test"]

  ds_global_train = bundles["global"]["train"]
  ds_global_val   = bundles["global"]["val"]
  ds_global_test  = bundles["global"]["test"]

  geom_0, _ = next(iter(ds_local_train)) # input_shape이 model build할 때 필요해서
  input_shape = tuple(geom_0.shape[1:])  # (H,W,1)

  local_unet = train_local_unet(
    ds_local_train, ds_local_val, ds_local_test,
    input_shape=input_shape,
    epochs=epochs_local, lr=lr,
    weights_dir=weights_dir
  )

  global_model = train_global_minmax(
    ds_global_train, ds_global_val, ds_global_test,
    input_shape=input_shape,
    epochs=epochs_global, lr=lr,
    weights_dir=weights_dir
  )

  return local_unet, global_model

In [None]:
import os
import numpy as np
import tensorflow as tf
from pathlib import Path

# ============================================================
# 0) Dataset 스캔: getting STEMS from files on disk
#    파일 포맷:
#      <stem>-geom.npy
#      <stem>-target.npy
#      <stem>-norm_stats.npy
# ============================================================

def list_dataset_stems(root="."):
    root = Path(root)

    def stems_with_suffix(suffix: str):
        # <stem><suffix>로 포맷된 stem files return
        return {p.name[:-len(suffix)] for p in root.glob(f"*{suffix}")}

    geom = stems_with_suffix("-geom.npy")
    targ = stems_with_suffix("-target.npy")
    stat = stems_with_suffix("-norm_stats.npy")

    stems = sorted(geom & targ & stat)

    if not stems:
        print(f"[DEBUG] folder: {root.resolve()}")
        print(f"[DEBUG] counts: geom={len(geom)} target={len(targ)} norm_stats={len(stat)}")
        examples = sorted([p.name for p in root.glob("*.npy")])[:25]
        print(f"[DEBUG] example files (first 25): {examples}")
        raise RuntimeError(
            f"No valid samples found in {root.resolve()}. "
            "Need <stem>-geom.npy + <stem>-target.npy + <stem>-norm_stats.npy."
        )

    return stems

# ============================================================
# 1) Stem에서 데이터 로드: (geom, target01, mm)
# ============================================================

def load_samples_geom_target_minmax( # X=geom; Y=target01; MM=[min_log, max_log]
    stems,
    root=".",
    geom_pat="{}-geom.npy",
    target_pat="{}-target.npy",
    minmax_pat="{}-norm_stats.npy",
):
    X, Y, MM = [], [], []

    for stem in stems:
        try:
            x = np.load(os.path.join(root, geom_pat.format(stem))).astype(np.float32)    # (H,W)
            y = np.load(os.path.join(root, target_pat.format(stem))).astype(np.float32) # (H,W)
            mm = np.load(os.path.join(root, minmax_pat.format(stem))).astype(np.float32)

            if x.shape != y.shape:
                raise ValueError(f"shape mismatch: geom{x.shape} vs target{y.shape}")

            mm = np.array(mm, dtype=np.float32).reshape(-1)
            if mm.size < 2:
                raise ValueError(f"norm_stats must start with [min_log,max_log]; got shape {mm.shape}")

            X.append(x)
            Y.append(y)
            MM.append(mm[:2])  # keep only [min_log, max_log] (eps 제외)

        except Exception as e:
            print(f"Skipped {stem}: {e}")
            continue

    if not X:
        raise RuntimeError("No samples loaded after filtering/skips.")

    X = np.stack(X, axis=0).astype(np.float32)      # (N,H,W)
    Y = np.stack(Y, axis=0).astype(np.float32)      # (N,H,W)
    MM = np.stack(MM, axis=0).astype(np.float32)    # (N,2)
    return X, Y, MM

# ============================================================
# 2) tf.data.Dataset 생성 후 local/global dataset으로 각각 split
# ============================================================

def make_tf_dataset_full(X, Y, MM, batch=8, shuffle=4096): # X=geom; Y=target01; MM=[min_log, max_log]
    X = X[..., None].astype(np.float32)   # (N,H,W,1)
    Y = Y[..., None].astype(np.float32)   # (N,H,W,1)
    MM = MM.astype(np.float32)            # (N,2)

    ds = tf.data.Dataset.from_tensor_slices((X, Y, MM))

    if shuffle and int(shuffle) > 1:
        buf = min(int(shuffle), int(X.shape[0])) # shuffle buffer size는 dataset size보다 작아야함
        ds = ds.shuffle(buf)

    ds = ds.batch(int(batch)).prefetch(tf.data.AUTOTUNE)
    return ds

def split_local_global(ds_full):
    ds_local  = ds_full.map(lambda x, y, mm: (x, y),  num_parallel_calls=tf.data.AUTOTUNE)
    ds_global = ds_full.map(lambda x, y, mm: (x, mm), num_parallel_calls=tf.data.AUTOTUNE)
    return ds_local, ds_global

# ============================================================
# 3) Train/val/test split
# ============================================================

def split_stems(stems, train_frac=0.8, val_frac=0.1, seed=42):
    rng = np.random.default_rng(seed)
    stems = list(stems)
    rng.shuffle(stems)

    n = len(stems)
    n_train = int(n * train_frac)
    n_val = int(n * val_frac)

    train = stems[:n_train]
    val   = stems[n_train:n_train + n_val]
    test  = stems[n_train + n_val:]
    return train, val, test

# ============================================================
# 4) Main train pipeline
# ============================================================

def main_train_pipeline( # X=geom; Y=target01; MM=[min_log, max_log]
    data_root=".",
    batch=8,
    shuffle=4096,
    seed=42,
    epochs_local=15,
    epochs_global=15,
    lr=1e-3,
):
    #-(1) root directory에서 dataset retrieve
    stems_all = list_dataset_stems(data_root)
    print(f"Found {len(stems_all)} valid samples under: {os.path.abspath(data_root)}")

    #-(2) split
    stems_tr, stems_va, stems_te = split_stems(stems_all, seed=seed)
    print(f"Split: train={len(stems_tr)}, val={len(stems_va)}, test={len(stems_te)}")

    #-(3) arrays 로드
    X_tr, Y_tr, MM_tr = load_samples_geom_target_minmax(stems_tr, root=data_root)
    X_va, Y_va, MM_va = load_samples_geom_target_minmax(stems_va, root=data_root)
    X_te, Y_te, MM_te = load_samples_geom_target_minmax(stems_te, root=data_root)

    #-(4) tf.data.Dataset 생성
    ds_tr_full = make_tf_dataset_full(X_tr, Y_tr, MM_tr, batch=batch, shuffle=shuffle)
    ds_va_full = make_tf_dataset_full(X_va, Y_va, MM_va, batch=batch, shuffle=1)
    ds_te_full = make_tf_dataset_full(X_te, Y_te, MM_te, batch=batch, shuffle=1)

    ds_local_tr, ds_global_tr = split_local_global(ds_tr_full)
    ds_local_va, ds_global_va = split_local_global(ds_va_full)
    ds_local_te, ds_global_te = split_local_global(ds_te_full)

    #-(5) input_shape 확보 (model build할 때 필요)
    geom0, _ = next(iter(ds_local_tr))          # geom0: (B,H,W,1)
    input_shape = tuple(geom0.shape[1:])        # (H,W,1)
    print("Input shape:", input_shape)

    #-(6) Local U-Net 모델 학습
    local_unet = train_local_unet(
        ds_local_tr, ds_local_va, ds_local_te,
        input_shape=input_shape,
        epochs=epochs_local,
        lr=lr,
        weights_dir="./model/weights/local" #<-- specify the local weights path here
    )
    #-(7) Global min/max 모델 학습
    global_model = train_global_minmax(
        ds_global_tr, ds_global_va, ds_global_te,
        input_shape=input_shape,
        epochs=epochs_global,
        lr=lr,
        weights_dir="./model/weights/global" #<-- specify the global weights path here
    )

    bundles = {
        "ids": {"train": stems_tr, "val": stems_va, "test": stems_te},
        "local":  {"train": ds_local_tr,  "val": ds_local_va,  "test": ds_local_te},
        "global": {"train": ds_global_tr, "val": ds_global_va, "test": ds_global_te},
    }

    return local_unet, global_model, bundles


def build_bundles_only(
    data_root=".",
    batch=8,
    shuffle=4096,
    seed=42,
    train_frac=0.8,
    val_frac=0.1,
):
    # (1) stem files retrieve
    stems_all = list_dataset_stems(data_root)
    print(f"Found {len(stems_all)} valid samples under: {os.path.abspath(data_root)}")

    # (2) split
    stems_tr, stems_va, stems_te = split_stems(
        stems_all, train_frac=train_frac, val_frac=val_frac, seed=seed
    )
    print(f"Split: train={len(stems_tr)}, val={len(stems_va)}, test={len(stems_te)}")

    # (3) arrays 로드
    X_tr, Y_tr, MM_tr = load_samples_geom_target_minmax(stems_tr, root=data_root)
    X_va, Y_va, MM_va = load_samples_geom_target_minmax(stems_va, root=data_root)
    X_te, Y_te, MM_te = load_samples_geom_target_minmax(stems_te, root=data_root)

    # (4) tf.data datasets 생성 / split
    ds_tr_full = make_tf_dataset_full(X_tr, Y_tr, MM_tr, batch=batch, shuffle=shuffle)
    ds_va_full = make_tf_dataset_full(X_va, Y_va, MM_va, batch=batch, shuffle=1)
    ds_te_full = make_tf_dataset_full(X_te, Y_te, MM_te, batch=batch, shuffle=1)

    ds_local_tr, ds_global_tr = split_local_global(ds_tr_full)
    ds_local_va, ds_global_va = split_local_global(ds_va_full)
    ds_local_te, ds_global_te = split_local_global(ds_te_full)

    # (5) input_shape 확보
    x0, y0 = next(iter(ds_local_tr))      # x0: (B,H,W,1)
    input_shape = tuple(x0.shape[1:])     # (H,W,1)
    print("Input shape:", input_shape)

    bundles = {
        "ids": {"train": stems_tr, "val": stems_va, "test": stems_te},
        "local":  {"train": ds_local_tr,  "val": ds_local_va,  "test": ds_local_te},
        "global": {"train": ds_global_tr, "val": ds_global_va, "test": ds_global_te},
    }
    return bundles, input_shape

# 3. Main

In [None]:
# ============================================================
# 5) Entry point
# ============================================================
if __name__ == "__main__":
    local_unet, global_model, bundles = main_train_pipeline(  #각 epoch마다 weights.h5 save
        data_root="./dataset/combined",
        batch=5,
        shuffle=6000,
        seed=123,
        epochs_local=1,
        epochs_global=1,
        lr=1e-3,
    )

Found 2400 valid samples under: /content/drive/.shortcut-targets-by-id/1MI7GB2cSnswvIZvKi4PyjTyRmytr3OIQ/YS_고분자응력필드_dataset/dataset/combined
Split: train=1920, val=240, test=240


#4. Model test
###*Must run the above helper functions (except the main)

In [None]:
import os
import glob
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm

# ============================================================
# (1) Dataset 유틸리티 함수들
# ============================================================

def list_dataset_stems(root="."):
    stems = []
    for gp in sorted(glob.glob(os.path.join(root, "*-geom.npy"))):
        stem = os.path.basename(gp)[:-len("-geom.npy")]
        if (os.path.exists(os.path.join(root, f"{stem}-target.npy")) and
            os.path.exists(os.path.join(root, f"{stem}-norm_stats.npy"))):
            stems.append(stem)
    if not stems:
        raise RuntimeError(f"No valid stems found under: {os.path.abspath(root)}")
    return stems


def _stats_to_minmax_eps(norm_stats, default_eps=1e-12):
    s = np.array(norm_stats, dtype=np.float32).reshape(-1)
    if s.size < 2:
        raise ValueError("norm_stats must contain at least [min_log, max_log]")
    min_log = float(s[0])
    max_log = float(s[1])
    eps = float(s[2]) if s.size >= 3 else float(default_eps)
    return min_log, max_log, eps


def load_one(stem, root="."):
    g = np.load(os.path.join(root, f"{stem}-geom.npy")).astype(np.float32)
    y = np.load(os.path.join(root, f"{stem}-target.npy")).astype(np.float32)
    s = np.load(os.path.join(root, f"{stem}-norm_stats.npy")).astype(np.float32)

    if g.shape != y.shape:
        raise ValueError(f"shape mismatch: geom{g.shape} vs target{y.shape}")

    return g, y, s


# ============================================================
# (2) 예측한 응력필드 reconstruct / evaluate
# ============================================================

def recon_log_field(local01_hw, minmax):
    if local01_hw.ndim == 3 and local01_hw.shape[-1] == 1:
        local01_hw = local01_hw[..., 0]
    min_log, max_log = float(minmax[0]), float(minmax[1])
    return local01_hw.astype(np.float32) * (max_log - min_log) + min_log


def eval_one(local_unet, global_model, geom_hw):
    x = geom_hw[None, ..., None].astype(np.float32)  # (1,H,W,1)

    y_pred = local_unet(x, training=False).numpy()[0]
    y_pred01 = y_pred[..., 0] if y_pred.ndim == 3 else np.asarray(y_pred).squeeze()

    mm_pred = global_model(x, training=False).numpy()[0]
    mm_pred = np.asarray(mm_pred, dtype=np.float32).reshape(-1)

    if mm_pred.size < 2:
        raise ValueError("global_model must output [min_log, max_log]")

    return y_pred01.astype(np.float32), mm_pred[:2].astype(np.float32)


# ============================================================
# (3) Metrics (testing)
# ============================================================

def compute_test_metrics(local_unet, global_model, stems, root=".", max_items=None):
    local_mse_list = []
    global_mse_list = []
    n_done = 0

    for stem in stems:
        if max_items is not None and n_done >= int(max_items):
            break
        try:
            geom, y_true01, stats = load_one(stem, root=root)
            min_log, max_log, _ = _stats_to_minmax_eps(stats)

            y_pred01, mm_pred = eval_one(local_unet, global_model, geom)

            local_mse = float(np.mean((y_pred01 - y_true01) ** 2))
            mm_true = np.array([min_log, max_log], dtype=np.float32)
            global_mse = float(np.mean((mm_pred - mm_true) ** 2))

            print(f"stem:{stem}")
            print(f"mm_pred:{mm_pred}, mm_true:{mm_true}")

            local_mse_list.append(local_mse)
            global_mse_list.append(global_mse)
            n_done += 1

        except Exception:
            continue

    if n_done == 0:
        raise RuntimeError("No samples evaluated.")

    return {
        "n": n_done,
        "local_mse_mean": float(np.mean(local_mse_list)),
        "local_mse_std": float(np.std(local_mse_list)),
        "global_mse_mean": float(np.mean(global_mse_list)),
        "global_mse_std": float(np.std(global_mse_list)),
    }


# ============================================================
# (4) Visualization 함수들
# ============================================================

def plot_geom_local_recon(stem, geom, y_true01, y_pred01, minmax_true, minmax_pred, savepath=None):

    log_true = recon_log_field(y_true01, minmax_true)
    log_pred = recon_log_field(y_pred01, minmax_pred)

    #-global color limits
    vminL = min(log_true.min(), log_pred.min())
    vmaxL = max(log_true.max(), log_pred.max())

    fig = plt.figure(figsize=(12, 9))
    gs = fig.add_gridspec(3, 2, hspace=0.25, wspace=0.15)

    ax0 = fig.add_subplot(gs[0, :])
    #cmap = ListedColormap(["blue", "red"])
    #bounds = [-0.01, 0.5, 1.01]
    #norm = BoundaryNorm(bounds, cmap.N)

    im0 = ax0.imshow(geom, cmap='binary', origin="upper")
    #im0 = ax0.imshow(geom, cmap=cmap, norm=norm, interpolation="nearest")
    ax0.set_title("Geometry (input)")
    ax0.axis("off")
    plt.colorbar(im0, ax=ax0, fraction=0.02, pad=0.01)

    ax1 = fig.add_subplot(gs[1, 0])
    im1 = ax1.imshow(y_pred01, vmin=0, vmax=1, origin="upper")
    ax1.set_title("Normalized log-stress (prediction)")
    ax1.axis("off")
    plt.colorbar(im1, ax=ax1, fraction=0.046, pad=0.02)

    ax2 = fig.add_subplot(gs[1, 1])
    im2 = ax2.imshow(y_true01, vmin=0, vmax=1, origin="upper")
    ax2.set_title("Normalized log-stress (groundtruth)")
    ax2.axis("off")
    plt.colorbar(im2, ax=ax2, fraction=0.046, pad=0.02)

    ax3 = fig.add_subplot(gs[2, 0])
    im3 = ax3.imshow(log_pred, vmin=vminL, vmax=vmaxL, origin="upper")
    ax3.set_title("Reconstructed log-stress (prediction)")
    ax3.axis("off")
    plt.colorbar(im3, ax=ax3, fraction=0.046, pad=0.02)

    ax4 = fig.add_subplot(gs[2, 1])
    im4 = ax4.imshow(log_true, vmin=vminL, vmax=vmaxL, origin="upper")
    ax4.set_title("Reconstructed log-stress (groundtruth)")
    ax4.axis("off")
    plt.colorbar(im4, ax=ax4, fraction=0.046, pad=0.02)

    plt.tight_layout()

    if savepath is not None:
        os.makedirs(savepath, exist_ok=True)
        plt.savefig(os.path.join(savepath, f"{stem}.png"), dpi=180)

    plt.show()



def visualize_examples(local_unet, global_model, stems, root=".", n=5, seed=0, figpath=None):
    rng = np.random.default_rng(seed)
    stems = list(stems)
    rng.shuffle(stems)

    shown = 0
    for stem in stems:
        try:
            geom, y_true01, stats = load_one(stem, root=root)
            min_log, max_log, _ = _stats_to_minmax_eps(stats)

            y_pred01, mm_pred = eval_one(local_unet, global_model, geom)

            plot_geom_local_recon(
                stem=stem,
                geom=geom,
                y_true01=y_true01,
                y_pred01=y_pred01,
                minmax_true=[min_log, max_log],
                minmax_pred=mm_pred,
                savepath=figpath
            )

            shown += 1
            if shown >= n:
                break
        except Exception:
            continue

    if shown == 0:
        raise RuntimeError("No examples visualized.")
    if shown < n:
        print(f"Only displayed {shown} examples (requested {n}).")

In [None]:
# ====================================================================
# 1) Main test pipeline
# load models if paths are given - otherwise uses the as-trained model
# ====================================================================

def main_test_only(
    data_root=".",
    local_model_path=None,
    global_model_path=None,
    stems_test=None,
    input_shape=(128,128,1),
    n_vis=10,
    seed=1,
    do_visualize=True,
    do_metrics=True,
    max_items_for_metrics=None,
    figpath=None
):
    # 1) stems
    if stems_test is None:
        stems_test = list_dataset_stems(data_root)

    # 2) models (load if paths provided)
    local_model = None
    global_mm_model = None

    if local_model_path is not None:
        local_model = build_local_unet(input_shape=input_shape)
        local_model.load_weights(local_model_path)
    if global_model_path is not None:
        global_mm_model = build_global_minmax_model(input_shape=input_shape)
        global_mm_model.load_weights(global_model_path)

    # 3) metrics
    results = None
    if do_metrics:
        results = compute_test_metrics(
            local_model,
            global_mm_model,
            stems_test,
            root=data_root,
            max_items=max_items_for_metrics
        )
        print(
            f"[TEST METRICS] n={results['n']} | "
            f"local_mse={results['local_mse_mean']:.6g} ± {results['local_mse_std']:.6g} | "
            f"global_mse={results['global_mse_mean']:.6g} ± {results['global_mse_std']:.6g}"
        )

    # 4) visualization
    if do_visualize:
        visualize_examples(
            local_model,
            global_mm_model,
            stems_test,
            root=data_root,
            n=n_vis,
            seed=seed,
            figpath=figpath
        )

    return stems_test, results


# ------------------------------
# 2) Entry point
# ------------------------------
if __name__ == "__main__":
  if bundles is None:
    bundles, input_shape = build_bundles_only(
          data_root="./dataset/combined",
          batch=5,
          shuffle=6000,
          seed=123,
          train_frac=0.8,
          val_frac=0.1,
      )
  stems_test, metrics = main_test_only(
      data_root="./dataset/combined",
      ###--밑에는 이미 학습된 모델:
      local_model_path='./model/weights(01.20.26)/local/local_unet_ep0050.weights.h5',  ##---if just using the as-trained models, put None instead
      global_model_path='./model/weights(01.20.26)/global/global_minmax_ep0050.weights.h5',
      stems_test=bundles["ids"]["test"],
      n_vis=50, #- visualize할 test sample 개수
      seed=1,
      do_visualize=True,
      do_metrics=True,
      max_items_for_metrics=None,
      figpath='./model/out' ##--응력필드(target01/reconstructed) 이미지 save directory
      ##--처음 출력으로 나오는 global이 예측한 mm (=[min_log, max_log]는 따로 figpath에 같이 텍스트파일로 저장했습니다)
  )