In [1]:
import numpy as np
import os 
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

In [2]:
class AttentionBlock(nn.Module):
    def __init__(self, dim, cond_dim):
        super().__init__()
        self.q = nn.Linear(dim, dim)
        self.k = nn.Linear(dim + cond_dim, dim)
        self.v = nn.Linear(dim + cond_dim, dim)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x, cond):
        B, C, H, W = x.shape
        x_flat = x.view(B, C, -1).permute(0, 2, 1)  # (B, N, C)
        cond = cond.unsqueeze(1).repeat(1, x_flat.size(1), 1)  # (B, N, cond_dim)

        q = self.q(x_flat)
        k = self.k(torch.cat([x_flat, cond], dim=-1))
        v = self.v(torch.cat([x_flat, cond], dim=-1))

        attn = torch.softmax(q @ k.transpose(-2, -1) / (C ** 0.5), dim=-1)
        out = attn @ v
        out = self.proj(out)
        return out.permute(0, 2, 1).view(B, C, H, W)


In [3]:
class AttentionUNet(nn.Module):
    def __init__(self, cond_dim=128):
        super().__init__()
        self.cond_proj = nn.Linear(5 * 1000 * 70, cond_dim)

        self.enc1 = nn.Conv2d(1, 64, 3, padding=1)
        self.attn1 = AttentionBlock(64, cond_dim)
        self.enc2 = nn.Conv2d(64, 128, 3, padding=1)
        self.attn2 = AttentionBlock(128, cond_dim)
        self.bottleneck = nn.Conv2d(128, 256, 3, padding=1)

        self.up1_conv = nn.Conv2d(256, 128, 3, padding=1)
        self.up2_conv = nn.Conv2d(128, 64, 3, padding=1)
        self.out = nn.Conv2d(64, 1, 3, padding=1)

    def forward(self, x, t, cond):
        B = cond.size(0)
        cond = self.cond_proj(cond.view(B, -1)).float()

        x1 = F.relu(self.enc1(x))        # [B, 64, 70, 70]
        x1 = self.attn1(x1, cond)

        x2 = F.avg_pool2d(x1, 2)         # [B, 64, 35, 35]
        x2 = F.relu(self.enc2(x2))       # [B, 128, 35, 35]
        x2 = self.attn2(x2, cond)

        x3 = F.avg_pool2d(x2, 2)         # [B, 128, 17, 17]
        x3 = F.relu(self.bottleneck(x3)) # [B, 256, 17, 17]

        u1 = F.interpolate(x3, size=x2.shape[2:], mode='bilinear', align_corners=False)
        u1 = F.relu(self.up1_conv(u1)) + x2

        u2 = F.interpolate(u1, size=x1.shape[2:], mode='bilinear', align_corners=False)
        u2 = F.relu(self.up2_conv(u2)) + x1

        return self.out(u2)  # → [B, 1, 70, 70]


In [5]:
def linear_beta_schedule(timesteps):
    beta_start = 1e-2
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

def get_alphas(betas):
    alphas = 1. - betas
    alphas_cumprod = torch.cumprod(alphas, axis=0)
    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod)
    return sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod

def q_sample(x_start, t, noise, sqrt_alpha_cum, sqrt_1m_alpha_cum):
    return (
        sqrt_alpha_cum[t][:, None, None, None] * x_start +
        sqrt_1m_alpha_cum[t][:, None, None, None] * noise
    )


In [25]:
import torch
from torch.utils.data import Dataset
import numpy as np
import os

class OpenFWIDataset(Dataset):
    def __init__(self, data_dir, split="train"):
        assert split in ["train", "val"]
        self.waves = np.load(os.path.join(data_dir, f"{split}_waves.npy"))  # e.g., (N, C, T, W)
        self.vels = np.load(os.path.join(data_dir, f"{split}_vels.npy"))    # e.g., (N, 1, H, W)

        # normalize (optional)
        self.waves = self.waves.astype(np.float32)
        self.vels = self.vels.astype(np.float32)
        self.waves /= 60.0
        self.vels /= 4500.0
        

    def __len__(self):
        return len(self.waves)

    def __getitem__(self, idx):
        wave = torch.from_numpy(self.waves[idx])  # shape: (C, T, W)
        vel  = torch.from_numpy(self.vels[idx])   # shape: (1, H, W)
        return wave, vel


In [26]:
from torch.utils.data import DataLoader

train_set = OpenFWIDataset("../dataset_one_batch", split="train")
val_set   = OpenFWIDataset("../dataset_one_batch", split="val")

train_loader = DataLoader(train_set, batch_size=8, shuffle=True)
val_loader   = DataLoader(val_set, batch_size=8, shuffle=False)

# 1バッチ確認
for xb, yb in train_loader:
    print(f"wave: {xb.shape}, vel: {yb.shape}")  # 例: (8, 5, 1000, 70), (8, 1, 70, 70)
    break


wave: torch.Size([8, 5, 1000, 70]), vel: torch.Size([8, 1, 70, 70])


In [8]:
# --- 訓練設定 ---
T = 1000
EPOCHS = 200
BATCH_SIZE = 8
SAVE_PATH = "ddpm_model2.pt"
device = "cuda" if torch.cuda.is_available() else "cpu"

# --- 拡散スケジュール ---
betas = linear_beta_schedule(T).to(device)
sqrt_alpha_cum, sqrt_1m_alpha_cum = get_alphas(betas)

# --- データ / モデル ---
train_ds = OpenFWIDataset("../dataset_one_batch", split="train")
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)



In [27]:
model = AttentionUNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)



In [28]:
# train_ddpm_best.py



def train():
    # 1. 設定
    device = "cuda" if torch.cuda.is_available() else "cpu"
    T = 1000
    EPOCHS = 100
    BATCH_SIZE = 8
    SAVE_PATH = "best_ddpm_model.pt"

    # 2. 拡散スケジュール
    betas = linear_beta_schedule(T).to(device)
    sqrt_alpha_cum, sqrt_1m_alpha_cum = get_alphas(betas)

    # 3. データローダー
    train_ds = OpenFWIDataset("../dataset_one_batch", split="train")
    val_ds   = OpenFWIDataset("../dataset_one_batch", split="val")
    train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    val_dl   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False)

    # 4. モデル＆最適化
    model = AttentionUNet().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", patience=5, factor=0.5
    )
    best_val_loss = float("inf")

    # 5. 学習ループ
    for epoch in range(1, EPOCHS + 1):
        # --- train ---
        model.train()
        total_train = 0.0
        for seismic, velocity in train_dl:
            seismic, velocity = seismic.to(device), velocity.to(device)
            B = velocity.size(0)

            # ランダムタイムステップ＆ノイズサンプリング
            t = torch.randint(0, T, (B,), device=device).long()
            noise = torch.randn_like(velocity)
            x_t = sqrt_alpha_cum[t][:, None, None, None] * velocity \
                + sqrt_1m_alpha_cum[t][:, None, None, None] * noise

            # 条件flatten
            cond = seismic.view(B, -1)

            # 予測＆損失計算（MSE）
            pred_noise = model(x_t, t, cond)
            loss = F.mse_loss(pred_noise, noise)

            # 逆伝播
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            total_train += loss.item() * B

        avg_train = total_train / len(train_ds)

        # --- validation ---
        model.eval()
        total_val_mse = 0.0
        total_val_mae = 0.0
        with torch.no_grad():
            for seismic, velocity in val_dl:
                seismic, velocity = seismic.to(device), velocity.to(device)
                B = velocity.size(0)

                # 同様のノイズ注入
                t = torch.randint(0, T, (B,), device=device).long()
                noise = torch.randn_like(velocity)
                x_t = sqrt_alpha_cum[t][:, None, None, None] * velocity \
                    + sqrt_1m_alpha_cum[t][:, None, None, None] * noise

                cond = seismic.view(B, -1)
                pred_noise = model(x_t, t, cond)

                # Val MSE on noise
                total_val_mse += F.mse_loss(pred_noise, noise, reduction="sum").item()

                # Val MAE on reconstructed velocity
                pred_vel = (x_t - sqrt_1m_alpha_cum[t][:, None, None, None] * pred_noise) \
                           / sqrt_alpha_cum[t][:, None, None, None]
                total_val_mae += F.l1_loss(pred_vel, velocity, reduction="sum").item()

        avg_val_mse = total_val_mse / len(val_ds)
        avg_val_mae = total_val_mae / len(val_ds)

        # --- ベストモデル更新（MSE基準）---
        improved = ""
        if avg_val_mse < best_val_loss:
            best_val_loss = avg_val_mse
            torch.save(model.state_dict(), SAVE_PATH)
            improved = "  <-- best so far"

        # LRスケジューラ更新
        scheduler.step(avg_val_mse)

        # ログ出力
        print(
            f"[Epoch {epoch:02d}] "
            f"Train MSE: {avg_train:.4f}  "
            f"Val MSE: {avg_val_mse:.4f}  "
            f"Val MAE: {avg_val_mae:.4f} {improved}"
        )

    print(f"\nTraining complete. Best validation MSE: {best_val_loss:.4f}")
    print(f"Best model saved to: {SAVE_PATH}")

if __name__ == "__main__":
    train()



[Epoch 01] Train MSE: 0.9998  Val MSE: 4848.5299  Val MAE: 900701.8325   <-- best so far
[Epoch 02] Train MSE: 0.8306  Val MSE: 1749.3080  Val MAE: 421651.6265   <-- best so far
[Epoch 03] Train MSE: 0.2630  Val MSE: 1056.9027  Val MAE: 301549.6487   <-- best so far
[Epoch 04] Train MSE: 0.1340  Val MSE: 446.6248  Val MAE: 134014.5612   <-- best so far
[Epoch 05] Train MSE: 0.0871  Val MSE: 310.5471  Val MAE: 96886.7266   <-- best so far
[Epoch 06] Train MSE: 0.0749  Val MSE: 300.2826  Val MAE: 107455.8274   <-- best so far
[Epoch 07] Train MSE: 0.0604  Val MSE: 259.7316  Val MAE: 110581.9413   <-- best so far
[Epoch 08] Train MSE: 0.0591  Val MSE: 409.6258  Val MAE: 89791.4080 
[Epoch 09] Train MSE: 0.0480  Val MSE: 192.8938  Val MAE: 70977.1182   <-- best so far
[Epoch 10] Train MSE: 0.0510  Val MSE: 231.7322  Val MAE: 67900.0591 
[Epoch 11] Train MSE: 0.0489  Val MSE: 172.1675  Val MAE: 84148.8545   <-- best so far
[Epoch 12] Train MSE: 0.0468  Val MSE: 198.1336  Val MAE: 66431.6303

In [36]:
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import trange

# 0. GPU設定
device = "cuda" if torch.cuda.is_available() else "cpu"

# 1. データ＆モデル読み込み
val_ds = OpenFWIDataset("../dataset_one_batch", split="val")
wave, gt_vel = val_ds[0]                    # wave: [C,T,W], gt_vel: [1,H,W]
wave = wave.unsqueeze(0).to(device)         # [1,C,T,W]
cond = wave.view(1, -1)                     # [1, C*T*W]
gt_vel = gt_vel.unsqueeze(0).to(device)     # [1,1,H,W]

model = AttentionUNet().to(device)
model.load_state_dict(torch.load("best_ddpm_model.pt", map_location=device))
model.eval()

# 2. 拡散スケジュール
T = 1000
betas = linear_beta_schedule(T).to(device)    # β₁…β_T
alphas = 1 - betas                           # α₁…α_T
alpha_cum = torch.cumprod(alphas, dim=0)      # ᾱ_t
sqrt_alpha_cum = alpha_cum.sqrt()             
sqrt_1m_alpha_cum = (1 - alpha_cum).sqrt()

# 3. 逆拡散サンプリング関数
@torch.no_grad()
def p_sample(x, t, cond):
    """
    x: [B,1,H,W]            ノイズ付加済みサンプル x_t
    t: [B]                  タイムステップ
    cond: [B, C*T*W]        flattenされた seismic
    """
    eps_pred = model(x, t, cond)           # ノイズ予測
    beta_t = betas[t]
    alpha_t = alphas[t]
    abar_t = alpha_cum[t]
    # abar_prev は t=0 のとき1.0
    abar_prev = torch.cat([torch.tensor([1.0], device=device), alpha_cum[:-1]])[t]

    # 平均 μ_t(x_t)
    coef = beta_t / sqrt_1m_alpha_cum[t]
    mean = (x - coef.view(-1,1,1,1) * eps_pred) / torch.sqrt(alpha_t).view(-1,1,1,1)

    # 分散 β̂_t = β_t * (1−ᾱ_{t−1})/(1−ᾱ_t)
    var = beta_t * (1 - abar_prev) / (1 - abar_t)
    if (t > 0).any():
        noise = torch.randn_like(x)
        return mean + torch.sqrt(var).view(-1,1,1,1) * noise
    else:
        return mean

# 4. 全ステップ逆拡散ループ
x = torch.randn_like(gt_vel)  # x_T をノイズで初期化
for ti in trange(T-1, -1, -1, desc="Reverse Sampling"):
    t = torch.full((1,), ti, dtype=torch.long, device=device)
    x = p_sample(x, t, cond)

# 5. 可視化（[0,1]→元スケールへ）
pred = x.clamp(0,1).squeeze().cpu().numpy() * 4500
gt   = gt_vel.squeeze().cpu().numpy() * 4500

fig, axs = plt.subplots(1,2,figsize=(10,4))
axs[0].imshow(gt,   cmap="viridis"); axs[0].set_title("Ground Truth")
axs[1].imshow(pred, cmap="viridis"); axs[1].set_title("DDPM Full Sampling")
plt.tight_layout()
plt.show()
