# Prithvi WxC Gravity Wave: Model Fine Tuning and Inference using TerraTorch

In [1]:
!pip install -U git+https://github.com/romeokienzler/terratorch.git@201


In [2]:
!pip install -U albumentations # fix until https://github.com/IBM/terratorch/issues/164 is solved

In [None]:
!pip install -U git+https://github.com/romeokienzler/gravity-wave-finetuning.git


In [4]:
!pip install huggingface_hub

In [None]:
import terratorch # this import is needed to initialize TT's factories
from lightning.pytorch import Trainer
import os
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from terratorch.models.wxc_model_factory import WxCModelFactory
import torch.distributed as dist

In [6]:
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355' 

if dist.is_initialized():
    dist.destroy_process_group()

dist.init_process_group(
    backend='gloo',
    init_method='env://', 
    rank=0,
    world_size=1
)

In [None]:
hf_hub_download(
    repo_id="Prithvi-WxC/Gravity_wave_Parameterization",
    filename=f"magnet-flux-uvtp122-epoch-99-loss-0.1022.pt",
    local_dir=".",
)

hf_hub_download(
    repo_id="Prithvi-WxC/Gravity_wave_Parameterization",
    filename=f"config.yaml",
    local_dir=".",
)

In [None]:
hf_hub_download(
    repo_id="Prithvi-WxC/Gravity_wave_Parameterization",
    repo_type='dataset',
    filename=f"wxc_input_u_v_t_p_output_theta_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling05.nc",
    local_dir=".",
)

In [9]:


from torchgeo.trainers import BaseTask
import torch.nn as nn

class WxCGravityWaveTask(BaseTask):
    def __init__(self, model_factory):
        self.model_factory = model_factory
        super().__init__()

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
    
    def configure_models(self):
        self.model = self.model_factory.build_model(backbone='gravitywave', aux_decoders=None)

In [None]:
from prithviwxc.gravitywave.datamodule import ERA5DataModule
task = WxCGravityWaveTask(WxCModelFactory())

In [None]:
trainer = Trainer(
    max_epochs=1,
)
dm = ERA5DataModule(train_data_path='.', valid_data_path='.')
type(dm)

In [15]:
results = trainer.predict(model=task, datamodule=dm, return_predictions=True)

In [5]:
dist.destroy_process_group()