[波形データ] (5, 1000, 7)
   ↓
Encoder CNN / Transformer
   ↓
[速度マップ] (1, 70, 70)        ← この出力に対し…
   ↓
VelocityMapFn (連続関数化)
   ↓
PINN損失 (波動方程式 ∂²u/∂t² = V²(∂²u/∂x² + ∂²u/∂z²))

🧩 1. 波形 → 速度マップモデル

In [None]:
class WaveformEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv3d(1, 16, kernel_size=(3, 5, 3), padding=1)
        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1)
        self.pool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.fc = nn.Linear(32, 1024)
        self.out = nn.Linear(1024, 70 * 70)

    def forward(self, x):  # [B, 5, 1000, 7]
        x = x.unsqueeze(1)  # → [B, 1, 5, 1000, 7]
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool(x).view(x.size(0), -1)
        x = F.relu(self.fc(x))
        x = self.out(x)
        return x.view(-1, 1, 70, 70)  # → [B, 1, 70, 70]


🧮 2. 速度マップ → 連続関数化（grid_sample）

In [None]:
class VelocityMapFn(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, waveform, coords):  # waveform: [B, 5, 1000, 7], coords: [B, 2]
        v_grid = self.model(waveform)  # [B, 1, 70, 70]
        coords = (coords + 1) / 2  # [-1, 1] → [0, 1]
        coords = coords.unsqueeze(1).unsqueeze(1)  # [B, 1, 1, 2]
        return F.grid_sample(v_grid, coords, align_corners=True).view(-1, 1)


📐 3. 波動場ネットワーク u(x,z,t)

In [None]:
class WaveNet(nn.Module):
    def __init__(self, in_dim=3, hidden=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.Tanh(),
            nn.Linear(hidden, hidden),
            nn.Tanh(),
            nn.Linear(hidden, 1),
        )

    def forward(self, xzt):  # [B, 3]
        return self.net(xzt)


📘 4. PDE残差（PINN損失）

In [None]:
def compute_pde_residual(u_model, v_fn, waveform, xzt):
    xzt.requires_grad_(True)
    u = u_model(xzt)  # [B, 1]

    grads = torch.autograd.grad(u, xzt, grad_outputs=torch.ones_like(u), create_graph=True)[0]
    u_x, u_z, u_t = grads[:, 0:1], grads[:, 1:2], grads[:, 2:3]

    u_xx = torch.autograd.grad(u_x, xzt, grad_outputs=torch.ones_like(u_x), create_graph=True)[0][:, 0:1]
    u_zz = torch.autograd.grad(u_z, xzt, grad_outputs=torch.ones_like(u_z), create_graph=True)[0][:, 1:2]
    u_tt = torch.autograd.grad(u_t, xzt, grad_outputs=torch.ones_like(u_t), create_graph=True)[0][:, 2:3]

    v = v_fn(waveform, xzt[:, :2])  # x, z座標からV(x,z)
    residual = u_tt - v**2 * (u_xx + u_zz)
    return residual


🔁 5. 学習ループ（PINN + data loss）

In [None]:
waveform_encoder = WaveformEncoder().to(device)
velocity_map_fn = VelocityMapFn(waveform_encoder).to(device)
wave_model = WaveNet().to(device)
optimizer = torch.optim.Adam(list(waveform_encoder.parameters()) + list(wave_model.parameters()), lr=1e-3)

for epoch in range(epochs):
    waveform = next_waveform_batch().to(device)        # [B, 5, 1000, 7]
    xzt = sample_coords(batch_size=512).to(device)     # [B, 3], each in [-1, 1]

    residual = compute_pde_residual(wave_model, velocity_map_fn, waveform, xzt)
    loss_pde = (residual**2).mean()

    # 任意：正解速度マップとのMSE損失を加える（教師あり部分）
    v_pred = waveform_encoder(waveform)
    loss_data = F.mse_loss(v_pred, ground_truth_velocity.to(device))

    loss = loss_data + 0.01 * loss_pde

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
