Environment setup:
```
conda create -n earth2studio python=3.12 -y 
conda activate earth2studio
pip install uv
export UV_CACHE_DIR="/projectnb/eb-general/wade/uv_cache"
uv pip install "earth2studio @ git+https://github.com/NVIDIA/earth2studio.git@0.10.0"
uv pip install "earth2studio[fcn]"
uv pip install numpy matplotlib pandas xarray cartopy cmocean tqdm
uv pip install "makani @ git+https://github.com/NVIDIA/modulus-makani.git@28f38e3e929ed1303476518552c64673bbd6f722"
uv pip install earth2studio[sfno]
```


# Running inference with SFNO checkpoints

In [None]:
import os
import subprocess
from dotenv import load_dotenv

from earth2studio.io import ZarrBackend

# from earth2studio.run import deterministic
# from earth2studio.models.px import SFNO
from deterministic_update import deterministic
from SFNO_update import SFNO

import earth2studio.data as data
from earth2studio.models.auto import Package
from utils import filename_to_year, datetime_range, open_hdf5 # these aren't used in this script currently


from datetime import datetime, timedelta
import json
import xarray as xr
from typing import List
import shutil
import sys
import gc
import numpy as np
import time

import torch

  import pynvml  # type: ignore[import]
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# print versions of makani and earth2studio
import earth2studio
print(f"earth2studio version: {earth2studio.__version__}")
import makani
print(f"makani version: {makani.__version__}")


earth2studio version: 0.10.0rc0
makani version: 0.2.0


In [None]:
# Check if CUDA (GPU support) is available
is_available = torch.cuda.is_available()
print(f"Is CUDA available? {is_available}")

if is_available:
    # Get the number of available GPUs
    gpu_count = torch.cuda.device_count()
    print(f"Number of GPUs available: {gpu_count}")

    # Get the ID of the current GPU
    current_gpu = torch.cuda.current_device()
    print(f"Current GPU ID: {current_gpu}")

    # Get the name of the current GPU
    gpu_name = torch.cuda.get_device_name(current_gpu)
    print(f"Current GPU Name: {gpu_name}")

    print(f"Memory (VRAM):      {torch.cuda.get_device_properties(current_gpu).total_memory / 1e9:.2f} GB")
else:
    print("CUDA is not available. Running on CPU.")

time_start = time.time()


Is CUDA available? False
CUDA is not available. Running on CPU.


In [19]:
############# CONFIG FOR INFERENCE RUN #############

start_datetime = "2019-09-03T00:00:00" #"2022-09-24T00:00:00" # "2021_09_20T00:00:00" # 
variables_to_select = ['msl', 'u10m', 'v10m'] #['tcwv'] #Only save selected variables - it slows down inference SIGNIFICANTLY to save all 74 variables
experiment_number = 0 # which experiment directory to output to
n_steps = 4  # number of 6hr steps to forecast
epochs_to_run = [20] # List or array of epochs/checkpoint numbers to run inference on

# boring =  False
ema = False

# Create the inference name based on the start datetime and number of steps
inference_name = datetime.fromisoformat(start_datetime).strftime("%Y_%m_%dT%H")+'_nsteps'+str(n_steps)
data_create_fp = "/projectnb/eb-general/wade/sfno/inference_runs/Ian/Initialize_data/Initialize_"+inference_name+".nc"

# Calculate the final datetime based from the start datetime and number of steps
final_datetime = (datetime.fromisoformat(start_datetime) + timedelta(hours = int(n_steps*6))).isoformat() 

# Directories
results_out_dir = f"/projectnb/eb-general/wade/sfno/inference_runs/sandbox/Experiment{str(experiment_number)}/{final_datetime[:10].replace('-', '_')}/"

#################################################

In [17]:
data_create_fp, final_datetime, results_out_dir

('/projectnb/eb-general/wade/sfno/inference_runs/Ian/Initialize_data/Initialize_2019_09_03T00_nsteps4.nc',
 '2019-09-04T00:00:00',
 '/projectnb/eb-general/wade/sfno/inference_runs/sandbox/Experiment0/2019_09_04/')

In [20]:
if os.path.exists(data_create_fp):
    print(f"Data already preprocessed: {data_create_fp}")
else:
    sys.exit(f"Data not found use Create_Initial_Data.ipynb to create: {data_create_fp}")

# make this xarray into a dataarray file for earth2studio
initial_data = data.DataArrayFile(data_create_fp)

fine_tuning_start_epoch = 71 # the epoch where fine-tuning starts (important for correctly accessing the checkpoints)

time_1 = time.time()
print(f"Data loaded in {time_1 - time_start:.2f} seconds")

for n_epoch in epochs_to_run: #[70]: #[1,10,20,30,40,50,60,70,80,90]: #np.arange(1,91,5): #70,1):
    time_2 = time.time()

    if ema:
        results_out_fp = results_out_dir+f"EMA_Checkpoint{n_epoch}_{inference_name}.nc"
    else:
        results_out_fp =  results_out_dir+"Checkpoint"+str(n_epoch)+"_"+inference_name+'.nc' 
    
    # Check if the results file already exists
    if os.path.exists(results_out_fp):
        print(f"Results file {results_out_fp} already exists. Skipping to next epoch.")
        continue  # Skip the rest of the loop and go to the next iteration
    else:
        os.makedirs(os.path.dirname(results_out_fp), exist_ok=True)

        load_dotenv()

        if n_epoch < fine_tuning_start_epoch: # pre-fine-tuning phase epochs are numbered 1-70
            src_dir = "/projectnb/eb-general/shared_data/data/processed/FourCastNet_sfno/Checkpoints_SFNO/sfno_linear_74chq_sc3_layers8_edim384_dt6h_wstgl2/v0.1.0-seed999/"
            checkpoint_name = 'ckpt_mp0_epoch'+str(n_epoch)+'.tar'
        else:
            src_dir = "/projectnb/eb-general/shared_data/data/processed/FourCastNet_sfno/Checkpoints_SFNO/multistep_sfno_linear_74chq_sc3_layers8_edim384_dt6h_wstgl2/v0.1.0-seed999-multistep2/"
            n_epoch_multistep2 = n_epoch - (fine_tuning_start_epoch - 1) # fine-tuning phase epochs are numbered from 1-20
            checkpoint_name = 'ckpt_mp0_epoch'+str(n_epoch_multistep2)+'.tar'

        print(f"Loading model from {src_dir} with checkpoint {checkpoint_name}")
        # Load the model package from storage
        model_package = Package(src_dir, cache = False)
        model = SFNO.load_model(model_package, 
        checkpoint_name = checkpoint_name, EMA = ema
        )

        # Create the IO handler, store in memory
        io = ZarrBackend()
        
        with torch.no_grad():
            # run inference
            io = deterministic([start_datetime], n_steps, model, initial_data, io, 
            variables_list=variables_to_select
            )

        print(io.root.tree())

        # save results to netcdf
        # Open the Zarr group from the in-memory store using xarray
        ds = xr.open_zarr(io.root.store)
        
        # SANITY CHECKING...
        print("Dataset times:", ds["time"].values)
        print("Dataset dimensions:", {dim: ds.dims[dim] for dim in ds.dims})
        print("Lead times", ds["lead_time"].values)

        # Convert the 'time' coordinate in ds to datetime64 format
        ds["time"] = ds["time"].astype("datetime64[ns]")

        # Convert lead_time from nanoseconds to timedelta64[ns]
        base_time = ds["time"].values  # shape (n_time,)
        lead_timedelta = ds["lead_time"].values.astype("timedelta64[ns]")  # shape (n_lead_time,)
        # Broadcast to 2D: (time, lead_time)
        valid_timesteps = (base_time[:, None] + lead_timedelta[None, :]).flatten() 
        # Drop the old lead_time coordinate
        ds = ds.drop_vars("lead_time")

        # Assume ds has dimensions (time, lead_time, lat, lon) and only one time
        initial_time = str(ds["time"].values[0])  # Save the initial time as a string
        # Remove the time dimension by selecting the first (and only) time
        ds = ds.isel(time=0).drop_vars("time")
        # Add the initial time as a global attribute
        ds.attrs["initial_time"] = initial_time

        # Create valid_time by adding lead_timedelta to base_time
        ds = ds.rename({"lead_time": "valid_time"})
        # Assign valid_time as a coordinate
        ds = ds.assign_coords(valid_time=(("valid_time",), valid_timesteps))

        # only save the final time step
        if np.datetime64(final_datetime) in ds["valid_time"].values:
            ds = ds.sel(valid_time=[final_datetime])
            ds = ds[variables_to_select]
            ds.to_netcdf(results_out_fp, mode="w", format="NETCDF4")
            print(f"Results saved to {results_out_fp}")
        else:
            print(f"ERROR: final_datetime {final_datetime} not found in ds['valid_time']. No file saved.")


        #some cleanup
        torch.cuda.empty_cache()
        del model_package
        del model
        del io
        del ds
        gc.collect()
        time_3 = time.time()
        print(f"Epoch {n_epoch} done: {time_3 - time_2:.2f} seconds")


Data already preprocessed: /projectnb/eb-general/wade/sfno/inference_runs/Ian/Initialize_data/Initialize_2019_09_03T00_nsteps4.nc
Data loaded in 4765.38 seconds
Loading model from /projectnb/eb-general/shared_data/data/processed/FourCastNet_sfno/Checkpoints_SFNO/sfno_linear_74chq_sc3_layers8_edim384_dt6h_wstgl2/v0.1.0-seed999/ with checkpoint ckpt_mp0_epoch20.tar
[32m2025-12-18 14:05:07.821[0m | [1mINFO    [0m | [36mdeterministic_update[0m:[36mdeterministic[0m:[36m59[0m - [1mRunning simple workflow![0m
[32m2025-12-18 14:05:07.822[0m | [1mINFO    [0m | [36mdeterministic_update[0m:[36mdeterministic[0m:[36m66[0m - [1mInference device: cpu[0m
[32m2025-12-18 14:05:08.067[0m | [32m[1mSUCCESS [0m | [36mdeterministic_update[0m:[36mdeterministic[0m:[36m90[0m - [32m[1mFetched data from DataArrayFile[0m
[32m2025-12-18 14:05:08.144[0m | [1mINFO    [0m | [36mdeterministic_update[0m:[36mdeterministic[0m:[36m120[0m - [1mInference starting![0m


Running inference: 100%|██████████| 5/5 [08:22<00:00, 100.46s/it]

[32m2025-12-18 14:13:30.442[0m | [32m[1mSUCCESS [0m | [36mdeterministic_update[0m:[36mdeterministic[0m:[36m144[0m - [32m[1mInference complete[0m






Dataset times: ['2019-09-03T00:00:00.000000000']
Dataset dimensions: {'time': 1, 'lead_time': 5, 'lat': 721, 'lon': 1440}
Lead times [ 0  6 12 18 24]
Results saved to /projectnb/eb-general/wade/sfno/inference_runs/sandbox/Experiment0/2019_09_04/Checkpoint20_2019_09_03T00_nsteps4.nc
Epoch 20 done: 545.60 seconds


In [3]:
# Compare the output of my forecast to Becca's saved forecast
my_forecast ='/projectnb/eb-general/wade/sfno/inference_runs/sandbox/Experiment0/2022_09_29/Checkpoint70_2022_09_24T00_nsteps20.nc' 
beccas_forecast = '/projectnb/eb-general/wade/sfno/inference_runs/Ian/leadtime_fivedays/Checkpoint70_2022_09_24T00_nsteps20.nc'
my_ds = xr.open_dataset(my_forecast)
beccas_ds = xr.open_dataset(beccas_forecast)
print(my_ds)
print(beccas_ds)

<xarray.Dataset> Size: 12MB
Dimensions:     (valid_time: 1, lat: 721, lon: 1440)
Coordinates:
  * valid_time  (valid_time) datetime64[ns] 8B 2022-09-29
  * lat         (lat) float64 6kB 90.0 89.75 89.5 89.25 ... -89.5 -89.75 -90.0
  * lon         (lon) float64 12kB 0.0 0.25 0.5 0.75 ... 359.0 359.2 359.5 359.8
Data variables:
    msl         (valid_time, lat, lon) float32 4MB ...
    u10m        (valid_time, lat, lon) float32 4MB ...
    v10m        (valid_time, lat, lon) float32 4MB ...
Attributes:
    initial_time:  2022-09-24T00:00:00.000000000
<xarray.Dataset> Size: 12MB
Dimensions:     (valid_time: 1, lat: 721, lon: 1440)
Coordinates:
  * valid_time  (valid_time) datetime64[ns] 8B 2022-09-29
  * lat         (lat) float64 6kB 90.0 89.75 89.5 89.25 ... -89.5 -89.75 -90.0
  * lon         (lon) float64 12kB 0.0 0.25 0.5 0.75 ... 359.0 359.2 359.5 359.8
Data variables:
    msl         (valid_time, lat, lon) float32 4MB ...
    u10m        (valid_time, lat, lon) float32 4MB ...
    v10m

To check what epoch a checkpoint is (e.g. if epoch number is not in the filepath):

In [38]:
import torch
import os

dir='/projectnb/eb-general/shared_data/data/processed/FourCastNet_sfno/Checkpoints_SFNO/sfno_linear_74chq_sc3_layers8_edim384_dt6h_wstgl2/v0.1.0-seed999/training_checkpoints/' # step 1 of training (epochs 1-70)
dir2='/projectnb/eb-general/shared_data/data/processed/FourCastNet_sfno/Checkpoints_SFNO/multistep_sfno_linear_74chq_sc3_layers8_edim384_dt6h_wstgl2/v0.1.0-seed999-multistep2/training_checkpoints/' # step 2 of training (epochs 71-90

# List of files to check
files_to_check = [dir + "best_ckpt_mp0.tar", 
                  dir + "ckpt_mp0.tar",
                  dir2 +  "best_ckpt_mp0.tar",
                    dir2 + "ckpt_mp0.tar"
                    ]
                
for filename in files_to_check:
    # Load the checkpoint
    # map_location='cpu' allows you to inspect this even without a GPU
    # weights_only=False allows loading the full dictionary structure
    checkpoint = torch.load(filename, map_location='cpu', weights_only=False)
    epoch = checkpoint.get('epoch', 'N/A')
    
    print(f"{filename:<25}")
    print(f'epoch: {str(epoch)}')
    print()

/projectnb/eb-general/shared_data/data/processed/FourCastNet_sfno/Checkpoints_SFNO/sfno_linear_74chq_sc3_layers8_edim384_dt6h_wstgl2/v0.1.0-seed999/training_checkpoints/best_ckpt_mp0.tar
epoch: 70

/projectnb/eb-general/shared_data/data/processed/FourCastNet_sfno/Checkpoints_SFNO/sfno_linear_74chq_sc3_layers8_edim384_dt6h_wstgl2/v0.1.0-seed999/training_checkpoints/ckpt_mp0.tar
epoch: 70

/projectnb/eb-general/shared_data/data/processed/FourCastNet_sfno/Checkpoints_SFNO/multistep_sfno_linear_74chq_sc3_layers8_edim384_dt6h_wstgl2/v0.1.0-seed999-multistep2/training_checkpoints/best_ckpt_mp0.tar
epoch: 19

/projectnb/eb-general/shared_data/data/processed/FourCastNet_sfno/Checkpoints_SFNO/multistep_sfno_linear_74chq_sc3_layers8_edim384_dt6h_wstgl2/v0.1.0-seed999-multistep2/training_checkpoints/ckpt_mp0.tar
epoch: 20



the above output shows that the best_ckpt is the best ckpt *for each phase*, and ckpt_mp0 is the *final* ckpt for that phase
- checkpoint numbering resets for each phase so phase 2 is numbered 1-20


# Original script:

In [None]:
import os
import subprocess
from dotenv import load_dotenv

from earth2studio.io import ZarrBackend
from SFNO_update import SFNO
import earth2studio.data as data
from earth2studio.models.auto import Package
from utils import filename_to_year, datetime_range, open_hdf5
from deterministic_update import deterministic

from datetime import datetime, timedelta
import json
import xarray as xr
from typing import List
import shutil
import sys
import gc
import numpy as np
import time

import torch

# Check if CUDA (GPU support) is available
is_available = torch.cuda.is_available()
print(f"Is CUDA available? {is_available}")

if is_available:
    # Get the number of available GPUs
    gpu_count = torch.cuda.device_count()
    print(f"Number of GPUs available: {gpu_count}")

    # Get the ID of the current GPU
    current_gpu = torch.cuda.current_device()
    print(f"Current GPU ID: {current_gpu}")

    # Get the name of the current GPU
    gpu_name = torch.cuda.get_device_name(current_gpu)
    print(f"Current GPU Name: {gpu_name}")

    print(f"Memory (VRAM):      {torch.cuda.get_device_properties(current_gpu).total_memory / 1e9:.2f} GB")
else:
    print("CUDA is not available. Running on CPU.")

time_start = time.time()

############# Double check these before running the script #############
#select start datetime and n_steps, each n_step = 6hrs
start_datetime = "2021-09-20T00:00:00" # "2021_09_20T00:00:00"
variables_to_select = ['tcwv'] #Only save selected variables - it slows down inference SIGNIFICANTLY to save all 74 variables
experiment_number = 0
n_steps = 20  #number of 6hr steps to forecast

boring = False
ema = False

# Create the inference name based on the start datetime and number of steps
inference_name = datetime.fromisoformat(start_datetime).strftime("%Y_%m_%dT%H")+'_nsteps'+str(n_steps)
data_create_fp = "/projectnb/eb-general/wade/sfno/inference_runs/Ian/Initialize_data/Initialize_"+inference_name+".nc"

# Calculate the final datetime based from the start datetime and number of steps
final_datetime = (datetime.fromisoformat(start_datetime) + timedelta(hours = int(n_steps*6))).isoformat() 

# Directories
results_out_dir = f"/projectnb/eb-general/wade/sfno/inference_runs/sandbox/Experiment{str(experiment_number)}/{final_datetime[:10].replace('-', '_')}/"

############# Double check these before running the script #############


if os.path.exists(data_create_fp):
    print(f"Data already preprocessed: {data_create_fp}")
else:
    sys.exit(f"Data not found use Create_Initial_Data.ipynb to create: {data_create_fp}")

#make this xarray into a dataarray file for earth2studio
initial_data = data.DataArrayFile(data_create_fp)

time_1 = time.time()
print(f"Data loaded in {time_1 - time_start:.2f} seconds")


for n_epoch in np.arange(1,3): #70,1):
    time_2 = time.time()
    # if boring:
    #     # Create the final datetime string in the desired format
    #     
    # else:# Create the final datetime string in the desired format

    if ema:
        results_out_fp = results_out_dir+f"EMA_Checkpoint{n_epoch}_{inference_name}.nc"
    else:
        results_out_fp = results_out_fp = results_out_dir+"/Checkpoint"+str(n_epoch)+"_"+inference_name+'.nc' 
    
    # Check if the results file already exists
    if os.path.exists(results_out_fp):
        print(f"Results file {results_out_fp} already exists. Skipping to next epoch.")
        continue  # Skip the rest of the loop and go to the next iteration
    else:
        os.makedirs(os.path.dirname(results_out_fp), exist_ok=True)

        load_dotenv()  

        # Make temporary folder with all the metadata in it.
        src_dir = "/projectnb/eb-general/shared_data/data/processed/FourCastNet_sfno/Checkpoints_SFNO/sfno_linear_74chq_sc3_layers8_edim384_dt6h_wstgl2/v0.1.0-seed999/"
        # Load the model package from storage
        model_package = Package(src_dir, cache = False)
        model = SFNO.load_model(model_package, checkpoint_name = 'ckpt_mp0_epoch'+str(n_epoch)+'.tar', EMA = ema)

        # Create the IO handler, store in memory
        io = ZarrBackend()
        
        with torch.no_grad():
            # run inference
            io = deterministic([start_datetime], n_steps, model, initial_data, io, variables_list=variables_to_select)

        print(io.root.tree())


        # save results to netcdf
        # Open the Zarr group from the in-memory store using xarray
        ds = xr.open_zarr(io.root.store)

        # Convert the 'time' coordinate in ds to datetime64 format
        ds["time"] = ds["time"].astype("datetime64[ns]")

        # Convert lead_time from nanoseconds to timedelta64[ns]
        base_time = ds["time"].values  # shape (n_time,)
        lead_timedelta = ds["lead_time"].values.astype("timedelta64[ns]")  # shape (n_lead_time,)
        # Broadcast to 2D: (time, lead_time)
        valid_timesteps = (base_time[:, None] + lead_timedelta[None, :]).flatten() 
        # Drop the old lead_time coordinate
        ds = ds.drop_vars("lead_time")

        # Assume ds has dimensions (time, lead_time, lat, lon) and only one time
        initial_time = str(ds["time"].values[0])  # Save the initial time as a string
        # Remove the time dimension by selecting the first (and only) time
        ds = ds.isel(time=0).drop_vars("time")
        # Add the initial time as a global attribute
        ds.attrs["initial_time"] = initial_time

        # Create valid_time by adding lead_timedelta to base_time
        ds = ds.rename({"lead_time": "valid_time"})
        # Assign valid_time as a coordinate
        ds = ds.assign_coords(valid_time=(("valid_time",), valid_timesteps))

        # only save the final time step
        if np.datetime64(final_datetime) in ds["valid_time"].values:
            ds = ds.sel(valid_time=[final_datetime])
            ds = ds[variables_to_select]
            ds.to_netcdf(results_out_fp, mode="w", format="NETCDF4")
            print(f"Results saved to {results_out_fp}")
        else:
            print(f"ERROR: final_datetime {final_datetime} not found in ds['valid_time']. No file saved.")


        #some cleanup
        torch.cuda.empty_cache()
        del model_package
        del model
        del io
        del ds
        gc.collect()
        time_3 = time.time()
        print(f"Epoch {n_epoch} done: {time_3 - time_2:.2f} seconds")



#     for n_epoch in np.arange(36,71,1):
#         time_2 = time.time()
#         if boring:
#             # Create the final datetime string in the desired format
#             results_out_fp = "/barnes-engr-scratch2/C837824079/Experiment"+str(experiment_number)+"/Forecasts_Boring/"+final_datetime[:10].replace("-", "_")+"/Checkpoint"+str(n_epoch)+"_"+inference_name+'.nc'
#         else:# Create the final datetime string in the desired format
#             if ema:
#                 results_out_fp = f"/barnes-engr-scratch2/C837824079/Experiment{str(experiment_number)}/Forecast/EMA_9/Checkpoint{n_epoch}_{inference_name}.nc"         
#             else:
#                 results_out_fp = "/projectnb/eb-general/rbaiman/SFNO/Example_Inference/Example_Forecast/Checkpoint"+str(n_epoch)+"_"+inference_name+'.nc'

#         # Check if the results file already exists
#         if os.path.exists(results_out_fp):
#             print(f"Results file {results_out_fp} already exists. Skipping to next epoch.")
#             continue  # Skip the rest of the loop and go to the next iteration
#         else:
#             os.makedirs(os.path.dirname(results_out_fp), exist_ok=True)

#             load_dotenv()  

#             # Make temporary folder with all the metadata in it.
#             src_dir = "/projectnb/eb-general/shared_data/data/processed/FourCastNet_sfno/Checkpoints_SFNO/sfno_linear_74chq_sc3_layers8_edim384_dt6h_wstgl2/v0.1.0-seed999/"

#             # Load the model package from storage
#             model_package = Package(src_dir, cache = False)
#             model = SFNO.load_model(model_package, checkpoint_name = 'ckpt_mp0_epoch'+str(n_epoch)+'.tar', EMA = ema)

#             # Create the IO handler, store in memory
#             io = ZarrBackend()
            
#             with torch.no_grad():
#                 # run inference
#                 io = deterministic([start_datetime], n_steps, model, initial_data, io, variables_list=variables_to_select)

#             print(io.root.tree())


#             # save results to netcdf
#             # Open the Zarr group from the in-memory store using xarray
#             ds = xr.open_zarr(io.root.store)

#             # Convert the 'time' coordinate in ds to datetime64 format
#             ds["time"] = ds["time"].astype("datetime64[ns]")

#             # Convert lead_time from nanoseconds to timedelta64[ns]
#             base_time = ds["time"].values  # shape (n_time,)
#             lead_timedelta = ds["lead_time"].values.astype("timedelta64[ns]")  # shape (n_lead_time,)
#             # Broadcast to 2D: (time, lead_time)
#             valid_timesteps = (base_time[:, None] + lead_timedelta[None, :]).flatten() 
#             # Drop the old lead_time coordinate
#             ds = ds.drop_vars("lead_time")

#             # Assume ds has dimensions (time, lead_time, lat, lon) and only one time
#             initial_time = str(ds["time"].values[0])  # Save the initial time as a string
#             # Remove the time dimension by selecting the first (and only) time
#             ds = ds.isel(time=0).drop_vars("time")
#             # Add the initial time as a global attribute
#             ds.attrs["initial_time"] = initial_time

#             # Create valid_time by adding lead_timedelta to base_time
#             ds = ds.rename({"lead_time": "valid_time"})
#             # Assign valid_time as a coordinate
#             ds = ds.assign_coords(valid_time=(("valid_time",), valid_timesteps))

#             # only save the final time step
#             if np.datetime64(final_datetime) in ds["valid_time"].values:
#                 ds = ds.sel(valid_time=[final_datetime])
#                 ds = ds[variables_to_select]
#                 ds.to_netcdf(results_out_fp, mode="w", format="NETCDF4")
#                 print(f"Results saved to {results_out_fp}")
#             else:
#                 print(f"ERROR: final_datetime {final_datetime} not found in ds['valid_time']. No file saved.")


#             #some cleanup
#             torch.cuda.empty_cache()
#             del model_package
#             del model
#             del io
#             del ds
#             gc.collect()
#             time_3 = time.time()
#             print(f"Epoch {n_epoch} done: {time_3 - time_2:.2f} seconds")


# for n_epoch in np.arange(1,21,1):
#     time_2 = time.time()
#     # Create the final datetime string in the desired format
#     if boring:
#         # Create the final datetime string in the desired format
#         results_out_fp = "/barnes-engr-scratch2/C837824079/Experiment"+str(experiment_number)+"/Forecasts_Boring/"+final_datetime[:10].replace("-", "_")+"/Checkpoint"+str(n_epoch+70)+"_"+inference_name+'.nc'
#     else:# Create the final datetime string in the desired format
#         if ema:
#             results_out_fp = f"/barnes-engr-scratch2/C837824079/Experiment{str(experiment_number)}/Forecast/EMA_9/Checkpoint{n_epoch+70}_{inference_name}.nc"
#         else:
#             results_out_fp = "/projectnb/eb-general/rbaiman/SFNO/Example_Inference/Example_Forecast/Checkpoint"+str(n_epoch+70)+"_"+inference_name+'.nc'

    
#     # Check if the results file already exists
#     if os.path.exists(results_out_fp):
#         print(f"Results file {results_out_fp} already exists. Skipping to next epoch.")
#         continue  # Skip the rest of the loop and go to the next iteration
#     else:
#         os.makedirs(os.path.dirname(results_out_fp), exist_ok=True)

#         load_dotenv()  

#         # Make temporary folder with all the metadata in it.
#         src_dir = "/projectnb/eb-general/shared_data/data/processed/FourCastNet_sfno/Checkpoints_SFNO/multistep_sfno_linear_74chq_sc3_layers8_edim384_dt6h_wstgl2/v0.1.0-seed999-multistep2/"

#         # Load the model package from storage
#         model_package = Package(src_dir, cache = False)
#         model = SFNO.load_model(model_package, checkpoint_name = 'ckpt_mp0_epoch'+str(n_epoch)+'.tar', EMA = ema)

#         # Create the IO handler, store in memory
#         io = ZarrBackend()

#         print(f"Running inference for {inference_name}")
#         with torch.no_grad():
#             # run inference
#             io = deterministic([start_datetime], n_steps, model, initial_data, io, variables_list=variables_to_select)

#         # print(io.root.tree())

#         # save results to netcdf
#         # Open the Zarr group from the in-memory store using xarray
#         ds = xr.open_zarr(io.root.store)

#         # Convert the 'time' coordinate in ds to datetime64 format
#         ds["time"] = ds["time"].astype("datetime64[ns]")

#         # Convert lead_time from nanoseconds to timedelta64[ns]
#         base_time = ds["time"].values  # shape (n_time,)
#         lead_timedelta = ds["lead_time"].values.astype("timedelta64[ns]")  # shape (n_lead_time,)
#         # Broadcast to 2D: (time, lead_time)
#         valid_timesteps = (base_time[:, None] + lead_timedelta[None, :]).flatten() 
#         # Drop the old lead_time coordinate
#         ds = ds.drop_vars("lead_time")

#         # Assume ds has dimensions (time, lead_time, lat, lon) and only one time
#         initial_time = str(ds["time"].values[0])  # Save the initial time as a string
#         # Remove the time dimension by selecting the first (and only) time
#         ds = ds.isel(time=0).drop_vars("time")
#         # Add the initial time as a global attribute
#         ds.attrs["initial_time"] = initial_time

#         # Create valid_time by adding lead_timedelta to base_time
#         ds = ds.rename({"lead_time": "valid_time"})
#         # Assign valid_time as a coordinate
#         ds = ds.assign_coords(valid_time=(("valid_time",), valid_timesteps))

#         # only save the final time step
#         if np.datetime64(final_datetime) in ds["valid_time"].values:
#             ds = ds.sel(valid_time=[final_datetime])
#             ds = ds[variables_to_select]
#             ds.to_netcdf(results_out_fp, mode="w", format="NETCDF4")
#             print(f"Results saved to {results_out_fp}")
#         else:
#             print(f"ERROR: final_datetime {final_datetime} not found in ds['valid_time']. No file saved.")


#         #some cleanup
#         torch.cuda.empty_cache()
#         del model_package
#         del model
#         del io
#         del ds
#         gc.collect()
#         time_3 = time.time()
#         print(f"Epoch {n_epoch+70} done: {time_3 - time_2:.2f} seconds")

