### Quick and Dirty - Learn how to predict 10 steps out from first 200

- Create Loaders
- Create Model
- Train
- Evaluate

#### Create Loaders / Get Data

In [None]:
import os
import random
import torch
import h5py

path = "/Users/anthonypoole/data/local_dqfno/raw_256x256"

def get_list_files(dir, rand=True):
    files = os.listdir(dir)
    if not rand:
        return files
    random.shuffle(files)
    return files

def get_pf_steps(curr_epoch, total_epoch, max_pf_step):
    # First 3 epochs: no iterative push forward (i.e. one-step prediction only)
    initial_phase = 3  
    # Last couple epochs: random number of steps (simulate training for multi-step predictions)
    final_phase = total_epoch - 2  
    if curr_epoch < initial_phase:
        return 1
    elif curr_epoch >= final_phase:
        return random.randint(1, max_pf_step)
    else:
        # Gradually increase steps linearly from 1 to max_pf_step
        steps = 1 + int((max_pf_step - 1) * (curr_epoch - initial_phase) / (final_phase - initial_phase))
        return steps

def get_chunk(file_path, chunk_size, chunk_index):
    with h5py.File(file_path, 'r') as f:
        start = chunk_index * chunk_size
        end = start + chunk_size
        
        if start >= f['gamma_n'].shape[0]:
            raise IndexError("Chunk index is out of range.")
        
        n = torch.from_numpy(f['density'][start:end])
        e = torch.from_numpy(f['omega'][start:end])
        p = torch.from_numpy(f['phi'][start:end])
        gn = torch.from_numpy(f['gamma_n'][start:end]).unsqueeze(0)
        state = torch.stack((n, e, p)).permute(1, 0, 2, 3).unsqueeze(0).unsqueeze(0)
        return (state, gn)

#### Create Model

In [None]:
from src.models.dqfno import DQFNO

config = {
    'data_dir': path,
    'modes': [[16, 16], [32, 32], [8, 8]],
    'in_channels': 1,
    'out_channels': 1,
    'hidden_channels': 8,  # Change back to 128 if needed
    'n_layers': 4,
    'dx': 1.0,  # Assign a proper value
    'derived_type': 'direct',
    'device': 'cpu',
    'lr': 0.001,
    'weight_decay': 0.01,
    'losses': ['lp', 'h1', 'derived'],
    'loss_weights': [0.4, 0.4, 0.2],
    'n_epochs': 10,
    'chunk_size': 200,
    'max_pf_step': 5  # Maximum push-forward steps
}

model = DQFNO(
    modes=config['modes'],
    in_channels=config['in_channels'],
    out_channels=config['out_channels'],
    hidden_channels=config['hidden_channels'],
    n_layers=config['n_layers'],
    dx=config['dx'],
    derived_type=config['derived_type'],
)

model.to(config['device'])

DQFNO(
  (positional_embedding): GridEmbedding2D()
  (lifting): ChannelMLP(
    (fcs): ModuleList(
      (0): Conv1d(3, 3, kernel_size=(1,), stride=(1,))
      (1): Conv1d(3, 8, kernel_size=(1,), stride=(1,))
    )
  )
  (fno_blocks): FNOBlocks(
    (spectral_convs): ModuleList(
      (0-3): 4 x SpectralConv(
        (weight): ParameterList(
            (0): Parameter containing: [torch.complex64 of size 2x8x8x16x9]
            (1): Parameter containing: [torch.complex64 of size 2x8x8x32x17]
            (2): Parameter containing: [torch.complex64 of size 2x8x8x8x5]
        )
      )
    )
    (conv3ds): ModuleList(
      (0-3): 4 x Conv3d(8, 8, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    )
  )
  (projection): ChannelMLP(
    (fcs): ModuleList(
      (0): Conv1d(8, 16, kernel_size=(1,), stride=(1,))
      (1): Conv1d(16, 1, kernel_size=(1,), stride=(1,))
    )
  )
  (derived_module): DerivedMLP()
)

#### Train

In [None]:
import torch
from scripts.utils import create_run_directory, initialize_model, get_loss_object, plot_and_save_loss
from src.losses.custom_losses import MultiTaskLoss
from src.losses.data_losses import LpLoss, H1Loss
from src.data.data_utils import get_data_loader, get_test_loader, PushForwardDataSet, push_forward

device = config['device']

optimizer = torch.optim.AdamW(model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])

selector_state = lambda y_pred, y: (y_pred[0], y[0])
selector_derived = lambda y_pred, y: (y_pred[1], y[1])

losses = []
selectors = []
for loss, weight in zip(config['losses'], config['loss_weights']):
    if loss == 'lp':
        losses.append(LpLoss(d=4, p=2, reduction='mean'))
        selectors.append(selector_state)
    elif loss == 'h1':
        losses.append(H1Loss(d=2))
        selectors.append(selector_state)
    elif loss == 'derived':
        losses.append(torch.nn.MSELoss())
        selectors.append(selector_derived)
    
loss_obj = MultiTaskLoss(
    loss_functions=losses,
    scales=config['loss_weights'],
    multi_output=True,
    input_selectors=selectors
)

for epoch in range(config['n_epochs']):
    running_loss = 0.0
    # Loop through all files in the data directory
    for file in get_list_files(config['data_dir']):
        # Pass in total epochs and max_pf_step from config
        pf_steps = get_pf_steps(epoch, config['n_epochs'], config['max_pf_step'])
        optimizer.zero_grad()
        
        file_path = os.path.join(config['data_dir'], file)
        input_data = get_chunk(file_path, config['chunk_size'], chunk_index=0)
        print(input_data[0].shape, input_data[1].shape)
        
        # Iteratively push the model forward
        for i in range(pf_steps):
            input_data = model(input_data)
        
        # Use the output after push-forward steps as prediction.
        # We choose the target chunk based on the number of steps (i.e. the next chunk)
        target_chunk_index = pf_steps  # (since pf_steps was the number of model iterations)
        target_data = get_chunk(file_path, config['chunk_size'], chunk_index=target_chunk_index)
        
        loss = loss_obj(input_data, target_data)
        loss.backward()
        optimizer.step()
        running_loss += float(loss)
    
    print(f"Epoch {epoch+1}/{config['n_epochs']} Loss: {running_loss}")


torch.Size([1, 1, 200, 3, 256, 256]) torch.Size([1, 200])


: 