In [None]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install h5py matplotlib numpy scipy tqdm einops

In [None]:
%cd /content
import os
if not os.path.exists('PDEBench'):
    !git clone https://github.com/pdebench/PDEBench.git
%cd PDEBench
# Relax torchvision and torch requirements in pyproject.toml to match the installed version (compatible with Python 3.12)
!sed -i 's/torchvision~=0.14.1/torchvision/' pyproject.toml
!sed -i 's/torch~=1.13.0/torch/' pyproject.toml
# Relax Python version requirement
!sed -i 's/requires-python = .*/requires-python = ">=3.9"/' pyproject.toml
!pip install -e .

In [None]:
import os
import urllib.request
from tqdm import tqdm
import pandas as pd
import h5py

os.makedirs('data', exist_ok=True)

# Load the URLs CSV to find the correct link
urls_df = pd.read_csv('/content/PDEBench/pdebench/data_download/pdebench_data_urls.csv')

# Search for the 2D Diffusion-Reaction dataset
file_row = urls_df[urls_df['Filename'] == '2D_diff-react_NA_NA.h5']

if not file_row.empty:
    DATASET_URL = file_row.iloc[0]['URL']
    DATASET_PATH = os.path.join('data', file_row.iloc[0]['Filename'])

    print(f"Found URL: {DATASET_URL}")
    print(f"Target Path: {DATASET_PATH}")

    def download_dataset():
        print("Downloading 2D Diffusion-Reaction dataset...")
        class DownloadProgressBar(tqdm):
            def update_to(self, b=1, bsize=1, tsize=None):
                if tsize is not None:
                    self.total = tsize
                self.update(b * bsize - self.n)

        with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc="Downloading") as t:
            urllib.request.urlretrieve(DATASET_URL, DATASET_PATH, reporthook=t.update_to)
        print(f"Downloaded to {DATASET_PATH}")

    # Check if file exists and is valid
    if os.path.exists(DATASET_PATH):
        try:
            # Try opening the file to check for corruption
            with h5py.File(DATASET_PATH, 'r') as f:
                pass
            print(f"Dataset exists and is valid at {DATASET_PATH}")
        except OSError:
            print("Detected corrupted/incomplete file. Deleting and re-downloading...")
            os.remove(DATASET_PATH)
            download_dataset()
    else:
        download_dataset()

    # Proceed with inspecting the dataset
    with h5py.File(DATASET_PATH, 'r') as f:
        print(f"\nDataset structure:")
        # Recursively visit items to handle Groups
        def print_item(name, obj):
            if isinstance(obj, h5py.Dataset):
                print(f"  {name}: {obj.shape}")
            else:
                print(f"  {name}/")

        f.visititems(print_item)

else:
    print("Could not find the dataset URL in the metadata. Please check the 'urls_df' manually.")

In [None]:
import numpy as np
import h5py
import matplotlib.pyplot as plt

# Load dataset
with h5py.File(DATASET_PATH, 'r') as f:
    keys = sorted(list(f.keys()))
    print(f"Root keys (first 5): {keys[:5]}")

    # Check if 'data' is the main key or if we have sample groups
    if 'data' in f:
        # Lazy load the dataset (don't read into RAM yet)
        data_dset = f['data']
        print(f"Data shape: {data_dset.shape}")
        # Load sample 0
        sample_data = data_dset[0]
        num_samples = data_dset.shape[0]
    else:
        # Assume structure is Group '0000' -> Dataset 'data'
        # Inspect first group
        first_group = f[keys[0]]
        if 'data' in first_group:
            # Shape of one sample: [time, x, y, c]
            sample_shape = first_group['data'].shape
            num_samples = len(keys)
            print(f"Data appears to be split across {num_samples} groups.")
            print(f"Single sample shape: {sample_shape}")

            # Load sample 0
            sample_data = first_group['data'][:]
        else:
             raise ValueError(f"Could not find 'data' dataset in group {keys[0]}. Keys: {first_group.keys()}")

    print(f"  - Samples: {num_samples}")
    print(f"  - Time steps: {sample_data.shape[0]}")
    print(f"  - Spatial: {sample_data.shape[1]} x {sample_data.shape[2]}")
    print(f"  - Channels (u, v): {sample_data.shape[3]}")

    # Visualize a sample
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))

    # Select 5 evenly spaced time steps
    time_steps = np.linspace(0, sample_data.shape[0]-1, 5, dtype=int)

    for i, t in enumerate(time_steps):
        # Activator (u)
        axes[0, i].imshow(sample_data[t, :, :, 0], cmap='viridis')
        axes[0, i].set_title(f't={t}')
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_ylabel('Activator (u)')

        # Inhibitor (v)
        axes[1, i].imshow(sample_data[t, :, :, 1], cmap='plasma')
        axes[1, i].axis('off')
        if i == 0:
            axes[1, i].set_ylabel('Inhibitor (v)')

    plt.suptitle('2D Diffusion-Reaction: Time Evolution', fontsize=14)
    plt.tight_layout()
    plt.savefig('data_visualization.png', dpi=150)
    plt.show()

print("\nThis is what we're learning to predict:")
print("Given initial state (t=0), predict future states (t=1,2,...,T)")

In [None]:
import torch

if torch.cuda.is_available():
    print(f"Success! GPU detected: {torch.cuda.get_device_name(0)}")
    print("You can now re-run the notebook cells to install libraries and download the data.")
else:
    print("WARNING: GPU not detected. Please check 'Runtime > Change runtime type' and select T4 GPU.")

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import h5py

class LazyDiffusionReactionDataset(Dataset):
    def __init__(self, file_path, split='train', train_ratio=0.8):
        self.file_path = file_path
        self.split = split

        # Open file just to read metadata
        with h5py.File(file_path, 'r') as f:
            self.keys = sorted(list(f.keys()))
            # Filter for actual data groups (assuming numbered keys like '0000')
            self.keys = [k for k in self.keys if k.isdigit()]

            # Split into train/test
            n_train = int(len(self.keys) * train_ratio)
            if split == 'train':
                self.keys = self.keys[:n_train]
            else:
                self.keys = self.keys[n_train:]

            # Get dimensions from first sample
            sample0 = f[self.keys[0]]['data']
            self.time_steps = sample0.shape[0] - 1 # predicting t -> t+1

        print(f"[{split}] Initialized with {len(self.keys)} trajectories")

    def __len__(self):
        # Total samples = (Number of trajectories) * (Time steps per trajectory)
        return len(self.keys) * self.time_steps

    def __getitem__(self, idx):
        # Map flat index to (trajectory_id, time_step)
        traj_idx = idx // self.time_steps
        t_idx = idx % self.time_steps

        key = self.keys[traj_idx]

        # Open file ONLY for this specific read (saves RAM)
        with h5py.File(self.file_path, 'r') as f:
            # Read only the two necessary frames (t and t+1)
            # Slicing [t_idx : t_idx+2] reads just 2 frames, not the whole video
            frames = f[key]['data'][t_idx : t_idx+2]

        x = frames[0] # Input (t)
        y = frames[1] # Target (t+1)

        # Convert to Torch: [H, W, C] -> [C, H, W]
        x = torch.from_numpy(x).permute(2, 0, 1).float()
        y = torch.from_numpy(y).permute(2, 0, 1).float()

        return x, y

# Initialize the lazy loaders
# Note: num_workers=0 is safer for HDF5 files to avoid read conflicts
train_dataset = LazyDiffusionReactionDataset(DATASET_PATH, split='train')
test_dataset = LazyDiffusionReactionDataset(DATASET_PATH, split='test')

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)

# Verify one batch
x, y = next(iter(train_loader))
print(f"\nBatch Loaded Successfully!")
print(f"Input Shape: {x.shape} (Batch, Channels, Height, Width)")

In [None]:
"""
FNO: Fourier Neural Operator

Instead of learning spatial convolution kernels (like CNN),
FNO learns the kernel directly in Fourier space. This makes it:
1. Resolution-independent (train on 64x64, test on 256x256)
2. Naturally captures global patterns (not just local)
3. Efficient for smooth PDE solutions
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

class SpectralConv2d(nn.Module):
  def __init__(self, in_channels, out_channels, modes1, modes2):
    super().__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.modes1 = modes1  # Number of Fourier modes to keep (height)
    self.modes2 = modes2  # Number of Fourier modes to keep (width)

    # Scale factor for initialization
    self.scale = 1 / (in_channels * out_channels)

    # Learnable weights in Fourier space
    # These are COMPLEX numbers (real + imaginary parts)
    # Shape: [in_channels, out_channels, modes1, modes2]
    self.weights1 = nn.Parameter(
      self.scale * torch.rand(in_channels, out_channels, modes1, modes2, dtype=torch.cfloat)
    )
    self.weights2 = nn.Parameter(
      self.scale * torch.rand(in_channels, out_channels, modes1, modes2, dtype=torch.cfloat)
    )

  def compl_mul2d(self, input, weights):
      """
      Complex multiplication in Fourier space.

      input: [batch, in_channels, height, width] (complex)
      weights: [in_channels, out_channels, height, width] (complex)
      output: [batch, out_channels, height, width] (complex)

      Einstein notation: batch(b), in_channel(i), out_channel(o), x, y
      """
      return torch.einsum("bixy,ioxy->boxy", input, weights)

  def forward(self, x):
    batch_size = x.shape[0]

    # FFT (spatial domain → frequency domain)
    # rfft2 = real FFT for 2D (more efficient than full FFT for real inputs)
    x_ft = torch.fft.rfft2(x)

    # Multiply by learnable weights in Fourier space
    # We only keep low-frequency modes (truncate high frequencies)
    # This is like a learnable low-pass filter
    out_ft = torch.zeros(
      batch_size, self.out_channels, x.size(-2), x.size(-1) // 2 + 1,
      dtype=torch.cfloat, device=x.device
    )

    # Lower frequencies (top-left corner in FFT output)
    out_ft[:, :, :self.modes1, :self.modes2] = \
      self.compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1)

    # Higher frequencies that wrap around (bottom-left corner)
    out_ft[:, :, -self.modes1:, :self.modes2] = \
      self.compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2)

    # Inverse FFT (frequency domain → spatial domain)
    x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)))

    return x

class FNO2d(nn.Module):
  """
  Complete 2D Fourier Neural Operator.

  Architecture:
  1. LIFT: Project input channels to higher dimension
  2. FOURIER LAYERS: Spectral convolution + regular convolution (residual)
  3. PROJECT: Map back to output channels

  The combination of spectral conv (global, smooth patterns) and
  regular conv (local details) captures both large-scale and small-scale features.
  """
  def __init__(
    self,
    modes1: int = 12,          # Fourier modes in height
    modes2: int = 12,          # Fourier modes in width
    width: int = 32,           # Hidden channel dimension
    in_channels: int = 2,      # Input: u and v fields
    out_channels: int = 2,     # Output: u and v fields (next timestep)
    n_layers: int = 4,         # Number of Fourier layers
  ):
    super().__init__()

    self.modes1 = modes1
    self.modes2 = modes2
    self.width = width
    self.n_layers = n_layers

    # LIFT: [in_channels + 2] → [width]
    # +2 because we concatenate (x, y) coordinate grid
    # This helps the model know WHERE it is in the domain
    self.fc0 = nn.Linear(in_channels + 2, width)

    # FOURIER LAYERS
    self.spectral_convs = nn.ModuleList([
      SpectralConv2d(width, width, modes1, modes2)
      for _ in range(n_layers)
    ])

    # Regular 1x1 convolutions (local/residual path)
    self.conv_layers = nn.ModuleList([
      nn.Conv2d(width, width, 1)
      for _ in range(n_layers)
    ])

    # PROJECT: [width] → [128] → [out_channels]
    self.fc1 = nn.Linear(width, 128)
    self.fc2 = nn.Linear(128, out_channels)

  def get_grid(self, shape, device):
      """
      Create normalized (x, y) coordinate grid.

      This is crucial: it tells the model WHERE each point is.
      Without this, the model can't distinguish boundaries from interior.

      For batteries: boundary = cooling plate, interior = cell center
      The model needs to know this to predict correctly.
      """
      batch_size, _, size_x, size_y = shape

      # Create 1D coordinates [0, 1]
      gridx = torch.linspace(0, 1, size_x, device=device)
      gridy = torch.linspace(0, 1, size_y, device=device)

      # Create 2D meshgrid
      gridx, gridy = torch.meshgrid(gridx, gridy, indexing='ij')

      # Stack and expand for batch: [batch, height, width, 2]
      grid = torch.stack([gridx, gridy], dim=-1)
      grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1)

      return grid

  def forward(self, x):
      """
      Forward pass.

      Input x: [batch, channels, height, width]
      Output:  [batch, channels, height, width]
      """

      # Get coordinate grid
      grid = self.get_grid(x.shape, x.device)

      # Reshape: [B, C, H, W] → [B, H, W, C]
      x = x.permute(0, 2, 3, 1)

      # Concatenate with grid coordinates: [B, H, W, C+2]
      x = torch.cat([x, grid], dim=-1)

      # LIFT to higher dimension
      x = self.fc0(x)  # [B, H, W, width]
      x = x.permute(0, 3, 1, 2)  # [B, width, H, W]

      # FOURIER LAYERS
      for i in range(self.n_layers):
          # Two parallel paths:
          x1 = self.spectral_convs[i](x)  # Global (Fourier)
          x2 = self.conv_layers[i](x)     # Local (1x1 conv)

          # Combine and activate
          x = x1 + x2
          if i < self.n_layers - 1:
              x = F.gelu(x)  # GELU activation (smooth ReLU)

      # PROJECT back to output dimension
      x = x.permute(0, 2, 3, 1)  # [B, H, W, width]
      x = self.fc1(x)
      x = F.gelu(x)
      x = self.fc2(x)  # [B, H, W, out_channels]
      x = x.permute(0, 3, 1, 2)  # [B, out_channels, H, W]

      return x

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = FNO2d(
    modes1=12,          # Keep 12 Fourier modes (captures smooth patterns)
    modes2=12,
    width=32,           # Hidden dimension
    in_channels=2,      # u and v fields
    out_channels=2,     # Predict u and v at next timestep
    n_layers=4          # 4 Fourier layers
).to(device)

n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,}")

x_test = torch.randn(4, 2, 128, 128).to(device)
with torch.no_grad():
    y_test = model(x_test)
print(f"Input shape:  {x_test.shape}")
print(f"Output shape: {y_test.shape}")
print("\nModel initialized successfully!")

In [None]:
import time
from tqdm.notebook import tqdm

def train_epoch(model, loader, optimizer, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    n_batches = 0

    for x, y in loader:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()

        # Forward pass: predict next timestep
        pred = model(x)

        # Loss: Mean squared error between prediction and ground truth
        loss = F.mse_loss(pred, y)

        # Backward pass
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        n_batches += 1

    return total_loss / n_batches


def evaluate(model, loader, device):
    """Evaluate on test set."""
    model.eval()
    total_mse = 0
    total_samples = 0

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)

            pred = model(x)

            # Sum of squared errors
            mse = F.mse_loss(pred, y, reduction='sum')
            total_mse += mse.item()
            total_samples += x.shape[0] * x.shape[1] * x.shape[2] * x.shape[3]

    # Mean over all elements
    return total_mse / total_samples


In [None]:
epochs = 50
learning_rate = 1e-3

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)


In [None]:
print("="*60)
print("Training FNO on 2D Diffusion-Reaction")
print("="*60)
print(f"Epochs: {epochs}")
print(f"Learning rate: {learning_rate}")
print(f"Train samples: {len(train_dataset):,}")
print(f"Test samples: {len(test_dataset):,}")
print(f"Estimated time: ~{1.7 * epochs:.0f} minutes")
print("="*60)

best_test_loss = float('inf')
train_losses = []
test_losses = []

start_time = time.time()

for epoch in range(epochs):
    epoch_start = time.time()

    # Train
    train_loss = train_epoch(model, train_loader, optimizer, device)
    scheduler.step()

    epoch_time = time.time() - epoch_start

    # Evaluate every 5 epochs (faster training)
    if (epoch + 1) % 5 == 0 or epoch == 0:
        test_loss = evaluate(model, test_loader, device)
        test_losses.append((epoch, test_loss))

        # Save best model
        if test_loss < best_test_loss:
            best_test_loss = test_loss
            torch.save(model.state_dict(), 'best_fno_model.pt')

        print(f"Epoch {epoch+1:3d}/{epochs} ({epoch_time:.1f}s) | Train: {train_loss:.6f} | Test: {test_loss:.6f} | Best: {best_test_loss:.6f}")
    else:
        print(f"Epoch {epoch+1:3d}/{epochs} ({epoch_time:.1f}s) | Train: {train_loss:.6f}")

    train_losses.append(train_loss)

total_time = time.time() - start_time
print("="*60)
print(f"Training complete in {total_time/60:.1f} minutes")
print(f"Best Test Loss: {best_test_loss:.6f}")
print(f"Best Test RMSE: {np.sqrt(best_test_loss):.6f}")
print("="*60)

# Plot training curves
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss', alpha=0.7)
test_epochs, test_vals = zip(*test_losses)
plt.scatter(test_epochs, test_vals, c='red', label='Test Loss', zorder=5)
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.legend()
plt.title('Training Progress')
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(train_losses, label='Train Loss', alpha=0.7)
plt.scatter(test_epochs, test_vals, c='red', label='Test Loss', zorder=5)
plt.xlabel('Epoch')
plt.ylabel('MSE Loss (log scale)')
plt.yscale('log')
plt.legend()
plt.title('Training Progress (Log Scale)')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_curves.png', dpi=150)
plt.show()
