In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm


In [None]:
class SeismicDataset(Dataset):
    def __init__(self, waves_path, vels_path):
        self.waves = np.load(waves_path)
        self.vels  = np.load(vels_path)
    def __len__(self):
        return len(self.waves)
    def __getitem__(self, idx):
        x = torch.from_numpy(self.waves[idx]).float()
        y = torch.from_numpy(self.vels[idx]).float()
        return x, y

train_ds = WaveformDataset("../dataset_one_batch/train_waves.npy", "../dataset_one_batch/train_vels.npy")
val_ds   = WaveformDataset("../dataset_one_batch/val_waves.npy", "../dataset_one_batch/val_vels.npy")

train_loader = DataLoader(train_ds, batch_size=8, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=8, shuffle=False)


In [None]:
class VelocityMapCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv3d(1, 16, kernel_size=(3, 3, 3), padding=1), nn.ReLU(),
            nn.Conv3d(16, 32, kernel_size=(3, 3, 3), padding=1), nn.ReLU()
        )
        self.fc = nn.Linear(32 * 5 * 1000 * 70 // 1000, 70 * 70)

    def forward(self, x):  # x: (B, 1, 5, 1000, 70)
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x.view(-1, 1, 70, 70)


In [None]:
def solve_wave_2d(v, source_pos, nt=1000, dx=10.0, dz=10.0, dt=0.001):
    """
    Solve 2D acoustic wave equation using FDTD for a single sample.
    v: (1, H, W) tensor - velocity map (km/s)
    source_pos: list of tuples (x_idx, z_idx)
    Returns: simulated waveform (5, nt, 70) tensor
    """
    device = v.device
    B, H, W = v.shape  # B=1
    u = torch.zeros((H, W, nt), device=device)
    u_prev = torch.zeros_like(u)
    u_next = torch.zeros_like(u)

    # ソース波形：Ricker wavelet
    f0 = 10.0  # Hz
    t = torch.arange(0, nt) * dt
    ricker = (1 - 2*(np.pi*f0*(t-0.05))**2) * torch.exp(-(np.pi*f0*(t-0.05))**2)
    ricker = ricker.to(device)

    receivers = torch.zeros((len(source_pos), nt, W), device=device)  # (5, nt, 70)

    for ti in range(1, nt-1):
        lap = (
            u[2:,1:-1,ti] - 2*u[1:-1,1:-1,ti] + u[:-2,1:-1,ti]
        ) / dx**2 + (
            u[1:-1,2:,ti] - 2*u[1:-1,1:-1,ti] + u[1:-1,:-2,ti]
        ) / dz**2
        u_next[1:-1,1:-1,ti+1] = 2*u[1:-1,1:-1,ti] - u_prev[1:-1,1:-1,ti-1] + \
            (dt**2) * (v[0,1:-1,1:-1]**2) * lap

        # 震源の追加
        for s, (sx, sz) in enumerate(source_pos):
            if 1 <= sx < H-1 and 1 <= sz < W-1:
                u_next[sx, sz, ti+1] += ricker[ti]

        # 更新
        u_prev[:,:,ti] = u[:,:,ti]
        u[:,:,ti+1] = u_next[:,:,ti+1]

        # 受信（z=0）
        for s in range(len(source_pos)):
            receivers[s, ti, :] = u[:, 0, ti]

    return receivers



In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VelocityMapCNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

train_losses, val_losses = [], []
best_val_loss = float('inf')

for epoch in range(10):
    model.train()
    total_loss = 0
    for wave, vel_true in train_loader:
        wave = wave.unsqueeze(1).to(device)
        vel_pred = model(wave)
        sim_wave = solve_wave_2d(vel_pred, src_pos=[(0,0),(17,0),(34,0),(52,0),(69,0)])
        loss = F.mse_loss(sim_wave, wave.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    train_losses.append(total_loss / len(train_loader))

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for wave, _ in val_loader:
            wave = wave.unsqueeze(1).to(device)
            vel_pred = model(wave)
            sim_wave = dummy_fdt_solver(vel_pred, src_pos=[(0,0),(17,0),(34,0),(52,0),(69,0)])
            loss = F.mse_loss(sim_wave, wave.to(device))
            val_loss += loss.item()
    val_loss /= len(val_loader)
    val_losses.append(val_loss)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_model.pth")

    print(f"Epoch {epoch}: Train Loss={train_losses[-1]:.4f}, Val Loss={val_losses[-1]:.4f}")


In [None]:
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("Loss Trend")
plt.grid(True)
plt.show()


In [None]:
model.load_state_dict(torch.load("best_model.pth"))
model.eval()
wave, true_vel = next(iter(val_loader))
wave = wave.unsqueeze(1).to(device)
pred_vel = model(wave).detach().cpu()

idx = 0
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
axs[0].imshow(true_vel[idx, 0], cmap='jet')
axs[0].set_title("True Velocity Map")
axs[1].imshow(pred_vel[idx, 0], cmap='jet')
axs[1].set_title("Predicted Velocity Map")
plt.show()
