In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

# 🔹 1. 実波形データセットの定義
class WaveformDataset(Dataset):
    def __init__(self, waveform_data):  # waveform_data: [N, 5, 1000, 7]
        self.data = waveform_data

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return self.data[idx]

In [None]:
# =======================================
# ✅ 正しい構成案：Forward PINN + Inverse CNN
# =======================================

import torch
import torch.nn as nn
import torch.nn.functional as F

# 1. Velocity Map Decoder (CNN) : 波形 → 速度マップ
class InverseVelocityDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv3d(1, 16, (3, 5, 3), padding=1)
        self.conv2 = nn.Conv3d(16, 32, 3, padding=1)
        self.pool = nn.AdaptiveAvgPool3d((1, 10, 10))
        self.fc = nn.Linear(32 * 10 * 10, 1024)
        self.out = nn.Linear(1024, 70 * 70)

    def forward(self, x):  # x: [B, 5, 1000, 7]
        x = x.unsqueeze(1)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc(x))
        return self.out(x).view(-1, 1, 70, 70)

# 2. WaveNet: Forward PINN モデル u(x,z,t)
class WaveNet(nn.Module):
    def __init__(self, in_dim=3, hidden=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.Tanh(),
            nn.Linear(hidden, hidden),
            nn.Tanh(),
            nn.Linear(hidden, 1),
        )

    def forward(self, xzt):  # xzt: [B, 3]
        return self.net(xzt)

# 3. VelocityMapFn: V(x,z) を連続関数に変換
class VelocityMapFn(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, waveform, coords):  # coords: [B, 2]
        v_grid = self.model(waveform)  # [B, 1, 70, 70]
        coords = (coords + 1) / 2
        coords = coords.unsqueeze(1).unsqueeze(1)
        return F.grid_sample(v_grid, coords, align_corners=True).view(-1, 1)

# 4. PDE残差計算（Forward PINN Loss）
def compute_pde_residual(u_model, v_fn, waveform, xzt):
    xzt.requires_grad_(True)
    u = u_model(xzt)
    grads = torch.autograd.grad(u, xzt, grad_outputs=torch.ones_like(u), create_graph=True)[0]
    u_x, u_z, u_t = grads[:, 0:1], grads[:, 1:2], grads[:, 2:3]
    u_xx = torch.autograd.grad(u_x, xzt, grad_outputs=torch.ones_like(u_x), create_graph=True)[0][:, 0:1]
    u_zz = torch.autograd.grad(u_z, xzt, grad_outputs=torch.ones_like(u_z), create_graph=True)[0][:, 1:2]
    u_tt = torch.autograd.grad(u_t, xzt, grad_outputs=torch.ones_like(u_t), create_graph=True)[0][:, 2:3]

    v = v_fn(waveform, xzt[:, :2])
    residual = u_tt - v**2 * (u_xx + u_zz)
    return residual

# ---------------------------------------
# 推論：CNN (波形 → V) に対し、Forward PINNで u(x,z,t) を再構成
# → PDE損失（物理整合）を通じてCNNを間接的に訓練する
# ---------------------------------------


In [None]:
import torch
import torch.nn.functional as F

# 🔹 座標サンプリング関数（x, z, t） ∈ [-1, 1]^3
def sample_coords(batch_size):
    # 均等な乱数を [-1, 1] の範囲で生成（B, 3）
    return 2.0 * torch.rand(batch_size, 3) - 1.0

# # 🔹 学習ループ（Forward PINN + Inverse CNN）
# def train_forward_pinn(
#     waveform_encoder, velocity_map_fn, wave_model,
#     train_loader, optimizer, device,
#     epochs=100, best_model_path="best_model.pth"
# ):
#     best_loss = float('inf')
#     for epoch in range(epochs):
#         waveform_encoder.train()
#         velocity_map_fn.train()
#         wave_model.train()
#         total_loss = 0.0

#         for waveform_batch in train_loader:  # waveform_batch: [B, 5, 1000, 7]
#             waveform_batch = waveform_batch.to(device)
#             batch_size = waveform_batch.size(0)
#             xzt = sample_coords(batch_size=512).to(device)  # 固定サンプル数

#             # PDE損失を計算（Forward PINN）
#             residual = compute_pde_residual(wave_model, velocity_map_fn, waveform_batch, xzt)
#             loss_pde = (residual ** 2).mean()

#             optimizer.zero_grad()
#             loss_pde.backward()
#             optimizer.step()

#             total_loss += loss_pde.item()

#         avg_loss = total_loss / len(train_loader)
#         print(f"Epoch {epoch+1}/{epochs} | PDE Loss: {avg_loss:.6f}")

#         if avg_loss < best_loss:
#             best_loss = avg_loss
#             torch.save(waveform_encoder.state_dict(), best_model_path)
#             print(f"✅ Best model saved at epoch {epoch+1} with loss {best_loss:.6f}")

#     print("Training completed.")


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

# 🔹 2. 学習 + 検証ループ（val_loader付き）
def train_forward_pinn_with_val(
    waveform_encoder, velocity_map_fn, wave_model,
    train_loader, val_loader, optimizer, device,
    epochs=100, best_model_path="best_forward_pinn_model.pth"
):
    best_loss = float('inf')
    train_losses = []
    val_losses = []

    for epoch in range(epochs):
        waveform_encoder.train()
        wave_model.train()
        total_loss = 0.0

        for waveform_batch in train_loader:
            waveform_batch = waveform_batch.to(device)
            xzt = sample_coords(batch_size=512).to(device)

            residual = compute_pde_residual(wave_model, velocity_map_fn, waveform_batch, xzt)
            loss_pde = (residual ** 2).mean()

            optimizer.zero_grad()
            loss_pde.backward()
            optimizer.step()

            total_loss += loss_pde.item()

        avg_train_loss = total_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        # --- 検証ループ ---
        waveform_encoder.eval()
        wave_model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for waveform_batch in val_loader:
                waveform_batch = waveform_batch.to(device)
                xzt = sample_coords(batch_size=512).to(device)
                residual = compute_pde_residual(wave_model, velocity_map_fn, waveform_batch, xzt)
                loss_pde = (residual ** 2).mean()
                val_loss += loss_pde.item()

        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)

        print(f"Epoch {epoch+1}/{epochs} | Train: {avg_train_loss:.6f} | Val: {avg_val_loss:.6f}")

        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            torch.save(waveform_encoder.state_dict(), best_model_path)
            print(f"✅ Best model saved at epoch {epoch+1} with val loss {best_loss:.6f}")

    print("Training completed.")
    return train_losses, val_losses

# 🔹 3. Loss可視化関数
def plot_losses(train_losses, val_losses):
    plt.figure()
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('PDE Loss')
    plt.legend()
    plt.title('Training vs Validation Loss')
    plt.grid(True)
    plt.show()

# 🔹 4. 速度マップ可視化関数
def visualize_velocity_map(velocity_tensor):
    # velocity_tensor: [1, 1, 70, 70]
    velocity_map = velocity_tensor.detach().cpu().squeeze().numpy()
    plt.figure(figsize=(6,5))
    plt.imshow(velocity_map, cmap='viridis', origin='lower')
    plt.colorbar(label='Velocity (m/s)')
    plt.title('Predicted Velocity Map')
    plt.xlabel('X')
    plt.ylabel('Z')
    plt.show()


In [None]:
import matplotlib.pyplot as plt

# 予測用: モデルを推論モードに
waveform_encoder.eval()

n_show = 5  # 表示したいサンプル数

with torch.no_grad():
    xb, yb = next(iter(val_loader))
    xb = xb.to(device)
    yb = yb.to(device)
    v_pred = waveform_encoder(xb)  # [B,1,70,70]

# 描画
fig, axes = plt.subplots(n_show, 2, figsize=(8, n_show*3))

for i in range(n_show):
    # 0番目からn_show番目まで
    pred_img = v_pred[i, 0].detach().cpu().numpy()  # [70,70]
    gt_img   = yb[i, 0].detach().cpu().numpy()      # [70,70]

    # 正解
    ax_gt = axes[i, 0]
    im_gt = ax_gt.imshow(gt_img, cmap="jet", aspect='auto')
    ax_gt.set_title(f"Ground Truth #{i}")
    fig.colorbar(im_gt, ax=ax_gt, fraction=0.046, pad=0.04)
    ax_gt.axis("off")

    # 予測
    ax_pred = axes[i, 1]
    im_pred = ax_pred.imshow(pred_img, cmap="jet", aspect='auto')
    ax_pred.set_title(f"Prediction #{i}")
    fig.colorbar(im_pred, ax=ax_pred, fraction=0.046, pad=0.04)
    ax_pred.axis("off")

plt.tight_layout()
plt.show()