In [1]:
import torch
import torch.nn as nn

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.fft import fft, ifft, fftfreq
import h5py
import os
from tqdm import tqdm
import json
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional
import warnings

@dataclass
class BurgersConfig:
    """Configuration for Burgers equation simulation"""
    # Domain
    L: float = 2 * np.pi
    periodic: bool = True

    # Viscosity values
    viscosities: List[float] = None

    # Spatial discretization
    Nx_train: int = 256
    Nx_test: int = 512

    # Temporal discretization
    T_final: float = 1.0
    cfl_factor: float = 0.4

    # Training data
    n_train: int = 500
    n_val: int = 100
    n_test: int = 100

    # Initial condition distribution
    fourier_ratio: float = 0.7
    step_ratio: float = 0.15
    gaussian_ratio: float = 0.15

    def __post_init__(self):
        if self.viscosities is None:
            self.viscosities = [3e-4, 8e-4, 2e-3, 5e-3]

        # Validate ratios
        total = self.fourier_ratio + self.step_ratio + self.gaussian_ratio
        if abs(total - 1.0) > 1e-10:
            raise ValueError(f"IC ratios must sum to 1.0, got {total}")

class BurgersDataGenerator:
    """Generate high-fidelity training data for Burgers' equation"""

    def __init__(self, config: BurgersConfig):
        self.config = config

    def generate_initial_condition(self, ic_type: str, Nx: int, seed: int) -> np.ndarray:
        """Generate initial condition based on type"""
        np.random.seed(seed)
        x = np.linspace(0, self.config.L, Nx, endpoint=False)

        if ic_type == 'fourier':
            return self._generate_fourier_ic(x)
        elif ic_type == 'step':
            return self._generate_step_ic(x)
        elif ic_type == 'gaussian':
            return self._generate_gaussian_ic(x)
        else:
            raise ValueError(f"Unknown IC type: {ic_type}")

    def _generate_fourier_ic(self, x: np.ndarray) -> np.ndarray:
        """Random Fourier series IC"""
        K = np.random.randint(1, 7)  # K ∈ [1,6]
        u = np.zeros_like(x)

        for k in range(1, K+1):
            a_k = np.random.uniform(0.1, 1.0)
            phi_k = np.random.uniform(0, 2*np.pi)
            u += a_k * np.sin(k * x + phi_k)

        return u

    def _generate_step_ic(self, x: np.ndarray) -> np.ndarray:
        """Smoothed step function IC"""
        x0 = np.random.uniform(0.2 * self.config.L, 0.8 * self.config.L)
        w = np.random.uniform(0.1, 0.3)
        A = np.random.uniform(0.5, 1.5)

        return A * np.tanh((x - x0) / w)

    def _generate_gaussian_ic(self, x: np.ndarray) -> np.ndarray:
        """Gaussian IC"""
        x0 = np.random.uniform(0.2 * self.config.L, 0.8 * self.config.L)
        sigma = np.random.uniform(0.2, 0.6)
        A = np.random.uniform(0.5, 1.5)

        return A * np.exp(-(x - x0)**2 / (2 * sigma**2))

    def solve_burgers_spectral(self, u0: np.ndarray, nu: float, Nx: int) -> Tuple[np.ndarray, np.ndarray, Dict]:
        """
        Solve Burgers equation using pseudo-spectral method with ETDRK4 and 2/3 dealiasing
        """
        # Setup
        x = np.linspace(0, self.config.L, Nx, endpoint=False)
        dx = self.config.L / Nx

        # Wavenumbers with 2/3 dealiasing
        k = fftfreq(Nx, dx/(2*np.pi))
        k_max = (2/3) * Nx//2
        dealias = np.abs(k) <= k_max

        # ETDRK4 coefficients
        dt_base = self.config.cfl_factor * min(dx / (np.max(np.abs(u0)) + 1e-12),
                                              dx**2 / (nu + 1e-12))

        # Time stepping
        t = 0.0
        u = u0.copy()
        trajectory = [u.copy()]
        times = [t]

        # Diagnostic info
        diagnostics = {
            'cfl_numbers': [],
            'energy': [0.5 * np.sum(u**2) * dx],
            'max_u': [np.max(np.abs(u))]
        }

        while t < self.config.T_final:
            # Adaptive timestep
            max_u = np.max(np.abs(u))
            dt_convective = dx / (max_u + 1e-12)
            dt_diffusive = dx**2 / (nu + 1e-12)
            dt = self.config.cfl_factor * min(dt_convective, dt_diffusive)

            # Don't overshoot final time
            if t + dt > self.config.T_final:
                dt = self.config.T_final - t

            # ETDRK4 step
            u = self._etdrk4_step(u, dt, nu, k, dealias)

            t += dt
            trajectory.append(u.copy())
            times.append(t)

            # Diagnostics
            cfl_conv = max_u * dt / dx
            cfl_diff = nu * dt / dx**2
            diagnostics['cfl_numbers'].append({'convective': cfl_conv, 'diffusive': cfl_diff})
            diagnostics['energy'].append(0.5 * np.sum(u**2) * dx)
            diagnostics['max_u'].append(max_u)

        return np.array(trajectory), np.array(times), diagnostics

    def _etdrk4_step(self, u: np.ndarray, dt: float, nu: float, k: np.ndarray, dealias: np.ndarray) -> np.ndarray:
        """Single ETDRK4 step for Burgers equation"""
        # Fourier transform
        u_hat = fft(u)

        # Linear operator (diffusion)
        L = -nu * k**2

        # ETDRK4 coefficients
        E = np.exp(dt * L)
        E2 = np.exp(dt * L / 2)

        # Handle division by zero for L = 0
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            phi1 = np.where(L == 0, dt, (E - 1) / L)
            phi2 = np.where(L == 0, dt/2, (E2 - 1) / L)
            phi3 = np.where(L == 0, dt/2, (E2 - 1) / L)

        # Nonlinear term: -u * du/dx
        def N(u_hat):
            u_real = np.real(ifft(u_hat))
            du_dx = np.real(ifft(1j * k * u_hat))
            nonlinear = -u_real * du_dx
            N_hat = fft(nonlinear)
            # Apply dealiasing
            N_hat[~dealias] = 0
            return N_hat

        # ETDRK4 stages
        N1 = N(u_hat)
        a = E2 * u_hat + phi2 * N1
        N2 = N(a)
        b = E2 * u_hat + phi3 * N2
        N3 = N(b)
        c = E2 * a + phi2 * (2*N3 - N1)
        N4 = N(c)

        # Final update
        u_hat_new = E * u_hat + phi1 * (N1 + 2*N2 + 2*N3 + N4) / 6

        return np.real(ifft(u_hat_new))

    def generate_dataset(self, output_dir: str = "burgers_dataset") -> Dict:
        """Generate complete dataset for CNN/TCN training"""

        os.makedirs(output_dir, exist_ok=True)

        # Determine IC types for each trajectory
        def get_ic_distribution(n_total):
            n_fourier = int(n_total * self.config.fourier_ratio)
            n_step = int(n_total * self.config.step_ratio)
            n_gaussian = n_total - n_fourier - n_step

            ic_types = (['fourier'] * n_fourier +
                       ['step'] * n_step +
                       ['gaussian'] * n_gaussian)
            np.random.shuffle(ic_types)
            return ic_types

        # Generate seeds and parameters
        np.random.seed(42)  # For reproducibility

        datasets = {}

        for split, n_traj in [('train', self.config.n_train),
                             ('val', self.config.n_val),
                             ('test', self.config.n_test)]:

            print(f"Generating {split} dataset ({n_traj} trajectories)...")

            # Use different Nx for test split
            Nx = self.config.Nx_test if split == 'test' else self.config.Nx_train

            # Get IC types
            ic_types = get_ic_distribution(n_traj)

            # Generate unique seeds for this split
            base_seed = {'train': 1000, 'val': 2000, 'test': 3000}[split]
            seeds = list(range(base_seed, base_seed + n_traj))

            # Randomly assign viscosities
            viscosities = np.random.choice(self.config.viscosities, n_traj)

            # For test set, ensure we have at least one unseen viscosity
            if split == 'test':
                # Add an OOD viscosity
                ood_nu = 1e-3  # Between our training values
                viscosities[0] = ood_nu

            # Generate trajectories
            X_all = []
            y_all = []
            metadata = []

            for i, (ic_type, seed, nu) in enumerate(tqdm(zip(ic_types, seeds, viscosities),
                                                        total=n_traj,
                                                        desc=f"{split.capitalize()}")):

                # Generate initial condition
                u0 = self.generate_initial_condition(ic_type, Nx, seed)

                # Solve PDE
                trajectory, times, diagnostics = self.solve_burgers_spectral(u0, nu, Nx)

                # Create sliding window samples
                for t_idx in range(len(trajectory) - 1):
                    X_all.append(trajectory[t_idx])
                    y_all.append(trajectory[t_idx + 1])

                    metadata.append({
                        'trajectory_id': i,
                        'time_step': t_idx,
                        'time': times[t_idx],
                        'viscosity': nu,
                        'ic_type': ic_type,
                        'ic_seed': seed,
                        'Nx': Nx,
                        'max_cfl_conv': max([d['convective'] for d in diagnostics['cfl_numbers']]),
                        'max_cfl_diff': max([d['diffusive'] for d in diagnostics['cfl_numbers']])
                    })

            X = np.array(X_all)
            y = np.array(y_all)

            # Store dataset
            datasets[split] = {
                'X': X,
                'y': y,
                'metadata': metadata
            }

            print(f"{split}: {X.shape[0]} samples, shape {X.shape}")

        # Compute global normalization statistics from training data
        X_train = datasets['train']['X']
        y_train = datasets['train']['y']
        all_data = np.concatenate([X_train.flatten(), y_train.flatten()])

        norm_stats = {
            'mean': float(np.mean(all_data)),
            'std': float(np.std(all_data))
        }

        print(f"Normalization stats - Mean: {norm_stats['mean']:.6f}, Std: {norm_stats['std']:.6f}")

        # Save datasets
        for split in datasets:
            # Save as HDF5
            with h5py.File(os.path.join(output_dir, f"{split}.h5"), 'w') as f:
                f.create_dataset('X', data=datasets[split]['X'])
                f.create_dataset('y', data=datasets[split]['y'])

                # Save metadata as JSON strings
                metadata_json = [json.dumps(m) for m in datasets[split]['metadata']]
                f.create_dataset('metadata', data=metadata_json, dtype=h5py.string_dtype())

        # Save configuration and normalization stats
        config_dict = {
            'config': {
                'L': self.config.L,
                'viscosities': self.config.viscosities,
                'Nx_train': self.config.Nx_train,
                'Nx_test': self.config.Nx_test,
                'T_final': self.config.T_final,
                'cfl_factor': self.config.cfl_factor,
                'n_train': self.config.n_train,
                'n_val': self.config.n_val,
                'n_test': self.config.n_test,
                'fourier_ratio': self.config.fourier_ratio,
                'step_ratio': self.config.step_ratio,
                'gaussian_ratio': self.config.gaussian_ratio
            },
            'normalization': norm_stats
        }

        with open(os.path.join(output_dir, 'config.json'), 'w') as f:
            json.dump(config_dict, f, indent=2)

        # Generate quality check plots
        self._generate_quality_plots(datasets, output_dir, norm_stats)

        return datasets, norm_stats

    def _generate_quality_plots(self, datasets: Dict, output_dir: str, norm_stats: Dict):
        """Generate quality check visualizations"""

        fig, axes = plt.subplots(2, 2, figsize=(12, 10))

        # Plot 1: Sample trajectories
        ax = axes[0, 0]
        train_data = datasets['train']

        # Find a few representative trajectories
        for traj_id in [0, 1, 2]:
            traj_samples = [m for m in train_data['metadata'] if m['trajectory_id'] == traj_id]
            if traj_samples:
                times = [m['time'] for m in traj_samples]
                energies = []

                for i, m in enumerate(traj_samples):
                    idx = next(j for j, meta in enumerate(train_data['metadata']) if meta == m)
                    u = train_data['X'][idx]
                    energy = 0.5 * np.sum(u**2) * (2*np.pi / len(u))
                    energies.append(energy)

                ax.plot(times, energies, 'o-', alpha=0.7, label=f'Traj {traj_id}, ν={traj_samples[0]["viscosity"]:.1e}')

        ax.set_xlabel('Time')
        ax.set_ylabel('Energy')
        ax.set_title('Energy Decay (Quality Check)')
        ax.legend()
        ax.grid(True, alpha=0.3)

        # Plot 2: CFL number distribution
        ax = axes[0, 1]
        cfl_conv = [m['max_cfl_conv'] for m in train_data['metadata']]
        cfl_diff = [m['max_cfl_diff'] for m in train_data['metadata']]

        ax.hist(cfl_conv, bins=20, alpha=0.5, label='Convective CFL', density=True)
        ax.hist(cfl_diff, bins=20, alpha=0.5, label='Diffusive CFL', density=True)
        ax.axvline(1.0, color='red', linestyle='--', alpha=0.7, label='CFL = 1')
        ax.set_xlabel('CFL Number')
        ax.set_ylabel('Density')
        ax.set_title('CFL Number Distribution')
        ax.legend()
        ax.grid(True, alpha=0.3)

        # Plot 3: Data distribution
        ax = axes[1, 0]
        all_X = np.concatenate([datasets[split]['X'].flatten() for split in datasets])
        ax.hist(all_X, bins=50, alpha=0.7, density=True)
        ax.axvline(norm_stats['mean'], color='red', linestyle='--', label=f"Mean = {norm_stats['mean']:.3f}")
        ax.axvline(norm_stats['mean'] - norm_stats['std'], color='orange', linestyle=':', label=f"±1σ")
        ax.axvline(norm_stats['mean'] + norm_stats['std'], color='orange', linestyle=':', alpha=0.7)
        ax.set_xlabel('Value')
        ax.set_ylabel('Density')
        ax.set_title('Data Distribution')
        ax.legend()
        ax.grid(True, alpha=0.3)

        # Plot 4: Sample initial conditions
        ax = axes[1, 1]
        x_train = np.linspace(0, 2*np.pi, self.config.Nx_train, endpoint=False)

        # Show one of each IC type
        ic_examples = {}
        for m in train_data['metadata'][:50]:  # Check first 50
            if m['time_step'] == 0 and m['ic_type'] not in ic_examples:
                idx = next(j for j, meta in enumerate(train_data['metadata']) if meta == m)
                ic_examples[m['ic_type']] = train_data['X'][idx]

        for ic_type, u0 in ic_examples.items():
            ax.plot(x_train, u0, label=f'{ic_type.capitalize()} IC', linewidth=2)

        ax.set_xlabel('x')
        ax.set_ylabel('u(x, t=0)')
        ax.set_title('Sample Initial Conditions')
        ax.legend()
        ax.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, 'quality_checks.png'), dpi=150, bbox_inches='tight')
        plt.close()

        print(f"Quality check plots saved to {output_dir}/quality_checks.png")

# Example usage and data loading utilities
class BurgersDataLoader:
    """Utility for loading and preprocessing the generated dataset"""

    def __init__(self, dataset_dir: str):
        self.dataset_dir = dataset_dir

        # Load configuration and normalization stats
        with open(os.path.join(dataset_dir, 'config.json'), 'r') as f:
            self.config_data = json.load(f)

        self.norm_stats = self.config_data['normalization']

    def load_split(self, split: str) -> Tuple[np.ndarray, np.ndarray, List[Dict]]:
        """Load a specific split of the dataset"""
        with h5py.File(os.path.join(self.dataset_dir, f"{split}.h5"), 'r') as f:
            X = f['X'][:]
            y = f['y'][:]
            metadata = [json.loads(m.decode()) for m in f['metadata'][:]]

        return X, y, metadata

    def normalize(self, data: np.ndarray) -> np.ndarray:
        """Apply normalization"""
        return (data - self.norm_stats['mean']) / self.norm_stats['std']

    def denormalize(self, data: np.ndarray) -> np.ndarray:
        """Remove normalization"""
        return data * self.norm_stats['std'] + self.norm_stats['mean']

    def get_cnn_format(self, X: np.ndarray) -> np.ndarray:
        """Convert to CNN format: [batch, channels=1, spatial]"""
        return X[:, np.newaxis, :]

    def get_tcn_format(self, X: np.ndarray, window_size: int = 3) -> np.ndarray:
        """Convert to TCN format with temporal stacking"""
        # This would require modifying the data generation to create temporal windows
        # For now, return CNN format
        return self.get_cnn_format(X)

# Main execution
if __name__ == "__main__":
    # Create configuration with specified parameters
    config = BurgersConfig(
        viscosities=[3e-4, 8e-4, 2e-3, 5e-3],
        Nx_train=256,
        Nx_test=512,
        T_final=1.0,
        n_train=500,
        n_val=100,
        n_test=100,
        fourier_ratio=0.7,
        step_ratio=0.15,
        gaussian_ratio=0.15
    )

    # Generate dataset
    generator = BurgersDataGenerator(config)
    datasets, norm_stats = generator.generate_dataset("burgers_dataset")

    print("\nDataset generation complete!")
    print(f"Training samples: {datasets['train']['X'].shape[0]}")
    print(f"Validation samples: {datasets['val']['X'].shape[0]}")
    print(f"Test samples: {datasets['test']['X'].shape[0]}")
    print(f"Normalization - Mean: {norm_stats['mean']:.6f}, Std: {norm_stats['std']:.6f}")

    # Example of loading and using the data
    print("\nExample usage:")
    loader = BurgersDataLoader("burgers_dataset")
    X_train, y_train, metadata_train = loader.load_split('train')

    # Normalize for training
    X_train_norm = loader.normalize(X_train)
    y_train_norm = loader.normalize(y_train)

    # Convert to CNN format
    X_cnn = loader.get_cnn_format(X_train_norm)
    print(f"CNN format shape: {X_cnn.shape}")  # [batch, 1, Nx]

Generating train dataset (500 trajectories)...


Train: 100%|██████████| 500/500 [00:20<00:00, 24.59it/s]


train: 70166 samples, shape (70166, 256)
Generating val dataset (100 trajectories)...


Val: 100%|██████████| 100/100 [00:03<00:00, 25.94it/s]


val: 14050 samples, shape (14050, 256)
Generating test dataset (100 trajectories)...


Test: 100%|██████████| 100/100 [00:09<00:00, 10.34it/s]


test: 27203 samples, shape (27203, 512)
Normalization stats - Mean: 0.010787, Std: 0.805432
Quality check plots saved to burgers_dataset/quality_checks.png

Dataset generation complete!
Training samples: 70166
Validation samples: 14050
Test samples: 27203
Normalization - Mean: 0.010787, Std: 0.805432

Example usage:
CNN format shape: (70166, 1, 256)


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import h5py
import json
import os
from tqdm import tqdm
from typing import Dict, List, Tuple, Optional
import time
from dataclasses import dataclass

@dataclass
class TrainingConfig:
    """Configuration for CNN training"""
    # Model architecture
    hidden_channels: List[int] = None
    kernel_sizes: List[int] = None
    activation: str = 'relu'
    dropout_rate: float = 0.1

    # Training parameters
    batch_size: int = 64
    learning_rate: float = 1e-3
    num_epochs: int = 100
    weight_decay: float = 1e-5

    # Learning rate scheduling
    lr_patience: int = 10
    lr_factor: float = 0.5

    # Early stopping
    early_stopping_patience: int = 20

    # Device
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Logging
    log_interval: int = 100
    save_dir: str = 'cnn_models'

    def __post_init__(self):
        if self.hidden_channels is None:
            self.hidden_channels = [16, 32, 64, 32, 16]
        if self.kernel_sizes is None:
            self.kernel_sizes = [7, 5, 5, 5, 7]

class BurgersDataset(Dataset):
    """PyTorch Dataset for your generated Burgers equation data"""

    def __init__(self, X: np.ndarray, y: np.ndarray, metadata: List[Dict],
                 normalize_fn=None, device='cpu'):
        self.X = torch.FloatTensor(X).to(device)
        self.y = torch.FloatTensor(y).to(device)
        self.metadata = metadata
        self.normalize_fn = normalize_fn

        # Convert to CNN format: [batch, channels=1, spatial]
        if len(self.X.shape) == 2:
            self.X = self.X.unsqueeze(1)  # Add channel dimension
        if len(self.y.shape) == 2:
            self.y = self.y.unsqueeze(1)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

class BurgersCNN(nn.Module):
    """1D CNN for Burgers equation surrogate modeling"""

    def __init__(self, config: TrainingConfig):
        super().__init__()
        self.config = config

        # Build encoder-decoder architecture
        layers = []
        in_channels = 1  # Single field u(x,t)

        # Encoder-decoder with skip connections
        for i, (out_channels, kernel_size) in enumerate(zip(config.hidden_channels, config.kernel_sizes)):
            layers.extend([
                nn.Conv1d(in_channels, out_channels, kernel_size,
                         padding=kernel_size//2, bias=True),
                self._get_activation(),
                nn.BatchNorm1d(out_channels),
                nn.Dropout1d(config.dropout_rate) if i > 0 else nn.Identity(),
            ])
            in_channels = out_channels

        # Final output layer
        layers.append(nn.Conv1d(in_channels, 1, kernel_size=1, padding=0))

        self.network = nn.Sequential(*layers)

        # Initialize weights
        self._init_weights()

    def _get_activation(self):
        if self.config.activation == 'relu':
            return nn.ReLU(inplace=True)
        elif self.config.activation == 'gelu':
            return nn.GELU()
        elif self.config.activation == 'swish':
            return nn.SiLU()
        else:
            return nn.ReLU(inplace=True)

    def _init_weights(self):
        """Initialize weights with Xavier/Glorot initialization"""
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        """Forward pass"""
        return self.network(x)

class BurgersTrainer:
    """Training pipeline for Burgers CNN using your generated data"""

    def __init__(self, config: TrainingConfig, dataset_dir: str = "burgers_dataset"):
        self.config = config
        self.dataset_dir = dataset_dir

        # Create save directory
        os.makedirs(config.save_dir, exist_ok=True)

        # Load normalization stats from your generated data
        with open(os.path.join(dataset_dir, '/content/sample_data/anscombe.json'), 'r') as f:
            dataset_config = json.load(f)

        self.norm_stats = dataset_config['normalization']
        print(f"Loaded normalization stats - Mean: {self.norm_stats['mean']:.6f}, Std: {self.norm_stats['std']:.6f}")

        # Load datasets
        self.datasets = self._load_datasets()

        # Create data loaders
        self.dataloaders = self._create_dataloaders()

        # Initialize model
        self.model = BurgersCNN(config).to(config.device)

        # Loss function and optimizer
        self.criterion = nn.MSELoss()
        self.optimizer = optim.Adam(self.model.parameters(),
                                   lr=config.learning_rate,
                                   weight_decay=config.weight_decay)

        # Learning rate scheduler
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', factor=config.lr_factor,
            patience=config.lr_patience, verbose=True
        )

        # Training history
        self.history = {
            'train_loss': [],
            'val_loss': [],
            'learning_rates': [],
            'epoch_times': []
        }

        print(f"Model initialized with {sum(p.numel() for p in self.model.parameters()):,} parameters")
        print(f"Using device: {config.device}")

    def _load_datasets(self) -> Dict[str, Tuple[np.ndarray, np.ndarray, List[Dict]]]:
        """Load your generated datasets"""
        datasets = {}

        for split in ['train', 'val', 'test']:
            with h5py.File(os.path.join(self.dataset_dir, f"{split}.h5"), 'r') as f:
                X = f['X'][:]
                y = f['y'][:]
                metadata = [json.loads(m.decode()) for m in f['metadata'][:]]

            # Normalize data
            X_norm = (X - self.norm_stats['mean']) / self.norm_stats['std']
            y_norm = (y - self.norm_stats['mean']) / self.norm_stats['std']

            datasets[split] = (X_norm, y_norm, metadata)
            print(f"Loaded {split}: {X.shape[0]} samples, shape {X.shape}")

        return datasets

    def _create_dataloaders(self) -> Dict[str, DataLoader]:
        """Create PyTorch data loaders"""
        dataloaders = {}

        for split in ['train', 'val', 'test']:
            X, y, metadata = self.datasets[split]

            dataset = BurgersDataset(X, y, metadata, device=self.config.device)

            shuffle = (split == 'train')
            batch_size = self.config.batch_size if split != 'test' else min(64, len(X))

            dataloaders[split] = DataLoader(
                dataset, batch_size=batch_size, shuffle=shuffle,
                num_workers=0, pin_memory=False  # Set to 0 since data is already on device
            )

        return dataloaders

    def train_epoch(self) -> float:
        """Train for one epoch"""
        self.model.train()
        total_loss = 0.0
        num_batches = 0

        pbar = tqdm(self.dataloaders['train'], desc='Training', leave=False)
        for batch_idx, (inputs, targets) in enumerate(pbar):
            # Forward pass
            outputs = self.model(inputs)
            loss = self.criterion(outputs, targets)

            # Backward pass
            self.optimizer.zero_grad()
            loss.backward()

            # Gradient clipping (helps with stability)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

            self.optimizer.step()

            # Update metrics
            total_loss += loss.item()
            num_batches += 1

            # Update progress bar
            pbar.set_postfix({'loss': f'{loss.item():.6f}'})

            # Log intermediate results
            if batch_idx % self.config.log_interval == 0:
                pbar.set_description(f'Training (batch {batch_idx}/{len(self.dataloaders["train"])})')

        return total_loss / num_batches

    def validate_epoch(self) -> float:
        """Validate for one epoch"""
        self.model.eval()
        total_loss = 0.0
        num_batches = 0

        with torch.no_grad():
            for inputs, targets in tqdm(self.dataloaders['val'], desc='Validating', leave=False):
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)

                total_loss += loss.item()
                num_batches += 1

        return total_loss / num_batches

    def train(self):
        """Full training loop"""
        print(f"\nStarting training for {self.config.num_epochs} epochs...")
        print("=" * 60)

        best_val_loss = float('inf')
        patience_counter = 0

        for epoch in range(self.config.num_epochs):
            epoch_start = time.time()

            # Train and validate
            train_loss = self.train_epoch()
            val_loss = self.validate_epoch()

            # Update learning rate
            self.scheduler.step(val_loss)
            current_lr = self.optimizer.param_groups[0]['lr']

            # Record history
            epoch_time = time.time() - epoch_start
            self.history['train_loss'].append(train_loss)
            self.history['val_loss'].append(val_loss)
            self.history['learning_rates'].append(current_lr)
            self.history['epoch_times'].append(epoch_time)

            # Print progress
            print(f"Epoch {epoch+1:3d}/{self.config.num_epochs} | "
                  f"Train Loss: {train_loss:.6f} | "
                  f"Val Loss: {val_loss:.6f} | "
                  f"LR: {current_lr:.2e} | "
                  f"Time: {epoch_time:.1f}s")

            # Early stopping and model saving
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0

                # Save best model
                self.save_model('best_model.pth', epoch, val_loss)
                print(f"  → New best model saved (val_loss: {val_loss:.6f})")
            else:
                patience_counter += 1

            # Early stopping check
            if patience_counter >= self.config.early_stopping_patience:
                print(f"\nEarly stopping triggered after {epoch+1} epochs")
                break

            # Save checkpoint every 10 epochs
            if (epoch + 1) % 10 == 0:
                self.save_model(f'checkpoint_epoch_{epoch+1}.pth', epoch, val_loss)

        print("=" * 60)
        print(f"Training completed! Best validation loss: {best_val_loss:.6f}")

        # Plot training history
        self.plot_training_history()

    def save_model(self, filename: str, epoch: int, val_loss: float):
        """Save model checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'val_loss': val_loss,
            'config': self.config,
            'norm_stats': self.norm_stats,
            'history': self.history
        }

        torch.save(checkpoint, os.path.join(self.config.save_dir, filename))

    def load_model(self, filename: str):
        """Load model checkpoint"""
        checkpoint = torch.load(os.path.join(self.config.save_dir, filename),
                               map_location=self.config.device)

        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.history = checkpoint['history']

        print(f"Model loaded from epoch {checkpoint['epoch']} with val_loss: {checkpoint['val_loss']:.6f}")

        return checkpoint

    def test_model(self) -> Dict[str, float]:
        """Test the best model"""
        # Load best model
        self.load_model('best_model.pth')

        self.model.eval()
        test_metrics = {'mse': 0.0, 'mae': 0.0, 'max_error': 0.0}
        num_samples = 0

        with torch.no_grad():
            for inputs, targets in tqdm(self.dataloaders['test'], desc='Testing'):
                outputs = self.model(inputs)

                # Denormalize for physical interpretation
                outputs_denorm = outputs * self.norm_stats['std'] + self.norm_stats['mean']
                targets_denorm = targets * self.norm_stats['std'] + self.norm_stats['mean']

                # Compute metrics
                mse = torch.mean((outputs_denorm - targets_denorm)**2)
                mae = torch.mean(torch.abs(outputs_denorm - targets_denorm))
                max_err = torch.max(torch.abs(outputs_denorm - targets_denorm))

                test_metrics['mse'] += mse.item() * inputs.size(0)
                test_metrics['mae'] += mae.item() * inputs.size(0)
                test_metrics['max_error'] = max(test_metrics['max_error'], max_err.item())

                num_samples += inputs.size(0)

        # Average metrics
        test_metrics['mse'] /= num_samples
        test_metrics['mae'] /= num_samples
        test_metrics['rmse'] = np.sqrt(test_metrics['mse'])

        print("\nTest Results:")
        print(f"  RMSE: {test_metrics['rmse']:.6f}")
        print(f"  MAE:  {test_metrics['mae']:.6f}")
        print(f"  Max Error: {test_metrics['max_error']:.6f}")

        return test_metrics

    def plot_training_history(self):
        """Plot training and validation loss"""
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

        epochs = range(1, len(self.history['train_loss']) + 1)

        # Loss plot
        ax1.plot(epochs, self.history['train_loss'], 'b-', label='Training Loss')
        ax1.plot(epochs, self.history['val_loss'], 'r-', label='Validation Loss')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.set_title('Training and Validation Loss')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        ax1.set_yscale('log')

        # Learning rate plot
        ax2.plot(epochs, self.history['learning_rates'], 'g-', label='Learning Rate')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Learning Rate')
        ax2.set_title('Learning Rate Schedule')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        ax2.set_yscale('log')

        plt.tight_layout()
        plt.savefig(os.path.join(self.config.save_dir, 'training_history.png'),
                   dpi=150, bbox_inches='tight')
        plt.show()

        print(f"Training history plot saved to {self.config.save_dir}/training_history.png")

    def visualize_predictions(self, num_samples: int = 5):
        """Visualize model predictions vs ground truth"""
        self.load_model('best_model.pth')
        self.model.eval()

        # Get some test samples
        test_dataset = self.datasets['test']
        X_test, y_test, metadata_test = test_dataset

        # Select random samples
        indices = np.random.choice(len(X_test), num_samples, replace=False)

        fig, axes = plt.subplots(num_samples, 1, figsize=(12, 3*num_samples))
        if num_samples == 1:
            axes = [axes]

        with torch.no_grad():
            for i, idx in enumerate(indices):
                # Get data
                x_sample = torch.FloatTensor(X_test[idx]).unsqueeze(0).unsqueeze(0).to(self.config.device)
                y_true = y_test[idx]

                # Predict
                y_pred = self.model(x_sample).squeeze().cpu().numpy()

                # Denormalize
                y_true_denorm = y_true * self.norm_stats['std'] + self.norm_stats['mean']
                y_pred_denorm = y_pred * self.norm_stats['std'] + self.norm_stats['mean']
                x_denorm = X_test[idx] * self.norm_stats['std'] + self.norm_stats['mean']

                # Plot
                x_grid = np.linspace(0, 2*np.pi, len(y_true_denorm))

                axes[i].plot(x_grid, y_true_denorm, 'b-', linewidth=2, label='Ground Truth', alpha=0.8)
                axes[i].plot(x_grid, y_pred_denorm, 'r--', linewidth=2, label='CNN Prediction')
                axes[i].plot(x_grid, x_denorm, 'g:', linewidth=1, label='Input', alpha=0.6)

                # Add metadata info
                meta = metadata_test[idx]
                axes[i].set_title(f"Sample {i+1}: ν={meta['viscosity']:.1e}, "
                                f"t={meta['time']:.3f}, IC: {meta['ic_type']}")
                axes[i].set_xlabel('x')
                axes[i].set_ylabel('u(x,t)')
                axes[i].legend()
                axes[i].grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig(os.path.join(self.config.save_dir, 'predictions_visualization.png'),
                   dpi=150, bbox_inches='tight')
        plt.show()

        print(f"Predictions visualization saved to {self.config.save_dir}/predictions_visualization.png")

# Main execution
if __name__ == "__main__":
    # Configure training
    config = TrainingConfig(
        # Model architecture
        hidden_channels=[32, 64, 128, 64, 32],
        kernel_sizes=[7, 5, 5, 5, 7],
        activation='gelu',
        dropout_rate=0.1,

        # Training parameters
        batch_size=128,
        learning_rate=1e-3,
        num_epochs=200,
        weight_decay=1e-5,

        # Learning rate scheduling
        lr_patience=15,
        lr_factor=0.7,

        # Early stopping
        early_stopping_patience=30,

        # Other settings
        log_interval=50,
        save_dir='burgers_cnn_models'
    )

    # Initialize trainer (uses your generated data)
    trainer = BurgersTrainer(config, dataset_dir="burgers_dataset")

    print(f"Dataset info:")
    print(f"  Training samples: {len(trainer.datasets['train'][0]):,}")
    print(f"  Validation samples: {len(trainer.datasets['val'][0]):,}")
    print(f"  Test samples: {len(trainer.datasets['test'][0]):,}")

    # Train the model
    trainer.train()

    # Test the trained model
    test_metrics = trainer.test_model()

    # Create visualizations
    trainer.visualize_predictions(num_samples=6)

    print("\nTraining complete! Check the 'burgers_cnn_models' directory for:")
    print("  - best_model.pth (trained model)")
    print("  - training_history.png (loss curves)")
    print("  - predictions_visualization.png (sample predictions)")

TypeError: list indices must be integers or slices, not str