# VAE(Variational Autoencoder) Training

- VAEモデルを単一カテゴリデータ(Normal)で学習します
- 推論には 「latent_dim」 が必要になります

### （事前）CUDA Version 確認

In [None]:
import torch, platform
print("torch version :", torch.__version__)
print("cuda in torch :", torch.version.cuda)
print("cuda available:", torch.cuda.is_available())


## 1. 初期設定

- 処理する画像サイズ 「IMAGE_SIZE」 と、検出対象とする画像カテゴリ 「CATEGORY」を指定してください
- 作成された「RUN_DIR」に、学習済モデルが作成されます

In [None]:
from pathlib import Path
from datetime import datetime, timezone, timedelta
import yaml, optuna, torch, shutil
from collections import defaultdict
from torchvision.transforms.functional import to_tensor

from torch import nn
import numpy as np 
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, f1_score
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from anomalib.data import Folder
from typing import Sequence
from collections.abc import Sequence

# -------- ユーザ設定項目 --------
IMAGE_SIZE  = 256                              # 画像サイズ(★変更対象)
IMAGE_THRESHOLD  = 0.5                         # 異常検出閾値(★変更対象)
CATEGORY    = "VisA_pipe_fryum"                # 検出対象のカテゴリ(★変更対象)
DATA_ROOT   = Path("/workspace/data")          # データフォルダ(train/test)
# -------------------------------

JST = timezone(timedelta(hours=9))
timestamp  = datetime.now(JST).strftime('%Y%m%d_%H%M%S')
OUTPUT_DIR = Path("/workspace/models") / "VAE" / CATEGORY
RUN_DIR    = OUTPUT_DIR / timestamp
(RUN_DIR / 'checkpoint').mkdir(parents=True, exist_ok=True)
(RUN_DIR / 'pytorch').mkdir(parents=True, exist_ok=True)
TEMP_DIR = RUN_DIR / "temp"
PARAM_DIR = RUN_DIR / "param"
LOG_DIR = RUN_DIR / "logs"
print("RUN DIR:", RUN_DIR)

# -------- DataModule定義 --------
from torchvision import transforms as T
from torchvision.transforms import functional as Ftv

def build_datamodule(batch_size: int = 32) -> Folder:
    """
    API 互換のため image_size / transform を渡さない
    Folder 内部で Albumentations → ToTensorV2 が適用され Torch.Tensor が返る
    """
    return Folder(
        name=CATEGORY,
        root=DATA_ROOT,
        normal_dir=f"train/{CATEGORY}",
        abnormal_dir=f"test/{CATEGORY}/anomaly",   # 評価用に異常も渡す（学習には使われない）
        normal_test_dir=f"test/{CATEGORY}/normal",
        train_batch_size=batch_size,
        eval_batch_size=batch_size,
        #train_num_workers=0,         # 共有メモリを使わない場合
        #eval_num_workers=0,          # 共有メモリを使わない場合
        extensions=(
            ".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp",
            ".JPG", ".JPEG", ".PNG", ".BMP", ".TIF", ".TIFF", ".WEBP",
        ),
    )
    
# -------- VAEモデル定義 --------
class ConvVAE(LightningModule):
    """
    * 入力: 3×IMAGE_SIZE×IMAGE_SIZE (0–1 float)
    * 潜在: latent_dim (Optuna で探索)
    """
    def __init__(self, latent_dim: int = 128, lr: float = 1e-3, img_size: int = 256):
        super().__init__()
        self.save_hyperparameters()

        chs = [3, 32, 64, 128]
        enc = []
        for cin, cout in zip(chs, chs[1:]):
            enc += [nn.Conv2d(cin, cout, 4, 2, 1), nn.ReLU(inplace=True)]
        self.encoder = nn.Sequential(*enc)

        # --- 入出力次元を動的に算出 ------------------------------------
        with torch.no_grad():
            dummy   = torch.zeros(1, 3, img_size, img_size)
            enc_out = self.encoder(dummy)
        c, h, w     = enc_out.shape[1:]
        flat_dim    = c * h * w

        self._enc_shape = (c, h, w)           # decode 時に使用

        self.mu     = nn.Linear(flat_dim, latent_dim)
        self.logvar = nn.Linear(flat_dim, latent_dim)

        self.fc_dec = nn.Linear(latent_dim, flat_dim)
        dec = [
            nn.ConvTranspose2d(c, 64, 4, 2, 1), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(32,  3, 4, 2, 1), nn.Sigmoid()
        ]
        self.decoder = nn.Sequential(*dec)
        self._img_size = img_size

    # ----- forward 系 -----
    def encode(self, x):
        h = self.encoder(x).flatten(1)
        return self.mu(h), self.logvar(h)

    @staticmethod
    def reparameterize(mu, logvar):
        std = torch.exp(0.5*logvar)
        return mu + torch.randn_like(std)*std

    def decode(self, z):
        c, h, w = self._enc_shape
        h = self.fc_dec(z).view(z.size(0), c, h, w)
        return self.decoder(h)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

    # ---------- 画像Tensor整形 ----------
    def _prepare_x(self, batch):
        # (以下コメントもコードも元のまま: 処理は同じ)
        if isinstance(batch, (list, tuple)):
            img = batch[0]
        elif isinstance(batch, dict):
            img = batch.get("image", next(iter(batch.values())))
        else:
            img = batch

        cls_name = img.__class__.__name__
        if cls_name.endswith("ImageBatch"):
            img = list(img)
        elif cls_name.endswith("ImageItem"):
            if hasattr(img, "tensor"):
                img = img.tensor
            elif hasattr(img, "data"):
                img = img.data
            elif hasattr(img, "image"):
                img = img.image

        if isinstance(img, Sequence) and not isinstance(img, (torch.Tensor, np.ndarray)):
            tensors = [self._prepare_x([sub]) for sub in img]
            return torch.cat(tensors, dim=0)

        if isinstance(img, torch.Tensor):
            x = img
        elif isinstance(img, np.ndarray):
            if img.dtype == object:
                x = torch.stack([to_tensor(el) for el in img], dim=0)
            else:
                x = torch.from_numpy(img)
        else:
            x = to_tensor(img)

        if x.dtype == torch.uint8:
            x = x.float().div_(255)
        elif x.is_floating_point() and x.max() > 1.5:
            x = x.div_(255)

        if x.ndim == 3:
            x = x.unsqueeze(0)
        elif x.ndim == 4 and x.shape[-1] in (1, 3):
            x = x.permute(0, 3, 1, 2).contiguous()

        if x.shape[-2:] != (self._img_size, self._img_size):
            x = F.interpolate(x, size=(self._img_size, self._img_size),
                              mode="bilinear", align_corners=False)

        return x.contiguous()
        
    def _vae_loss(self, x_hat, x, mu, logvar):
        recon = F.mse_loss(x_hat, x, reduction="mean")
        kl    = -0.5*torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
        return recon + kl*0.0005

    def _shared_step(self, batch, stage):
        x = self._prepare_x(batch)
        x_hat, mu, logvar = self(x)
        loss = self._vae_loss(x_hat, x, mu, logvar)
        self.log(f"{stage}_loss", loss, prog_bar=False)
        return loss, x, x_hat

    def training_step(self, batch, _):
        loss, _, _ = self._shared_step(batch, "train")
        return loss

    def validation_step(self, batch, _):
        # --- 画像を取り出す -------------------------------------------
        if isinstance(batch, dict):
            img_raw   = batch["image"]                 # 画像 Tensor
            path_list = batch.get("image_path", [""])  # list[str|Path] or str
            if isinstance(path_list, (str, Path)):
                path_list = [path_list]

            # 正常=0 / 異常=1 のラベルを作成
            y = torch.tensor(
                [1.0 if ("anomaly" in str(p) or "defect" in str(p)) else 0.0
                 for p in path_list],
                device=self.device,
                dtype=torch.float32,
            )

        elif isinstance(batch, (list, tuple)) and len(batch) == 2:
            img_raw, y = batch                         # (img, label)
            y = y.float().to(self.device)

        else:                                          # ラベル情報なし → 正常扱い
            img_raw = batch
            # len(img_raw) でバッチ長を取得（ImageBatch に対応）
            y = torch.zeros(len(img_raw), device=self.device)

        # (_shared_step は画像だけ渡せば OK)
        loss, x, x_hat = self._shared_step(img_raw, "val")

        errs = F.mse_loss(x_hat, x, reduction="none").mean([1, 2, 3])

        # --- バッファに蓄積 -------------------------------------------
        if not hasattr(self, "_val_buf"):
            self._val_buf = defaultdict(list)
        self._val_buf["errs"].append(errs.detach())
        self._val_buf["labels"].append(y.detach())

    def on_validation_epoch_end(self):
        if not hasattr(self, "_val_buf") or len(self._val_buf["errs"]) == 0:
            return

        errs   = torch.cat(self._val_buf["errs"])
        labels = torch.cat(self._val_buf["labels"]).view(-1)

        # --- AUROC 計算（正常=0, 異常=1） ----------------------------
        if labels.unique().numel() < 2:
            auroc = torch.tensor(0.5, device=self.device)
        else:
            auroc = torch.tensor(
                roc_auc_score(labels.cpu(), errs.cpu()),
                device=self.device
            )

        self.log("val_AUROC", auroc, prog_bar=True)

        # --- デバッグ用ログ -----------------------------------------
        n_norm = int((labels == 0).sum())
        n_anom = int((labels == 1).sum())
        self.print(                                # ← Lightning の rank_zero_only print
            f"[VAL] epoch={self.current_epoch}  normals={n_norm}  "
            f"anomalies={n_anom}  AUROC={auroc.item():.4f}"
        )

        self._val_buf.clear()

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

# ------- Tensor化 (共通関数) -------
def _to_tensor(img):
    """PIL / np.ndarray / torch.Tensor いずれも Tensor 化 (0-1)"""
    import numpy as np
    from torchvision.transforms.functional import to_tensor

    if hasattr(img, "tensor"):
        img = img.tensor
    elif hasattr(img, "image"):
        img = img.image
        
    if isinstance(img, torch.Tensor):
        return img.float() / (255 if img.dtype == torch.uint8 else 1)

    if isinstance(img, np.ndarray):
        return torch.from_numpy(img).float() / (255 if img.dtype != np.float32 else 1)

    return to_tensor(img)           # PIL など

# ------- 画像&ラベル取出 (共通関数) -------
def extract_xy(
    batch,
    device: torch.device | str = "cpu",
    *,
    return_counts: bool = False,
    return_paths:  bool = False,
):
    """
    ImageBatch / list[ImageItem] / dict のいずれが来ても以下を返す共通関数
        x : Tensor[B,3,H,W]  (float32 0–1, device 移動済み)
        y : Tensor[B]        (0:normal, 1:anomaly, float32, device 移動済み)
        [counts]             n_norm, n_anom   ※ return_counts=True の時
        [paths]              list[str|Path]   ※ return_paths =True の時

    Parameters
    ----------
    batch : dict | ImageBatch | Sequence
        DataLoader から受け取るバッチ
    device : torch.device | str
        `.to(device)` 先
    return_counts : bool, default False
        True のとき (n_norm, n_anom) を返す
    return_paths : bool, default False
        True のとき 画像パスの list を返す
    """
    # ------ anomalib >=0.8 形式の dict -------
    if isinstance(batch, dict):
        x = batch["image"]
        y = batch["label"].float()
        paths = batch.get("image_path", [""] * len(y))

    # ------ ImageBatch (list-like) または list[ImageItem] -------
    else:
        items: Sequence = list(batch)
        x = torch.stack([_to_tensor(getattr(it, "tensor", it)) for it in items])
        # ― ラベル取得 ―
        if hasattr(items[0], "label"):
            y = torch.tensor([float(it.label) for it in items])
        else:
            paths = [str(getattr(it, "image_path", "")) for it in items]
            y = torch.tensor([1.0 if ("anomaly" in p or "defect" in p) else 0.0 for p in paths])

    # ------ 後処理／戻り値組み立て -------
    x, y = x.to(device), y.to(device)
    out  = [x, y]

    if return_counts:
        n_norm = int((y == 0).sum())
        n_anom = int((y == 1).sum())
        out.extend([n_norm, n_anom])

    if return_paths:
        out.append(paths)

    return tuple(out) if len(out) > 1 else out[0]


## 2. ハイパーパラメータ探索 (Optuna)

- ここでは、ハイパーパラメータの自動探索を行います
- 探索方法は、以下より選択
  - 正常画像のみ利用した MSE の最小化(完全教師無し)
  - 正常画像と異常画像を両方利用する AUROC の最大化(教師ありプロセス)
- 探索を行わずに手動でハイパーパラメータを調整する場合は、ここをスキップして「3. 学習」へ

In [None]:
from pytorch_lightning import Trainer
import optuna, torch, yaml, types

from sklearn.metrics import roc_auc_score
import torch.nn.functional as F

# -------- GPU利用可否 --------
GPU_OK = torch.cuda.is_available()

# -------- 探索に異常データ利用切替 --------
'''
 True:  異常も使用   / 最適化指標: test AUROC / Validationなし
 False: 正常のみ使用 / 最適化指標: val MSE    / Validationあり
'''
EVAL_WITH_TEST_ANOM = True

# -------- サーチスペース(★変更対象) --------
N_TRIALS    = 1                           # 試行回数
LATENT_OPTS = [64, 128, 256]              # 潜在次元候補
LR_RANGE    = (1e-4, 5e-3)                # 学習率範囲

def _to_tensor(item):
    for attr in ("tensor", "data", "image"):
        t = getattr(item, attr, None)
        if t is not None:
            return t
    raise AttributeError("ImageItem に tensor/data/image 属性がありません")

def _evaluate_on_test(model, dm):
    model.eval()
    errs, labels = [], []
    n_norm_tot = n_anom_tot = 0

    with torch.no_grad():
        for b in dm.test_dataloader():
            x, y, n_norm, n_anom = extract_xy(b, model.device, return_counts=True)
            x = model._prepare_x(x) 
            xh, _, _ = model(x)
            errs .append(F.mse_loss(xh, x, reduction="none").mean([1,2,3]).cpu())
            labels.append(y.cpu())
            n_norm_tot += n_norm
            n_anom_tot += n_anom

    errs   = torch.cat(errs).numpy()
    labels = torch.cat(labels).numpy()
    auroc  = roc_auc_score(labels, errs)

    print(f"[TEST] normals={n_norm_tot}  anomalies={n_anom_tot}  AUROC={auroc:.4f}")
    return auroc
    
def _evaluate_on_val(model, dm):
    model.eval()
    mse = []; n_img = 0
    with torch.no_grad():
        for b in dm.val_dataloader():
            x, _, n_norm, _ = extract_xy(b, model.device, return_counts=True)   # val は正常のみ
            x = model._prepare_x(x) 
            xh, _, _ = model(x)
            mse.append(F.mse_loss(xh, x, reduction="none").mean([1,2,3]).cpu())
            n_img += n_norm
    mse_mean = torch.cat(mse).mean().item()
    print(f"[VAL] normals={n_img}  mean-MSE={mse_mean:.6f}")
    return mse_mean

# -------- Objective --------
def objective(trial: optuna.Trial):

    latent_dim = trial.suggest_categorical("latent_dim", LATENT_OPTS)
    lr         = trial.suggest_float("lr", *LR_RANGE, log=True)

    # ---------- DataModule ----------
    dm = build_datamodule(batch_size=32)
    dm.setup()

    # ---------- モデル ----------
    model = ConvVAE(latent_dim=latent_dim, lr=lr)

    # ---------- Trainer ----------
    if EVAL_WITH_TEST_ANOM:      # test で評価 → validation ループ不要
        trainer = Trainer(
            logger=False,
            default_root_dir=TEMP_DIR / "optuna",
            accelerator="gpu" if GPU_OK else "cpu",
            max_epochs=5,
            limit_val_batches=0,          # validation を回さない
            enable_progress_bar=False,
        )
        trainer.fit(model, train_dataloaders=dm.train_dataloader())
        metric = _evaluate_on_test(model, dm)     # AUROC (大きいほど良)
        return metric

    else:                          # val で評価 → validation ループ使用
        trainer = Trainer(
            logger=False,
            default_root_dir=TEMP_DIR / "optuna",
            accelerator="gpu" if GPU_OK else "cpu",
            max_epochs=5,
            enable_progress_bar=True,
        )
        trainer.fit(
            model,
            train_dataloaders = dm.train_dataloader(),
            val_dataloaders   = dm.val_dataloader(),
        )
        mse = _evaluate_on_val(model, dm)         # MSE (小さいほど良)
        print(f"[VAL] mean-MSE={mse:.6f}")
        return -mse                               # maximize(−MSE) ≡ minimize(MSE)

# -------- 探索実行 --------
TEMP_DIR.mkdir(parents=True, exist_ok=True)
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=N_TRIALS, show_progress_bar=False)

print("BEST :", study.best_params)
best_metric = study.best_value if EVAL_WITH_TEST_ANOM else -study.best_value
print("Best AUROC:", study.best_value)

# -------- 結果保存 --------
PARAM_DIR.mkdir(exist_ok=True)
yaml.safe_dump(
    dict(search_space=dict(latent_dim=LATENT_OPTS, lr=LR_RANGE,
                           eval_with_test_anom=EVAL_WITH_TEST_ANOM),
        ),
    open(PARAM_DIR / "search_space.yaml", "w")
)
yaml.safe_dump(study.best_params, open(PARAM_DIR / "best_params.yaml", "w"))
study.trials_dataframe().to_csv(PARAM_DIR / "trials.csv", index=False)

# ---------- TEMP_DIR のクリーンアップ ----------
# Debugする場合は、コメントアウトしてください
import shutil, gc
if TEMP_DIR.exists():
    # Windows でハンドルが残ると削除に失敗することがあるので、念のため GC
    gc.collect()
    shutil.rmtree(TEMP_DIR, ignore_errors=True)


## 3. 学習

- 「2. ハイパーパラメータ探索(Optuna)」 を行った場合は、保存されたベストパラメータで学習します
- 自動探索を行わない行場合は、ハイパーパラメータ 「MANUAL_PARAMS」を手動で調整してください

In [None]:
from pytorch_lightning.callbacks import LearningRateMonitor
import numpy as np, torch, yaml
from pytorch_lightning import Trainer

# -------- GPU利用可否 --------
GPU_OK = torch.cuda.is_available()

# -------- 学習設定 --------
MAX_EPOCHS = 1            # VAE はエポックを回すのため可変(★変更対象)

# -------- ★手動パラメータ（best_params.yaml が無い時に使用） --------
MANUAL_PARAMS = dict(
    latent_dim    = 128,
    learning_rate = 1e-3,
)

# -------- best_params.yaml 読み込み --------
best_params_path = PARAM_DIR / "best_params.yaml"

if best_params_path.exists():
    cfg = yaml.safe_load(open(best_params_path))
    print("▶ Using best_params.yaml:", cfg)
else:
    cfg = MANUAL_PARAMS.copy()
    print("▶ Using manual params:", cfg)

# -------- Model & DataModule --------
model = ConvVAE(
    latent_dim = cfg["latent_dim"],
    lr         = cfg.get("learning_rate", 1e-3),
    img_size   = IMAGE_SIZE,
)

dm = build_datamodule(batch_size=32)
dm.setup()

# -------- Lightning学習 (GPU/CPU フォールバック) --------
tb_logger = TensorBoardLogger(save_dir=RUN_DIR / "logs", name="final")

used_gpu = None
for use_gpu in (GPU_OK, False):
    try:
        trainer = Trainer(
            default_root_dir    = TEMP_DIR / "train",
            accelerator         = "gpu" if use_gpu else "cpu",
            max_epochs          = MAX_EPOCHS,
            log_every_n_steps   = 1,
            enable_progress_bar = False,
            logger              = tb_logger,
            callbacks           = [LearningRateMonitor(logging_interval="step")],
        )
        # ← datamodule を渡さず DataLoader を直接渡すと
        #    Folder に起きていた prepare_data 判定バグを回避できる
        trainer.fit(
            model,
            train_dataloaders = dm.train_dataloader(),
            val_dataloaders   = dm.val_dataloader(),
        )
        used_gpu = use_gpu
        break
    except RuntimeError as e:
        if "cudaGetDeviceCount" in str(e):
            print("⚠️ CUDA 初期化エラー → CPU でリトライ")
            continue
        raise

if not used_gpu:
    print("⚠️ GPU 使用不可 → CPU で学習完了")

# -------- Checkpoint 保存 --------
ckpt_path = RUN_DIR / "checkpoint" / "best.ckpt"
trainer.save_checkpoint(ckpt_path)
print("✓ Checkpoint :", ckpt_path)

# -------- 推論用モデル .pth 保存 -------- 
state_path = RUN_DIR / "pytorch" / "model.pth"
torch.save(model.state_dict(), state_path)
print("✓ Weights :", state_path)

# ===== raw reconstruction error min/max 計算 =====
import torch.nn.functional as F

print("Computing image_min and image_max ...")
model.eval()
scores = []
for batch in dm.train_dataloader():
    x = model._prepare_x(batch).to(model.device)  # shape (B,3,IMAGE_SIZE,IMAGE_SIZE)
    x_hat, _, _ = model(x)
    errs = F.mse_loss(x_hat, x, reduction="none").mean(dim=[1,2,3])
    scores.extend(errs.detach().cpu().tolist())

image_min = float(min(scores))
image_max = float(max(scores))
print(f"Computed image_min: {image_min}, image_max: {image_max}")

# ------- .pth に含まれないメタ情報を出力(json) -------
import json

meta_path = RUN_DIR / "pytorch" / "meta.json"
meta = {
    "latent_dim"    : cfg["latent_dim"],
    "learning_rate" : cfg.get("learning_rate", 1e-3),
    "image_size"    : IMAGE_SIZE,
    "weights_pth" : state_path.name,
    "image_threshold"     : IMAGE_THRESHOLD,  # 手動設定
    "image_threshold_auto": IMAGE_THRESHOLD,  # Test後に上書き用(任意)
    "raw_image_min"       : image_min,
    "raw_image_max"       : image_max,    
    # 再構成誤差しきい値 (ConvVAEクラスに recon_threshold 等を持たせている場合)
    #"recon_threshold": float(getattr(model, "recon_threshold", 0.0)),    
}

with open(meta_path, "w", encoding="utf-8") as f:
    json.dump(meta, f, indent=2, ensure_ascii=False)

print("✓ Meta JSON :", meta_path)

# ---------- TEMP_DIR のクリーンアップ ----------
# Debugする場合は、コメントアウトしてください。
import shutil, gc
if TEMP_DIR.exists():
    # Windows でハンドルが残ると削除に失敗することがあるので、念のため GC
    gc.collect()
    shutil.rmtree(TEMP_DIR, ignore_errors=True)


## 4. 検出性能テスト

In [None]:
from pathlib import Path
import yaml, torch, pandas as pd, shutil
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, f1_score

from torchvision.utils import save_image
import torchvision.transforms.functional as Ftv
from torchvision.utils import make_grid
import cv2

# ---------- パス・デバイス ----------
state_path      = Path(RUN_DIR / "pytorch" / "model.pth")
best_param_path = PARAM_DIR / "best_params.yaml"
result_root     = RUN_DIR / "test_result"
device          = "cuda" if torch.cuda.is_available() else "cpu"

# ---------- latent_dim 取得 --------------------------
# 1) cfg["latent_dim"]  2) best_params.yaml  3) .pth から推定
if cfg.get("latent_dim") is not None:
    latent_dim = int(cfg["latent_dim"])
    src_info   = "cfg['latent_dim']"

elif best_param_path.exists():
    _bp        = yaml.safe_load(open(best_param_path))
    latent_dim = int(_bp["latent_dim"])
    src_info   = "best_params.yaml"

else:
    sd = torch.load(state_path, map_location="cpu")
    for k in ("mu.weight", "encoder.mu.weight", "mu.linear.weight"):
        if k in sd:
            latent_dim = sd[k].shape[0]          # out_features
            src_info   = f"model.pth ({k})"
            break
    else:
        raise RuntimeError("❌ latent_dim を特定できません")

print(f"▶ latent_dim = {latent_dim}  (from {src_info})")

# ---------- モデル ----------
model = ConvVAE(latent_dim=latent_dim).to(device)
model.load_state_dict(torch.load(state_path, map_location=device), strict=True)
model.eval()

# ---------- DataModule ----------
dm = build_datamodule(); dm.setup("test")

img_dir = result_root / "images"     
(img_dir / "normal").mkdir(parents=True, exist_ok=True)
(img_dir / "anomaly").mkdir(parents=True, exist_ok=True)

# ---------- テスト全体の min / max を取る -----------------
all_maps = []
with torch.no_grad():
    for batch in dm.test_dataloader():
        x, _ = extract_xy(batch, device)
        x  = model._prepare_x(x)
        xh, _, _ = model(x)
        diff = (x - xh).abs().mean(1, keepdim=True)   # (B,1,H,W)
        all_maps.append(diff.cpu())

err_stack = torch.cat(all_maps, dim=0)     # (N,1,H,W)
g_min, g_max = err_stack.min(), err_stack.max()
rng = (g_max - g_min).clamp_min(1e-6)

errs, labels, paths = [], [], []

# ---------- 再度回して PNG を作る --------------------------
alpha = 0.5                     # Overlay 係数
with torch.no_grad():
    for batch in dm.test_dataloader():
        x, y, pths = extract_xy(batch, device, return_paths=True)
        x  = model._prepare_x(x)
        xh, _, _  = model(x)
        diff = (x - xh).abs().mean(1, keepdim=True)   # (B,1,H,W)

        for img_in, img_out, dmap, fname, gt in zip(
                x.cpu(), xh.cpu(), diff.cpu(), pths, y.int()):

            # ------ 0-1 正規化（テスト全体で共通） ----------
            d_norm = ((dmap - g_min) / rng).clamp(0, 1)          # (1,H,W)

            # ------ Gaussian Blur (5×5) ----------
            d_np = d_norm.squeeze(0).numpy()
            d_np = cv2.GaussianBlur(d_np, (5, 5), 0)

            # ------ Jet カラーマップ → Tensor(3,H,W) ----------
            cmap = cv2.applyColorMap((d_np * 255).astype(np.uint8),
                                     cv2.COLORMAP_JET)           # BGR, H,W,3
            cmap = torch.from_numpy(cmap[:, :, ::-1].copy())      # →RGB
            cmap = cmap.permute(2, 0, 1).float() / 255.0          # 3,H,W

            # ------ α Blending ----------
            overlay = alpha * cmap + (1 - alpha) * img_in.cpu()

            # ------ 横連結：Input | Overlay | Recon ----------
            trio = make_grid([img_in.cpu(), overlay, img_out.cpu()], nrow=3)

            subdir = "anomaly" if gt.item() == 1 else "normal"
            save_image(trio, img_dir / subdir / Path(fname).name) 

        # 指標用バッファ 
        errs  .append(F.mse_loss(xh, x, reduction="none").mean([1, 2, 3]).cpu())
        labels.append(y.cpu())
        paths.extend([Path(p).name for p in pths])

errs   = torch.cat(errs)
labels = torch.cat(labels)
auroc  = roc_auc_score(labels.numpy(), errs.numpy())
f1     = f1_score(labels.numpy(), (errs > errs.mean()).numpy())  # 簡易 F1

# testで自動決定される image_threshold 
threshold = errs.mean().item()
print(f"\n── image_threshold used: {threshold:.4f}\n")
# ★ meta.json に反映して TensorRT側へ連携することも可能 (今回はなし)

# ---------- CSV 生成 ----------
records = []
for file, score, lab in zip(paths, errs.numpy(), labels.numpy()):
    pred = "anomaly" if score > threshold else "normal"
    gt   = "anomaly" if lab == 1 else "normal"
    print(f"{file:40s} | score={score:7.4f} | pred={pred:7s} | label={gt}")
    records.append(dict(file=file, score=float(score), pred=pred, label=gt))

# ---------- CSV 保存 ----------
result_root.mkdir(parents=True, exist_ok=True)
csv_path = result_root / "predictions.csv"
pd.DataFrame(records).to_csv(csv_path, index=False)
print(f"\n✓ predictions.csv saved to {csv_path}\n")

# ---------- meta.json 更新 (image_threshold_auto) ----------
meta_path = Path(RUN_DIR) / "pytorch" / "meta.json"
with open(meta_path, "r") as f:
    meta = json.load(f)
meta["image_threshold_auto"] = float(threshold)
with open(meta_path, "w") as f:
    json.dump(meta, f, indent=2, ensure_ascii=False)
print(f"✓ threshold_auto updated to {meta_path}")

# ---------- 結果表示 ----------
print("\n==========  EVALUATION  ==========")
print(f"Images tested : {len(dm.test_data)}")
print(f"AUROC         : {auroc:7.4f}")
print(f"Best F1       : {f1:7.4f}")
print("===================================\n")
