# Our Goal

Our model is designed to compute:

$x_{t+1}$ from $x_t$

So we can directly predict the next state of our system given the current state.

Unlike the base model, we will use a single set of fixed parameters for the entire dataset, rather than generating new parameters for each datapoint.

# Our Data

Our model receives:
- $x_t$: Current state

And attempts to predict:
- $x_{t+1}$: The next state

We generate our data by:
- First establishing fixed physical parameters for the entire dataset
- Then for each datapoint:
  - Generating a new, random initial state $x_t$
  - Running it forward one timestep to get $x_{t+1}$
  - Using this pair to create one datapoint
  - Repeat with new random initial states but same fixed parameters

# Imports

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from typing import Dict, Tuple, List
import helper
import matplotlib.pyplot as plt
import os
import h5py
from tqdm import tqdm

# Class for generating data

In [None]:
class StateModelDataset(Dataset):
    def __init__(self, num_runs: int, timesteps_per_run: int, nr: int, nc: int, dt: float, F: float, fixed_params: dict = None):
        """
        Initialize dataset with fixed parameters for all samples.
        
        Args:
            num_runs (int): Number of different world configurations to generate
            timesteps_per_run (int): Number of timesteps to simulate for each world
            nr (int): Number of rows in grid
            nc (int): Number of columns in grid
            dt (float): Time step size
            F (float): Forcing parameter
            fixed_params (dict): Optional pre-generated parameters to use instead of generating new ones
        """
        self.num_runs = num_runs
        self.timesteps_per_run = timesteps_per_run
        self.nr = nr
        self.nc = nc
        self.dt = dt
        self.F = F
        
        # Either use provided parameters or generate new ones
        self.fixed_params = fixed_params if fixed_params is not None else self._generate_params()
        self.samples = self._generate_samples()

    def _generate_gaussian_field(self, n, nrv, ncv):
        """
        Generate a Gaussian field composed of n Gaussian functions.
        Uses same parameters as helper.py for consistency.
        """
        mux = np.random.choice(ncv, n)
        muy = np.random.choice(range(2, nrv - 2), n)
        sigmax = np.random.uniform(1, ncv/4, n)
        sigmay = np.random.uniform(1, nrv/4, n)

        v = np.zeros((nrv, ncv))
        for i in range(n):
            for x in range(ncv):
                for y in range(nrv):
                    # Create three copies for pseudo-periodic field
                    gauss = np.exp(-((x-mux[i])**2/(2*sigmax[i]**2) + (y-muy[i])**2/(2*sigmay[i]**2)))
                    gauss += np.exp(-((x-(mux[i]-ncv))**2/(2*sigmax[i]**2) + (y-muy[i])**2/(2*sigmay[i]**2)))
                    gauss += np.exp(-((x-(mux[i]+ncv))**2/(2*sigmax[i]**2) + (y-muy[i])**2/(2*sigmay[i]**2)))
                    v[y,x] += gauss
        return v

    def _generate_circular_field(self, v):
        """Generate a circular field from gradient of input field."""
        grad_v_y, grad_v_x = np.gradient(v)
        return -grad_v_y, grad_v_x

    def _generate_params(self):
        """
        Generate fixed model parameters to be used for all samples.
        Modified to match helper.py parameter generation.
        """
        # Base grid parameters - unit spacing
        DX_C = torch.ones(self.nr, self.nc + 1)
        DY_C = torch.ones(self.nr + 1, self.nc)
        DX_G = torch.ones(self.nr + 1, self.nc)
        DY_G = torch.ones(self.nr, self.nc + 1)
        RAC = torch.ones(self.nr, self.nc)

        # Generate random diffusivities (must be positive)
        # Using random uniform distribution between 0 and 1
        KX = torch.rand(self.nr, self.nc + 1)
        KY = torch.rand(self.nr + 1, self.nc)

        # Generate velocities using Gaussian field approach
        num_gauss = 16  # Number of Gaussian functions for velocity field
        gauss = self._generate_gaussian_field(num_gauss, self.nr + 1, self.nc + 1)
        VX_np, VY_np = self._generate_circular_field(gauss)
        
        # Convert velocities to PyTorch and scale
        # Using scaling factor of 100 as in helper.py
        VX = torch.from_numpy(1 * VX_np[:-1, :]).float()
        VY = torch.from_numpy(1 * VY_np[:, :-1]).float()

        # Generate random forcing field with controlled magnitude
        # Using standard normal distribution scaled down
        f = torch.randn(self.nr * self.nc) / np.sqrt(self.nr * self.nc)

        return {
            'KX': KX,
            'KY': KY,
            'DX_C': DX_C,
            'DY_C': DY_C,
            'DX_G': DX_G,
            'DY_G': DY_G,
            'VX': VX,
            'VY': VY,
            'RAC': RAC,
            'f': f,
        }

    def _generate_samples(self):
        """Generate multiple samples using fixed parameters but with controlled initial states."""
        samples = []
        total_samples = self.num_runs * (self.timesteps_per_run - 1)  # -1 because we need pairs
        
        for run in range(self.num_runs):
            if run % 100 == 0:
                print(f"Generating world {run+1}/{self.num_runs}")
            
            # Generate initial state with controlled magnitude
            # Scale by 1/sqrt(n) to maintain reasonable magnitudes
            x_current = torch.randn(self.nr * self.nc) / np.sqrt(self.nr * self.nc)
            
            # Generate timesteps for this world
            world_states = [x_current]
            for t in range(self.timesteps_per_run - 1):
                x_next = self._compute_next_state(x_current)
                world_states.append(x_next)
                x_current = x_next
            
            # Create pairs of consecutive states as samples
            for t in range(len(world_states) - 1):
                samples.append((world_states[t], world_states[t + 1]))
        
        return samples

    def _compute_next_state(self, x_t):
        """Compute x(t+1) using fixed parameters."""
        # Convert PyTorch tensors to numpy arrays for helper function
        np_params = {
            key: tensor.cpu().detach().numpy() if torch.is_tensor(tensor) else tensor
            for key, tensor in self.fixed_params.items()
        }
        
        # Get model matrix M using helper function
        M = helper.make_M_2d_diffusion_advection_forcing(
            nr=self.nr,
            nc=self.nc,
            dt=self.dt,
            KX=np_params['KX'],
            KY=np_params['KY'],
            DX_C=np_params['DX_C'],
            DY_C=np_params['DY_C'], 
            DX_G=np_params['DX_G'],
            DY_G=np_params['DY_G'],
            VX=np_params['VX'],
            VY=np_params['VY'],
            RAC=np_params['RAC'],
            F=self.F,
            cyclic_east_west=True,
            cyclic_north_south=False,
            M_is_sparse=False
        )
        
        x_t_np = x_t.cpu().detach().numpy()
        f_np = np_params['f']
        
        result_np = M @ x_t_np + self.F * f_np
        return torch.from_numpy(result_np).float()

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]



# Class for Model

In [None]:
class StateModelNet(torch.nn.Module):
    def __init__(self, nr: int, nc: int):
        super().__init__()
        self.nr = nr
        self.nc = nc
        
        self.state_size = nr * nc
        self.input_size = self.state_size  # Now only takes state as input
        self.output_size = self.state_size
        
        print(f"Input size: {self.input_size}")
        
        # Original simple architecture with two layers
        self.fc1 = torch.nn.Linear(self.input_size, 25)
        self.fc2 = torch.nn.Linear(25, self.output_size)
        self.activation = torch.nn.ReLU()
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size = x.size(0)
        
        # Ensure x has correct shape
        x = x.view(batch_size, -1)
        
        # Forward pass
        x = self.activation(self.fc1(x))
        x = self.fc2(x)
        
        return x




# Code to generate/load dataset

In [None]:
def generate_and_save_dataset(
    filename: str,
    num_train_runs: int,
    num_test_runs: int,
    timesteps_per_run: int,
    nr: int,
    nc: int,
    dt: float,
    F: float
) -> None:
    # Generate a single set of parameters first
    temp_dataset = StateModelDataset(1, 1, nr, nc, dt, F)
    shared_params = temp_dataset.fixed_params
    
    # Generate the model matrix M using the helper function
    M = helper.make_M_2d_diffusion_advection_forcing(
        nr=nr,
        nc=nc,
        dt=dt,
        KX=shared_params['KX'].numpy(),
        KY=shared_params['KY'].numpy(),
        DX_C=shared_params['DX_C'].numpy(),
        DY_C=shared_params['DY_C'].numpy(),
        DX_G=shared_params['DX_G'].numpy(),
        DY_G=shared_params['DY_G'].numpy(),
        VX=shared_params['VX'].numpy(),
        VY=shared_params['VY'].numpy(),
        RAC=shared_params['RAC'].numpy(),
        F=F,
        cyclic_east_west=True,
        cyclic_north_south=False,
        M_is_sparse=False
    )
    
    # Create dataset instances with shared parameters
    train_dataset = StateModelDataset(num_train_runs, timesteps_per_run, nr, nc, dt, F, fixed_params=shared_params)
    test_dataset = StateModelDataset(num_test_runs, timesteps_per_run, nr, nc, dt, F, fixed_params=shared_params)
    
    with h5py.File(filename, 'w') as f:
        # Save top-level data that's shared between train and test
        f.create_dataset('model_matrix', data=M)  # Store M at top level
        
        # Save metadata
        f.attrs['nr'] = nr
        f.attrs['nc'] = nc
        f.attrs['dt'] = dt
        f.attrs['F'] = F
        f.attrs['timesteps_per_run'] = timesteps_per_run
        
        # Save the shared parameters
        params_group = f.create_group('shared_params')
        for key, value in shared_params.items():
            params_group.create_dataset(key, data=value.numpy())
        
        # Create train and test groups for the actual datasets
        train_group = f.create_group('train')
        test_group = f.create_group('test')
        
        # Helper function to save a single dataset
        def save_dataset(group, dataset, desc):
            samples_group = group.create_group('samples')
            for i, (x_t, x_t_plus_1) in enumerate(tqdm(dataset, desc=desc)):
                sample_group = samples_group.create_group(f'sample_{i}')
                sample_group.create_dataset('x_t', data=x_t.numpy())
                sample_group.create_dataset('x_t_plus_1', data=x_t_plus_1.numpy())
        
        # Save training and testing datasets
        save_dataset(train_group, train_dataset, "Saving training data")
        save_dataset(test_group, test_dataset, "Saving testing data")

def load_dataset(filename: str) -> Tuple[Dict, Dict]:
    train_data = []
    test_data = []
    
    with h5py.File(filename, 'r') as f:
        # Load top-level shared data
        metadata = {
            'nr': f.attrs['nr'],
            'nc': f.attrs['nc'],
            'dt': f.attrs['dt'],
            'F': f.attrs['F'],
            'timesteps_per_run': f.attrs['timesteps_per_run']
        }
        
        # Load the model matrix M from top level
        M = f['model_matrix'][:]
        
        # Load shared parameters
        shared_params = {}
        params_group = f['shared_params']
        for key in params_group.keys():
            shared_params[key] = torch.from_numpy(params_group[key][:]).float()
        
        def load_dataset(group):
            data = []
            samples_group = group['samples']
            for sample_name in tqdm(samples_group.keys(), desc=f"Loading {group.name} data"):
                sample = samples_group[sample_name]
                x_t = torch.from_numpy(sample['x_t'][:]).float()
                x_t_plus_1 = torch.from_numpy(sample['x_t_plus_1'][:]).float()
                data.append((x_t, x_t_plus_1))
            return data
        
        train_data = load_dataset(f['train'])
        test_data = load_dataset(f['test'])
    
    # Create common shared data dictionary
    shared_data = {
        'metadata': metadata,
        'shared_params': shared_params,
        'model_matrix': M
    }
    
    # Return dictionaries that combine shared data with specific train/test data
    return {
        'train': train_data,
        **shared_data  # Unpack shared data
    }, {
        'test': test_data,
        **shared_data  # Unpack shared data
    }


class SavedDataset(torch.utils.data.Dataset):
    """Dataset class for loading pre-saved data"""
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]
    
def create_dataloaders(train_data, test_data, batch_size=32, shuffle_train=True):
    """
    Create DataLoader objects for training and testing datasets.
    Now handles data without parameters.
    
    Args:
        train_data: Training dataset dictionary containing 'train' data
        test_data: Testing dataset dictionary containing 'test' data
        batch_size: Batch size for DataLoaders
        shuffle_train: Whether to shuffle training data
        
    Returns:
        train_loader, test_loader: DataLoader objects
    """
    train_dataset = SavedDataset(train_data['train'])
    test_dataset = SavedDataset(test_data['test'])
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=shuffle_train,
        num_workers=0,  # Adjust based on system
        pin_memory=True  # Helps with GPU transfer
    )
    
    test_loader = DataLoader(
        test_dataset, 
        batch_size=batch_size, 
        shuffle=False,
        num_workers=0,  # Adjust based on system
        pin_memory=True  # Helps with GPU transfer
    )
    
    return train_loader, test_loader

In [None]:
# Generate and save datasets
"""
num_train_runs = 2000
num_test_runs = 200
generate_and_save_dataset(
    filename=f'state_transition_data_{num_train_runs}_{num_test_runs}.h5',
    num_train_runs=num_train_runs,      # Number of different runs for training
    num_test_runs=num_test_runs,        # Number of different runs for testing
    timesteps_per_run=10,      # Number of timesteps per run
    nr=10,
    nc=10,
    dt=0.01,
    F=0.01
)

# Example usage


# Load dataset and create dataloaders
train_data, test_data = load_dataset('state_model_data_one-world.h5')
train_loader, test_loader = create_dataloaders(
    train_data, 
    test_data, 
    batch_size=32
)
"""

# Training and Evaluating the Model

In [None]:
def train_and_evaluate_model(
    model: torch.nn.Module,
    train_dataloader: DataLoader,
    test_dataloader: DataLoader,
    num_epochs: int,
    learning_rate: float,
    device: str = 'cpu',
    run_number: int = None
):
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = torch.nn.MSELoss()
    
    train_losses = []
    test_losses = []
    
    # Get total dataset sizes
    train_size = len(train_dataloader.dataset)
    test_size = len(test_dataloader.dataset)
    batch_size_train = train_dataloader.batch_size
    batch_size_test = test_dataloader.batch_size
    
    run_str = f" (Run {run_number})" if run_number is not None else ""
    print(f"Training started{run_str}...")
    
    # Calculate initial losses before any training
    model.eval()
    total_train_loss = 0
    total_test_loss = 0
    
    with torch.no_grad():
        # Initial training loss
        for x_t, x_t_plus_1 in train_dataloader:
            x_t = x_t.to(device)
            x_t_plus_1 = x_t_plus_1.to(device)
            predicted = model(x_t)
            loss = criterion(predicted, x_t_plus_1)
            total_train_loss += loss.item()
            
        # Initial test loss
        for x_t, x_t_plus_1 in test_dataloader:
            x_t = x_t.to(device)
            x_t_plus_1 = x_t_plus_1.to(device)
            predicted = model(x_t)
            loss = criterion(predicted, x_t_plus_1)
            total_test_loss += loss.item()
    
    # Add initial losses (pre-training)
    initial_train_loss = total_train_loss * batch_size_train / train_size
    initial_test_loss = total_test_loss * batch_size_test / test_size
    train_losses.append(initial_train_loss)
    test_losses.append(initial_test_loss)
    
    print(f"Initial training loss: {initial_train_loss:.6f}")
    print(f"Initial test loss: {initial_test_loss:.6f}")
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        total_train_loss = 0
        for x_t, x_t_plus_1 in train_dataloader:
            x_t = x_t.to(device)
            x_t_plus_1 = x_t_plus_1.to(device)
            
            optimizer.zero_grad()
            predicted = model(x_t)
            loss = criterion(predicted, x_t_plus_1)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
        
        # Evaluation phase
        model.eval()
        total_test_loss = 0
        with torch.no_grad():
            for x_t, x_t_plus_1 in test_dataloader:
                x_t = x_t.to(device)
                x_t_plus_1 = x_t_plus_1.to(device)
                
                predicted = model(x_t)
                test_loss = criterion(predicted, x_t_plus_1)
                total_test_loss += test_loss.item()
        
        # Normalize losses by total dataset size instead of number of batches
        avg_train_loss = total_train_loss * batch_size_train / train_size
        avg_test_loss = total_test_loss * batch_size_test / test_size
        
        train_losses.append(avg_train_loss)
        test_losses.append(avg_test_loss)
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}")
            print(f"Training Loss: {avg_train_loss:.6f}")
            print(f"Test Loss: {avg_test_loss:.6f}")
    
    print(f"Training complete{run_str}!")
    return model, train_losses, test_losses

def simulate_state_evolution(model, M, x0, num_timesteps, F, f, device='cpu'):
    """
    Simulates state evolution using both the trained model and true model M.
    
    Args:
        model: Trained neural network
        M: True model matrix
        x0: Initial state
        num_timesteps: Number of timesteps to simulate
        F: Forcing parameter
        f: Forcing field
        device: Device to run model on
        
    Returns:
        tuple: (neural_net_states, true_model_states)
    """
    model.eval()
    x0_tensor = torch.tensor(x0, dtype=torch.float32).to(device)
    f_tensor = torch.tensor(f, dtype=torch.float32).to(device)
    
    # Lists to store states over time
    neural_net_states = [x0]
    true_model_states = [x0]
    
    # Simulate using neural network
    with torch.no_grad():
        current_state = x0_tensor
        for t in range(num_timesteps):
            next_state = model(current_state.unsqueeze(0)).squeeze(0)
            neural_net_states.append(next_state.cpu().numpy())
            current_state = next_state
    
    # Simulate using true model
    current_state = x0
    for t in range(num_timesteps):
        next_state = M @ current_state + F * f
        true_model_states.append(next_state)
        current_state = next_state
    
    return neural_net_states, true_model_states

def train_multiple_models(
    nr: int, 
    nc: int, 
    train_dataloader: DataLoader,
    test_dataloader: DataLoader,
    train_data: dict,
    num_runs: int = 10,
    num_epochs: int = 50,
    learning_rate: float = 0.001,
    device: str = 'cpu',
    num_timesteps_sim: int = 10
):
    """
    Train multiple models and compare their simulation results with the true model.
    Displays 5 evenly spaced timesteps in the visualization.
    
    Args:
        nr: Number of rows in the grid
        nc: Number of columns in the grid
        train_dataloader: DataLoader for training data
        test_dataloader: DataLoader for test data
        train_data: Dictionary containing model matrix, metadata, and shared parameters
        num_runs: Number of models to train
        num_epochs: Number of epochs per training run
        learning_rate: Learning rate for optimization
        device: Device to run training on
        num_timesteps_sim: Number of timesteps to simulate
    """
    all_models = []
    all_train_losses = []
    all_test_losses = []
    
    # Extract simulation parameters from train_data
    M = train_data['model_matrix']
    metadata = train_data['metadata']
    f = train_data['shared_params']['f'].numpy()
    
    # Calculate the indices for 5 evenly spaced timesteps
    display_timesteps = [
        0,  # First timestep
        num_timesteps_sim // 4,  # Quarter way
        num_timesteps_sim // 2,  # Halfway
        3 * num_timesteps_sim // 4,  # Three-quarters way
        num_timesteps_sim  # Last timestep
    ]
    
    for i in range(num_runs):
        print(f"\nStarting model {i+1}/{num_runs}")
        model = StateModelNet(nr, nc)
        
        trained_model, train_losses, test_losses = train_and_evaluate_model(
            model=model,
            train_dataloader=train_dataloader,
            test_dataloader=test_dataloader,
            num_epochs=num_epochs,
            learning_rate=learning_rate,
            device=device,
            run_number=i+1
        )
        
        # Plot individual run results
        plt.figure(figsize=(10, 6))
        epochs = range(0, num_epochs + 1)  # Start from 0 to include initial loss
        plt.plot(epochs, train_losses, 'b-', label='Training Loss')
        plt.plot(epochs, test_losses, 'r-', label='Test Loss')
        plt.title(f'Training and Testing Loss Over Time - Run {i+1}')
        plt.xlabel('Epoch')
        plt.ylabel('Loss (MSE)')
        plt.legend()
        plt.grid(True)
        plt.show()
        
        # After training, compare model simulations
        # Get sample initial state from test data
        sample_batch = next(iter(test_dataloader))
        x0 = sample_batch[0][0].numpy()  # Take first state from first batch
        
        # Simulate evolution using both models
        neural_net_states, true_model_states = simulate_state_evolution(
            trained_model, M, x0, num_timesteps=num_timesteps_sim, F=metadata['F'], f=f, device=device
        )
        
        # Extract only the desired timesteps for visualization
        neural_net_display = [neural_net_states[t] for t in display_timesteps]
        true_model_display = [true_model_states[t] for t in display_timesteps]
        
        # Plot comparison
        helper.plot_multi_heatmap_time_evolution(
            saved_timesteps=display_timesteps,  # Use the display timesteps
            many_states_over_time=[neural_net_display, true_model_display],
            nr=nr,
            nc=nc,
            titles=["Neural Network", "True Model"],
            big_title=f"State Evolution Comparison - Model {i+1}\nTimesteps {display_timesteps}",
            vmin=None,
            vmax=None
        )
        
        all_models.append(trained_model)
        all_train_losses.append(train_losses)
        all_test_losses.append(test_losses)
    
    return all_models, all_train_losses, all_test_losses

In [None]:
train_data, test_data = load_dataset('state_transition_data_20000_2000.h5')
train_loader = DataLoader(SavedDataset(train_data['train']), batch_size=1000, shuffle=True)
test_loader = DataLoader(SavedDataset(test_data['test']), batch_size=1000, shuffle=False)

M = train_data['model_matrix']
metadata = train_data['metadata']
x0 = torch.randn(metadata['nr'] * metadata['nc'])

# Train multiple models
device = 'cuda' if torch.cuda.is_available() else 'cpu'
models, train_losses, test_losses = train_multiple_models(
    nr=10, 
    nc=10, 
    train_dataloader=train_loader,
    test_dataloader=test_loader,
    train_data=train_data,  # Pass in the train_data dictionary
    num_runs=1,
    num_epochs=50,
    device=device,
    num_timesteps_sim=100
)

- Simulate Using Model (DONE)
- Include current parameters: describe how they work, how many weights we have vs parameters (DONE)
- Try 100 hidden units (DONE)
- Regularization
- Try purely linear network
- Make ReLU clearer
- How much smaller can we make it? - Vince

Previous problems: 
* Accidentally created different model for testing and training
* Parameters were too big, caused nasty divergence (heat temps exploding way too high)