# Sundial Foundation Model within CTF4Science
This notebook contains the code to run Sundial Foundation Model within the CTF4Science framework, following the quickstart guid provided in the official [Github repository](https://github.com/thuml/Sundial/blob/main/examples/quickstart_zero_shot_generation.ipynb).

In [1]:
import torch
from transformers import AutoModelForCausalLM
from ctf4science.data_module import load_dataset, parse_pair_ids, get_applicable_plots, get_prediction_timesteps, get_training_timesteps, load_validation_dataset, get_validation_prediction_timesteps
from ctf4science.eval_module import evaluate, save_results

import pickle
import os
import time
import numpy as np

model_name = 'sundial'
device = 'mps' if torch.backends.mps.is_available() else ('cuda' if torch.cuda.is_available() else 'cpu')
model = AutoModelForCausalLM.from_pretrained('thuml/sundial-base-128m', trust_remote_code=True).to(device)

  from .autonotebook import tqdm as notebook_tqdm


## Lorenz Dataset

In [2]:
dataset_name = 'ODE_Lorenz'

path_fig = f"{dataset_name}/"
os.makedirs(path_fig, exist_ok=True)

validation = False
num_samples = 40

Execute for all pair ids

In [3]:
execution_time = time.time()

for pair_id in range(1,10):

    print(f"Processing pair_id: {pair_id}")

    if validation:
        train_data, val_data, init_data = load_validation_dataset(dataset_name, pair_id=pair_id)
        forecast_length = get_validation_prediction_timesteps(dataset_name, pair_id).shape[0]
    else:
        train_data, init_data = load_dataset(dataset_name, pair_id=pair_id)
        forecast_length = get_prediction_timesteps(dataset_name, pair_id).shape[0]


    if pair_id in [2, 4]:
        recon_ctx = 200
        # Reconstruction
        print(f"> Reconstruction task, using {recon_ctx} context length")
        train_mat = train_data[0]
        train_mat = train_mat[0:recon_ctx,:]
        forecast_length = forecast_length - recon_ctx
    elif pair_id in [8, 9]:
        # Burn-in - Parametric Generalisation
        print(f"> Burn-in matrix of size {init_data.shape[0]}, using {forecast_length} forecast length")
        train_mat = init_data
        forecast_length = forecast_length - init_data.shape[0]
    else:
        # Standard prediction
        print(f"> Standard prediction task, using {forecast_length} forecast length")
        train_mat = train_data[0]

    _input_data = torch.tensor(train_mat, dtype=torch.float32).to(device).T

    # Get prediction data
    pred_data = model.generate(_input_data, max_new_tokens=forecast_length, num_samples=num_samples)

    if pair_id in [2, 4, 8, 9]:
        pred_mat = np.concatenate([train_mat, pred_data.cpu().numpy().mean(axis=1).T], axis=0)
    else:
        pred_mat = pred_data.cpu().numpy().mean(axis=1).T

    # Evaluate the performance (mean prediction over samples)
    results = evaluate(dataset_name, pair_id, pred_mat)

    # Save results
    print(f"> Prediction matrix shape: {pred_mat.shape}")
    print(f"> Results: {results}")

    pickle.dump({
        'model_name': model_name,
        'dataset_name': dataset_name,
        'pair_id': pair_id,
        'pred_mat': pred_mat,
        'pred_shape': pred_data.shape,
        'results': results
    }, open(f"{dataset_name}/pair_{pair_id}_results.pkl", "wb"))

    print(' ')

execution_time = time.time() - execution_time
# Convert to HH:MM:SS format
hours = int(execution_time // 3600)
minutes = int((execution_time % 3600) // 60)
seconds = int(execution_time % 60)

print(f"> Total execution time: {hours:02d}:{minutes:02d}:{seconds:02d}")

Processing pair_id: 1
> Standard prediction task, using 1000 forecast length


This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (10000). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.


> Prediction matrix shape: (1000, 3)
> Results: {'short_time': 60.09410006061692, 'long_time': -25.73333333333334}
 
Processing pair_id: 2
> Reconstruction task, using 200 context length
> Prediction matrix shape: (10000, 3)
> Results: {'reconstruction': 45.60054650054103}
 
Processing pair_id: 3
> Standard prediction task, using 1000 forecast length
> Prediction matrix shape: (1000, 3)
> Results: {'long_time': -39.733333333333334}
 
Processing pair_id: 4
> Reconstruction task, using 200 context length
> Prediction matrix shape: (10000, 3)
> Results: {'reconstruction': 55.899638539899364}
 
Processing pair_id: 5
> Standard prediction task, using 1000 forecast length
> Prediction matrix shape: (1000, 3)
> Results: {'long_time': -41.866666666666674}
 
Processing pair_id: 6
> Standard prediction task, using 1000 forecast length
> Prediction matrix shape: (1000, 3)
> Results: {'short_time': 27.544332725946173, 'long_time': 46.4}
 
Processing pair_id: 7
> Standard prediction task, using 100

## Kuramoto-Sivashinsky Dataset

In [None]:
dataset_name = 'PDE_KS'

path_fig = f"{dataset_name}/"
os.makedirs(path_fig, exist_ok=True)

validation = False
num_samples = 40

Let us execute for all pair ids

In [None]:
from tqdm import tqdm

execution_time = time.time()

batch_size = 400

for pair_id in range(1,10):

    print(f"Processing pair_id: {pair_id}")

    if validation:
        train_data, val_data, init_data = load_validation_dataset(dataset_name, pair_id=pair_id)
        forecast_length = get_validation_prediction_timesteps(dataset_name, pair_id).shape[0]
    else:
        train_data, init_data = load_dataset(dataset_name, pair_id=pair_id)
        forecast_length = get_prediction_timesteps(dataset_name, pair_id).shape[0]


    if pair_id in [2, 4]:
        recon_ctx = 1000
        # Reconstruction
        print(f"> Reconstruction task, using {recon_ctx} context length")
        train_mat = train_data[0]
        train_mat = train_mat[0:recon_ctx,:]
        forecast_length = forecast_length - recon_ctx
    elif pair_id in [8, 9]:
        # Burn-in - Parametric Generalisation
        print(f"> Burn-in matrix of size {init_data.shape[0]}, using {forecast_length} forecast length")
        train_mat = init_data
        forecast_length = forecast_length - init_data.shape[0]
    else:
        # Standard prediction
        print(f"> Standard prediction task, using {forecast_length} forecast length")
        train_mat = train_data[0]

    # If GPU is too small, we can sequentially process the input data
    spatial_dim = train_mat.shape[-1]
    pred_data = np.zeros((spatial_dim, num_samples, forecast_length), dtype=np.float32)

    for i in tqdm(range(0, spatial_dim, batch_size), desc=f"Processing pair_id {pair_id} in batches"):

        _input_data = torch.tensor(train_mat[:, i : (i + batch_size)], dtype=torch.float32).to(device).T
        _tmp = model.generate(_input_data, max_new_tokens=forecast_length, num_samples=num_samples)
        pred_data[i : (i + batch_size)] = _tmp.cpu().numpy()

    if pair_id in [2, 4, 8, 9]:
        pred_mat = np.concatenate([train_mat, pred_data.mean(axis=1).T], axis=0)
    else:
        pred_mat = pred_data.mean(axis=1).T

    # Evaluate the performance (mean prediction over samples)
    results = evaluate(dataset_name, pair_id, pred_mat)

    # Save results
    print(f"> Prediction matrix shape: {pred_mat.shape}")
    print(f"> Results: {results}")

    pickle.dump({
        'model_name': model_name,
        'dataset_name': dataset_name,
        'pair_id': pair_id,
        'pred_mat': pred_mat,
        'pred_shape': pred_data.shape,
        'results': results
    }, open(f"{dataset_name}/pair_{pair_id}_results.pkl", "wb"))

    print(' ')

execution_time = time.time() - execution_time

# Convert to HH:MM:SS format
hours = int(execution_time // 3600)
minutes = int((execution_time % 3600) // 60)
seconds = int(execution_time % 60)

print(f"> Total execution time: {hours:02d}:{minutes:02d}:{seconds:02d}")

Processing pair_id: 1
> Standard prediction task, using 1000 forecast length


Processing pair_id 1 in batches: 100%|██████████| 3/3 [00:01<00:00,  2.07it/s]


> Prediction matrix shape: (1000, 1024)
> Results: {'short_time': 0.07413805063144485, 'long_time': 0.00030750364858889156}
 
Processing pair_id: 2
> Reconstruction task, using 1000 context length


Processing pair_id 2 in batches: 100%|██████████| 3/3 [00:07<00:00,  2.46s/it]


> Prediction matrix shape: (10000, 1024)
> Results: {'reconstruction': 4.60880708712782}
 
Processing pair_id: 3
> Standard prediction task, using 1000 forecast length


Processing pair_id 3 in batches: 100%|██████████| 3/3 [00:01<00:00,  2.57it/s]


> Prediction matrix shape: (1000, 1024)
> Results: {'long_time': 7.144790827862124e-05}
 
Processing pair_id: 4
> Reconstruction task, using 1000 context length


Processing pair_id 4 in batches: 100%|██████████| 3/3 [00:07<00:00,  2.40s/it]


> Prediction matrix shape: (10000, 1024)
> Results: {'reconstruction': 5.763553021359103}
 
Processing pair_id: 5
> Standard prediction task, using 1000 forecast length


Processing pair_id 5 in batches: 100%|██████████| 3/3 [00:01<00:00,  2.50it/s]


> Prediction matrix shape: (1000, 1024)
> Results: {'long_time': 0.0004692158975361238}
 
Processing pair_id: 6
> Standard prediction task, using 1000 forecast length


Processing pair_id 6 in batches: 100%|██████████| 3/3 [00:01<00:00,  2.37it/s]


> Prediction matrix shape: (1000, 1024)
> Results: {'short_time': 0.03821716311058765, 'long_time': 0.0001854924966071536}
 
Processing pair_id: 7
> Standard prediction task, using 1000 forecast length


Processing pair_id 7 in batches: 100%|██████████| 3/3 [00:01<00:00,  2.92it/s]


> Prediction matrix shape: (1000, 1024)
> Results: {'short_time': -0.06756134693686189, 'long_time': 0.000428945588781815}
 
Processing pair_id: 8
> Burn-in matrix of size 100, using 1000 forecast length


Processing pair_id 8 in batches: 100%|██████████| 3/3 [00:01<00:00,  2.87it/s]


> Prediction matrix shape: (1000, 1024)
> Results: {'short_time': -49.711427031931876}
 
Processing pair_id: 9
> Burn-in matrix of size 100, using 1000 forecast length


Processing pair_id 9 in batches: 100%|██████████| 3/3 [00:01<00:00,  2.93it/s]

> Prediction matrix shape: (1000, 1024)
> Results: {'short_time': -45.43654667699752}
 
> Total execution time: 00:00:31



