# Prithvi WxC Gravity Wave: Model Fine Tuning and Inference using TerraTorch

In [None]:
!pip install terratorch==0.99.9 huggingface_hub PrithviWxC

In [None]:
!pip install  git+https://github.com/romeokienzler/gravity-wave-finetuning.git

In [None]:
import terratorch # this import is needed to initialize TT's factories
from lightning.pytorch import Trainer
import os
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from terratorch.models.wxc_model_factory import WxCModelFactory
from terratorch.tasks.wxc_task import WxCTask
import torch.distributed as dist

In [None]:
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355' 

if dist.is_initialized():
    dist.destroy_process_group()

dist.init_process_group(
    backend='gloo',
    init_method='env://', 
    rank=0,
    world_size=1
)

In [None]:
hf_hub_download(
    repo_id="ibm-nasa-geospatial/Prithvi-WxC-1.0-2300m-gravity-wave-parameterization",
    filename=f"magnet-flux-uvtp122-epoch-99-loss-0.1022.pt",
    local_dir=".",
)
hf_hub_download(
    repo_id="ibm-nasa-geospatial/Prithvi-WxC-1.0-2300m-gravity-wave-parameterization",
    filename=f"config.yaml",
    local_dir=".",
)

In [None]:
hf_hub_download(
    repo_id="ibm-nasa-geospatial/gravity-wave-parameterization",
    repo_type='dataset',
    filename=f"wxc_input_u_v_t_p_output_theta_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling05.nc",
    local_dir=".",
)

In [None]:
from prithviwxc.gravitywave.datamodule import ERA5DataModule

model_args = {
    "in_channels": 1280,
    "input_size_time": 1,
    "n_lats_px": 64,
    "n_lons_px": 128,
    "patch_size_px": [2, 2],
    "mask_unit_size_px": [8, 16],
    "mask_ratio_inputs": 0.5,
    "embed_dim": 2560,
    "n_blocks_encoder": 12,
    "n_blocks_decoder": 2,
    "mlp_multiplier": 4,
    "n_heads": 16,
    "dropout": 0.0,
    "drop_path": 0.05,
    "parameter_dropout": 0.0,
    "residual": "none",
    "masking_mode": "both",
    "decoder_shifting": False,
    "positional_encoding": "absolute",
    "checkpoint_encoder": [3, 6, 9, 12, 15, 18, 21, 24],
    "checkpoint_decoder": [1, 3],
    "in_channels_static": 3,
    "input_scalers_mu": torch.tensor([0] * 1280),
    "input_scalers_sigma": torch.tensor([1] * 1280),
    "input_scalers_epsilon": 0,
    "static_input_scalers_mu": torch.tensor([0] * 3),
    "static_input_scalers_sigma": torch.tensor([1] * 3),
    "static_input_scalers_epsilon": 0,
    "output_scalers": torch.tensor([0] * 1280),
    #"encoder_hidden_channels_multiplier" : [1, 2, 4, 8],
    #"encoder_num_encoder_blocks" : 4,
    #"decoder_hidden_channels_multiplier" : [(16, 8), (12, 4), (6, 2), (3, 1)],
    #"decoder_num_decoder_blocks" : 4,
    "aux_decoders": "unetpincer",
    "backbone": "prithviwxc",
    "skip_connection": "True"
}
task = WxCTask('WxCModelFactory', model_args=model_args, mode='eval')

In [None]:
trainer = Trainer(
    max_epochs=1,
    limit_predict_batches=1,
)
dm = ERA5DataModule(train_data_path='.', valid_data_path='.')
type(dm)

In [None]:
results = trainer.predict(model=task, datamodule=dm, return_predictions=True)
results

In [None]:
task = WxCTask('WxCModelFactory', model_args=model_args, mode='train')

In [None]:
results2 = trainer.fit(model=task, datamodule=dm)
results2

In [None]:
dist.destroy_process_group()