In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import random
import time
from tqdm import tqdm  # プログレスバー表示用

# <<< torchvision をインポート >>>
import torchvision
import torchvision.transforms as transforms


# --- 設定クラス ---
class Config:
    # 使用デバイス (GPUが利用可能ならGPU、そうでなければCPU)
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --- 格子関連 ---
    # <<< MNISTに合わせる場合は28x28に近づけるか、リサイズを前提とする >>>
    LATTICE_SIZE = 32  # MNIST(28x28)をリサイズして使う想定
    NUM_CELL_TYPES = 3  # 0: 培地, 1: 背景細胞(タイプ1), 2: 数字細胞(タイプ2)
    # <<< MNIST数字形成用のタイプID >>>
    MNIST_TARGET_CELL_TYPE = 2

    # --- 細胞関連 ---
    # Neural Hamiltonianでのパディングやボリューム計算のため、想定される細胞の最大数を定義
    # <<< MNISTデータを使う場合、厳密な細胞数は可変になる。ここでは最大値として定義 >>>
    MAX_CELLS = 15  # 細胞の最大数 + 1 (培地用ID=0を含む) - これはNHの都合上の設定
    TARGET_VOLUME = (
        25  # 細胞の目標体積 (V*) - MNISTターゲットとは直接連動しない可能性あり
    )
    LAMBDA_VOLUME = 2.0  # 体積制約項の強度 (λ_v)

    # --- 接触エネルギー J(type1, type2) ---
    # タイプ1(背景)、タイプ2(数字)がそれぞれ凝集し、互いに反発するような設定例
    J = torch.tensor(
        [
            [0, 4, 4],  # 培地 - 培地, 培地 - 背景, 培地 - 数字
            [4, 2, 8],  # 背景 - 培地, 背景 - 背景, 背景 - 数字 (背景と数字は反発)
            [4, 8, 2],  # 数字 - 培地, 数字 - 背景, 数字 - 数字
        ],
        dtype=torch.float32,
        device=DEVICE,
    )

    # --- Neural Hamiltonian (NH) アーキテクチャ関連 ---
    NH_EMBED_DIM = 16
    NH_HIDDEN_DIMS = [16, 32]
    NH_KERNEL_SIZE = 3
    NH_POOL_RATES = [2, 1]
    NH_MLP_DIM = 64

    # --- 学習関連 ---
    BATCH_SIZE = 8
    LEARNING_RATE = 1e-4
    TRAINING_STEPS = 5000
    MCMC_STEPS_PER_UPDATE = 50
    MCMC_PARALLEL_FLIPS = 100
    MCMC_TEMP = 1.0
    REGULARIZATION_LAMBDA = 0.01
    PERSISTENT_CHAIN_PROB_RESET = 0.05
    CLOSURE_WEIGHT_ANALYTICAL = 1.0  # 解析的ハミルトニアンの寄与を調整
    CLOSURE_WEIGHT_NEURAL = (
        5.0  # ニューラルハミルトニアンの寄与を強める (形状学習のため)
    )

    # --- MNISTデータ関連 ---
    MNIST_DATA_PATH = "./mnist_data"  # MNISTデータの保存先
    MNIST_THRESHOLD = 0.5  # MNIST画像を二値化する際の閾値

    # --- 可視化関連 ---
    VISUALIZE_EVERY = 100
    NUM_VIZ_SAMPLES = 4


# 設定クラスのインスタンスを作成
cfg = Config()

# --- MNISTデータセットの準備 ---
# 画像の前処理: 指定サイズへのリサイズ、テンソル変換、(閾値処理は後で)
transform = transforms.Compose(
    [
        transforms.Resize((cfg.LATTICE_SIZE, cfg.LATTICE_SIZE)),
        transforms.ToTensor(),  # 画像を [0, 1] の範囲のテンソルに変換
    ]
)

# MNIST訓練データセットをロード（なければダウンロード）
try:
    trainset = torchvision.datasets.MNIST(
        root=cfg.MNIST_DATA_PATH, train=True, download=True, transform=transform
    )
    # データローダーを作成 (バッチ処理とシャッフルを行う)
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=cfg.BATCH_SIZE, shuffle=True, num_workers=2
    )
    # DataLoaderをイテレータに変換して、繰り返し使えるようにする
    train_iter = iter(trainloader)
    print(f"MNIST dataset loaded/downloaded successfully from {cfg.MNIST_DATA_PATH}")
except Exception as e:
    print(f"Error loading MNIST dataset: {e}")
    print(
        "Proceeding without MNIST data. `get_training_batch` will generate random data."
    )
    train_iter = None  # MNISTが使えない場合は None に設定


# --- ヘルパー関数 (変更なし) ---
def get_neighbors(lattice):
    padded_lattice = F.pad(lattice.float(), (1, 1, 1, 1), mode="circular").unsqueeze(1)
    neighbors = []
    for dx, dy in [(0, 1), (0, -1), (1, 0), (-1, 0)]:
        neighbors.append(torch.roll(lattice, shifts=(dx, dy), dims=(-2, -1)))
    neighbor_ids = torch.stack(neighbors, dim=1)
    return neighbor_ids.long()


def get_boundary_sites(lattice, neighbor_ids):
    batch_size, H, W = lattice.shape
    lattice_expanded = lattice.unsqueeze(1).expand(-1, 4, -1, -1)
    is_different = lattice_expanded != neighbor_ids
    boundary_mask = torch.any(is_different, dim=1)
    boundary_mask = boundary_mask & (lattice != 0)
    boundary_indices = [torch.nonzero(mask, as_tuple=False) for mask in boundary_mask]
    return boundary_mask, boundary_indices


def get_cell_volumes(lattice, max_cells):
    one_hot = (
        F.one_hot(lattice.long(), num_classes=max_cells).permute(0, 3, 1, 2).float()
    )
    volumes = one_hot.sum(dim=[2, 3])
    return volumes


# --- 解析的ハミルトニアン (Analytical Hamiltonian) ---
# <<< MNISTターゲットの場合、細胞IDとタイプのマッピングが重要になる >>>
class AnalyticalHamiltonian(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.cfg = config
        self.J = config.J.to(config.DEVICE)
        self.target_volume = config.TARGET_VOLUME
        self.lambda_volume = config.LAMBDA_VOLUME
        self.max_cells = config.MAX_CELLS

    # <<< このマッピングは、MNISTターゲット生成方法と密接に関連 >>>
    def map_ids_to_types(self, lattice):
        # MNISTターゲットの場合、特定のIDが数字(タイプ2)、他が背景(タイプ1)を表す
        # このデモでは、latticeが既にターゲット状態を表していると仮定し、
        # ID=MNIST_TARGET_CELL_TYPE を タイプ2、他の非ゼロIDを タイプ1 とする簡易マッピング。
        # 実際には、初期状態やMCMCサンプルに対して動的にマッピングする必要がある。
        types = torch.zeros_like(lattice, dtype=torch.long, device=lattice.device)
        types[(lattice != 0) & (lattice != self.cfg.MNIST_TARGET_CELL_TYPE)] = (
            1  # 背景タイプ
        )
        types[lattice == self.cfg.MNIST_TARGET_CELL_TYPE] = 2  # 数字タイプ
        return types.long()

    def forward(self, lattice):
        batch_size, H, W = lattice.shape
        # 1. 接触エネルギー計算
        neighbor_ids = get_neighbors(lattice)
        lattice_expanded = lattice.unsqueeze(1).expand(-1, 4, -1, -1)
        lattice_types = self.map_ids_to_types(lattice)
        neighbor_types = self.map_ids_to_types(neighbor_ids.reshape(-1, H, W)).view(
            batch_size, 4, H, W
        )
        lattice_types_expanded = lattice_types.unsqueeze(1).expand(-1, 4, -1, -1)
        contact_J_values = self.J[lattice_types_expanded, neighbor_types]
        is_different_id = lattice_expanded != neighbor_ids
        contact_energy = (
            torch.sum(contact_J_values * is_different_id, dim=(1, 2, 3)) / 2.0
        )

        # 2. 体積制約エネルギー計算 (ここでは細胞ごとの区別を簡単にする)
        #    MNISTターゲット学習では、体積制約の役割は限定的になる可能性がある。
        cell_volumes = get_cell_volumes(lattice, self.max_cells)
        volume_penalty = torch.zeros(batch_size, device=lattice.device)
        for b in range(batch_size):
            present_cell_ids = torch.unique(lattice[b])
            present_cell_ids = present_cell_ids[present_cell_ids > 0]
            if len(present_cell_ids) > 0:
                volumes_b = cell_volumes[b, present_cell_ids.long()]
                # 平均体積に対するペナルティ（簡易的）
                avg_volume = volumes_b.mean() if volumes_b.numel() > 0 else 0
                penalty_b = self.lambda_volume * (avg_volume - self.target_volume) ** 2
                # penalty_b = self.lambda_volume * (volumes_b - self.target_volume)**2
                volume_penalty[b] = penalty_b  # .sum()

        total_energy = contact_energy + volume_penalty
        return total_energy


# --- Neural Hamiltonian コンポーネント (変更なし) ---
class NHLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super().__init__()
        padding = kernel_size // 2
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding),
            nn.SiLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding),
            nn.SiLU(),
        )

    def forward(self, features):
        return self.conv_block(features)


class NeuralHamiltonian(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.cfg = config
        self.max_cells = config.MAX_CELLS
        self.lattice_size = config.LATTICE_SIZE
        self.initial_embed = nn.Conv2d(
            config.MAX_CELLS, config.NH_EMBED_DIM, kernel_size=1
        )
        nh_layers = []
        current_dim = config.NH_EMBED_DIM
        current_size = config.LATTICE_SIZE
        for i, (hidden_dim, pool_rate) in enumerate(
            zip(config.NH_HIDDEN_DIMS, config.NH_POOL_RATES)
        ):
            nh_layers.append(NHLayer(current_dim, hidden_dim, config.NH_KERNEL_SIZE))
            if pool_rate > 1:
                nh_layers.append(nn.MaxPool2d(pool_rate))
                current_size //= pool_rate
            current_dim = hidden_dim
        self.nh_network = nn.Sequential(*nh_layers)
        final_conv_output_size = current_dim * current_size * current_size
        self.mlp = nn.Sequential(
            nn.Linear(final_conv_output_size, config.NH_MLP_DIM),
            nn.SiLU(),
            nn.Linear(config.NH_MLP_DIM, 1),
        )

    def forward(self, lattice):
        batch_size, H, W = lattice.shape
        lattice_clamped = torch.clamp(lattice.long(), 0, self.max_cells - 1)
        one_hot = (
            F.one_hot(lattice_clamped, num_classes=self.max_cells)
            .permute(0, 3, 1, 2)
            .float()
        )
        embedded = self.initial_embed(one_hot)
        features = self.nh_network(embedded)
        # === .reshape() を使用 ===
        flat_features = features.reshape(batch_size, -1)
        # =======================
        energy = self.mlp(flat_features).squeeze(-1)
        return energy


# --- 統合モデル (Closure) (変更なし) ---
class NeuralCPM(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.cfg = config
        self.analytical_h = AnalyticalHamiltonian(config)
        self.neural_h = NeuralHamiltonian(config)
        self.w_s = config.CLOSURE_WEIGHT_ANALYTICAL
        self.w_nn = config.CLOSURE_WEIGHT_NEURAL

    def forward(self, lattice):
        e_analytical = self.analytical_h(lattice)
        e_neural = self.neural_h(lattice)
        total_energy = self.w_s * e_analytical + self.w_nn * e_neural
        return total_energy


# --- 近似サンプラー (ApproxPCPM - 論文 Algorithm 2) (変更なし) ---
def approx_pcpm_sampler(model, initial_states, num_steps, num_parallel_flips, temp):
    current_states = initial_states.clone()
    batch_size, H, W = current_states.shape
    for _ in range(num_steps):
        with torch.no_grad():
            current_energy = model(current_states)
        neighbor_ids = get_neighbors(current_states)
        boundary_mask, boundary_indices_list = get_boundary_sites(
            current_states, neighbor_ids
        )
        proposed_flips_i = []
        proposed_flips_j = []
        batch_indices_for_flips = []
        for b in range(batch_size):
            boundary_idx_b = boundary_indices_list[b]
            num_boundary = boundary_idx_b.shape[0]
            if num_boundary == 0:
                continue
            sample_indices = torch.randint(
                0, num_boundary, (num_parallel_flips,), device=cfg.DEVICE
            )
            sites_to_flip = boundary_idx_b[sample_indices]
            y_coords, x_coords = sites_to_flip[:, 0], sites_to_flip[:, 1]
            neighbors_at_sites = neighbor_ids[b, :, y_coords, x_coords]
            current_id_at_sites = current_states[b, y_coords, x_coords]
            new_cell_ids = torch.zeros(
                num_parallel_flips, dtype=torch.long, device=cfg.DEVICE
            )
            for p in range(num_parallel_flips):
                site_id = current_id_at_sites[p]
                neigh_ids = neighbors_at_sites[:, p]
                valid_neigh = neigh_ids[neigh_ids != site_id]
                if len(valid_neigh) > 0:
                    chosen_neigh_idx = torch.randint(0, len(valid_neigh), (1,))
                    new_cell_ids[p] = valid_neigh[chosen_neigh_idx]
                else:
                    new_cell_ids[p] = site_id
            proposed_flips_i.append(sites_to_flip)
            proposed_flips_j.append(new_cell_ids)
            batch_indices_for_flips.extend([b] * num_parallel_flips)
        if not proposed_flips_i:
            continue
        all_sites_i = torch.cat(proposed_flips_i, dim=0)
        all_new_ids_j = torch.cat(proposed_flips_j, dim=0)
        all_batch_idx = torch.tensor(batch_indices_for_flips, device=cfg.DEVICE)
        proposed_states_batch = []
        original_energies_batch = []
        total_flips_processed = 0
        # バッチごとにP個の提案状態を作成するループを修正
        idx_offset = 0
        for b in range(batch_size):
            if b >= len(proposed_flips_i):
                continue
            num_flips_b = proposed_flips_i[b].shape[0]
            if num_flips_b == 0:
                continue  # フリップ候補がない場合

            sites_i_b = proposed_flips_i[b]
            new_ids_j_b = proposed_flips_j[b]

            states_b = current_states[b].unsqueeze(0).repeat(num_flips_b, 1, 1)
            y_coords, x_coords = sites_i_b[:, 0], sites_i_b[:, 1]
            states_b[torch.arange(num_flips_b), y_coords, x_coords] = new_ids_j_b
            proposed_states_batch.append(states_b)
            original_energies_batch.append(current_energy[b].repeat(num_flips_b))
            idx_offset += num_flips_b

        if not proposed_states_batch:
            continue

        all_proposed_states = torch.cat(proposed_states_batch, dim=0)
        all_original_energies = torch.cat(original_energies_batch, dim=0)
        with torch.no_grad():
            all_proposed_energies = model(all_proposed_states)
        delta_energies = all_proposed_energies - all_original_energies
        accept_prob = torch.exp(-delta_energies / temp)
        accept_prob = torch.clamp(accept_prob, max=1.0)
        random_uniform = torch.rand_like(accept_prob)
        accepted_flips_mask = random_uniform < accept_prob
        accepted_batch_idx = all_batch_idx[accepted_flips_mask]
        accepted_sites_i = all_sites_i[accepted_flips_mask]
        accepted_new_ids_j = all_new_ids_j[accepted_flips_mask]
        if accepted_sites_i.shape[0] > 0:
            y_coords_acc, x_coords_acc = accepted_sites_i[:, 0], accepted_sites_i[:, 1]
            current_states[accepted_batch_idx, y_coords_acc, x_coords_acc] = (
                accepted_new_ids_j
            )
    return current_states.detach()


# --- データ生成/取得関数 ---
def generate_initial_config(batch_size, lattice_size, num_cells):
    """
    ランダムな初期状態を生成する (変更なし、ただし細胞数はNHの都合)。
    """
    lattices = torch.zeros(
        batch_size, lattice_size, lattice_size, dtype=torch.long, device=cfg.DEVICE
    )
    center = lattice_size // 2
    radius = lattice_size // 4
    for b in range(batch_size):
        # MCMCサンプルは多様な細胞IDを持つ可能性があるため、初期状態も多様にする
        temp_cell_ids = list(range(1, num_cells + 1))  # 1からnum_cellsまでのID
        random.shuffle(temp_cell_ids)
        cell_id_idx = 0
        for y in range(lattice_size):
            for x in range(lattice_size):
                if (y - center) ** 2 + (x - center) ** 2 < radius**2:
                    if random.random() < 0.1:
                        if cell_id_idx < len(temp_cell_ids):
                            lattices[b, y, x] = temp_cell_ids[cell_id_idx]
                            cell_id_idx += 1
                        # else: keep as medium (0)
    return lattices  # タイプ情報は返さない


def create_target_lattice_from_mnist(images, threshold, target_cell_type):
    """
    MNIST画像バッチからターゲット格子状態を生成する。
    Args:
        images (torch.Tensor): MNIST画像のバッチ (B, 1, H, W), 値は [0, 1]
        threshold (float): 二値化の閾値
        target_cell_type (int): 数字部分に割り当てる細胞タイプID
    Returns:
        torch.Tensor: ターゲット格子状態 (B, H, W)
                      数字部分 = target_cell_type, 背景 = 1 (タイプ1), 培地はなし
    """
    batch_size, _, H, W = images.shape
    # 閾値処理で二値化 (数字部分が True になるマスク)
    digit_mask = (images > threshold).squeeze(1)  # (B, H, W)

    # ターゲット格子を作成
    target_lattice = torch.ones(
        batch_size, H, W, dtype=torch.long, device=images.device
    )  # まず全体を背景(タイプ1)で埋める
    target_lattice[digit_mask] = (
        target_cell_type  # 数字部分をターゲットタイプIDで上書き
    )

    return target_lattice


def get_training_batch(current_iter, dataloader):
    """
    MNISTデータローダーからバッチを取得し、ターゲット格子を生成する。
    データローダーが終端に達したら、再度イテレータを作成する。
    """
    global train_iter  # グローバル変数を更新するため宣言
    try:
        # データローダーから次のバッチを取得
        images, _ = next(current_iter)  # ラベルは使用しない
    except StopIteration:
        # データローダーが終端に達した場合、再度イテレータを作成
        print("DataLoader reached end, restarting iterator.")
        train_iter = iter(dataloader)
        images, _ = next(train_iter)

    images = images.to(cfg.DEVICE)
    # MNIST画像からターゲット格子を生成
    target_lattice = create_target_lattice_from_mnist(
        images, cfg.MNIST_THRESHOLD, cfg.MNIST_TARGET_CELL_TYPE
    )
    return target_lattice


# --- 学習ループ ---
def train():
    global train_iter  # get_training_batch で更新するためグローバル宣言
    if train_iter is None:
        print("MNIST DataLoader not available. Training cannot proceed.")
        return

    print(f"使用デバイス: {cfg.DEVICE}")
    model = NeuralCPM(cfg).to(cfg.DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=cfg.LEARNING_RATE)

    # MCMCチェーンはランダムな初期状態で初期化
    mcmc_chains = generate_initial_config(
        cfg.BATCH_SIZE, cfg.LATTICE_SIZE, cfg.MAX_CELLS - 1
    ).to(cfg.DEVICE)

    losses = []
    fig, axes = plt.subplots(
        2, cfg.NUM_VIZ_SAMPLES, figsize=(cfg.NUM_VIZ_SAMPLES * 3, 6)
    )  # 2行にしてターゲットも表示
    plt.ion()

    start_time = time.time()
    for step in tqdm(range(cfg.TRAINING_STEPS), desc="Training Steps"):
        # 1. ポジティブサンプルの取得 (MNISTデータから生成)
        x_pos = get_training_batch(train_iter, trainloader)

        # 2. ネガティブサンプルの取得 (MCMC)
        x_neg = approx_pcpm_sampler(
            model,
            mcmc_chains,
            cfg.MCMC_STEPS_PER_UPDATE,
            cfg.MCMC_PARALLEL_FLIPS,
            cfg.MCMC_TEMP,
        )
        mcmc_chains = x_neg

        # Optional: チェーンリセット
        reset_mask = (
            torch.rand(cfg.BATCH_SIZE, device=cfg.DEVICE)
            < cfg.PERSISTENT_CHAIN_PROB_RESET
        )
        num_reset = reset_mask.sum().item()
        if num_reset > 0:
            new_chains = generate_initial_config(
                num_reset, cfg.LATTICE_SIZE, cfg.MAX_CELLS - 1
            ).to(cfg.DEVICE)
            mcmc_chains[reset_mask] = new_chains

        # 3. エネルギー計算
        e_pos = model(x_pos)
        e_neg = model(x_neg.detach())

        # 4. 損失計算
        loss_mle = e_pos.mean() - e_neg.mean()
        loss_reg = cfg.REGULARIZATION_LAMBDA * (e_pos**2 + e_neg**2).mean()
        loss = loss_mle + loss_reg

        # 5. Backpropagation and Optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        # --- 可視化 ---
        if step % cfg.VISUALIZE_EVERY == 0 or step == cfg.TRAINING_STEPS - 1:
            print(
                f"\nステップ: {step}, 損失: {loss.item():.4f}, E_pos 平均: {e_pos.mean().item():.4f}, E_neg 平均: {e_neg.mean().item():.4f}"
            )

            viz_samples_neg = mcmc_chains[: cfg.NUM_VIZ_SAMPLES].cpu().numpy()
            viz_samples_pos = (
                x_pos[: cfg.NUM_VIZ_SAMPLES].cpu().numpy()
            )  # 対応するポジティブサンプルも表示

            for i in range(cfg.NUM_VIZ_SAMPLES):
                # ポジティブサンプル（ターゲット）の表示 (上段)
                ax_pos = axes[0, i] if cfg.NUM_VIZ_SAMPLES > 1 else axes[0]
                ax_pos.clear()
                cmap_pos = plt.cm.get_cmap(
                    "gray", 3
                )  # ターゲットはシンプルに表示 (0:培地(未使用), 1:背景, 2:数字)
                ax_pos.imshow(
                    viz_samples_pos[i],
                    cmap=cmap_pos,
                    vmin=0,
                    vmax=cfg.NUM_CELL_TYPES - 1,
                    interpolation="nearest",
                )
                ax_pos.set_title(f"Target {i}")
                ax_pos.axis("off")

                # ネガティブサンプル（MCMC）の表示 (下段)
                ax_neg = axes[1, i] if cfg.NUM_VIZ_SAMPLES > 1 else axes[1]
                ax_neg.clear()
                cmap_mcmc = plt.cm.get_cmap("tab20", cfg.MAX_CELLS)
                ax_neg.imshow(
                    viz_samples_neg[i],
                    cmap=cmap_mcmc,
                    vmin=0,
                    vmax=cfg.MAX_CELLS - 1,
                    interpolation="nearest",
                )
                ax_neg.set_title(f"MCMC {i}")
                ax_neg.axis("off")

            fig.suptitle(f"ターゲット vs MCMCサンプル @ ステップ {step}")
            plt.tight_layout(rect=[0, 0.03, 1, 0.95])
            plt.draw()
            plt.pause(0.1)

    end_time = time.time()
    print(f"\n学習終了 ({end_time - start_time:.2f}秒)")

    plt.ioff()
    plt.figure()
    plt.plot(losses)
    plt.title("学習損失の推移")
    plt.xlabel("学習ステップ")
    plt.ylabel("損失")
    plt.grid(True)
    plt.show()


# --- メイン実行ブロック ---
if __name__ == "__main__":
    train()