# Prithvi WxC Gravity Wave: Model Fine Tuning and Inference using TerraTorch

In [1]:
#!pip install -U ../../.

In [2]:
#!pip install -U albumentations # fix until https://github.com/IBM/terratorch/issues/164 is solved

In [3]:
#!pip install -U git+https://github.com/romeokienzler/gravity-wave-finetuning.git
#!pip install -U -e ../../../gravity-wave-finetuning/

In [4]:
#!pip install huggingface_hub

In [1]:
import terratorch # this import is needed to initialize TT's factories
from lightning.pytorch import Trainer
import os
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from terratorch.models.wxc_model_factory import WxCModelFactory
import torch.distributed as dist

  from .autonotebook import tqdm as notebook_tqdm
INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.20 (you have 1.4.10). Upgrade using: pip install --upgrade albumentations
  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)


In [2]:
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355' 

if dist.is_initialized():
    dist.destroy_process_group()

dist.init_process_group(
    backend='gloo',
    init_method='env://',  # Use environment variables for initialization
    rank=0,
    world_size=1
)

In [None]:
hf_hub_download(
    repo_id="Prithvi-WxC/Gravity_wave_Parameterization",
    filename=f"magnet-flux-uvtp122-epoch-99-loss-0.1022.pt",
    local_dir=".",
)

hf_hub_download(
    repo_id="Prithvi-WxC/Gravity_wave_Parameterization",
    filename=f"config.yaml",
    local_dir=".",
)

In [None]:
hf_hub_download(
    repo_id="Prithvi-WxC/Gravity_wave_Parameterization",
    repo_type='dataset',
    filename=f"wxc_input_u_v_t_p_output_theta_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling05.nc",
    local_dir=".",
)

In [3]:


from torchgeo.trainers import BaseTask
import torch.nn as nn

class WxCGravityWaveTask(BaseTask):
    def __init__(self, model_factory):
        self.model_factory = model_factory
        super().__init__()

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
    
    def configure_models(self):
        self.model = self.model_factory.build_model(backbone='gravitywave', aux_decoders=None)

In [4]:
from prithviwxc.gravitywave.datamodule import ERA5DataModule
task = WxCGravityWaveTask(WxCModelFactory())

Loading weights from magnet-flux-uvtp122-epoch-99-loss-0.1022.pt
Loaded weights


In [5]:
trainer = Trainer(
    max_epochs=1,
)
dm = ERA5DataModule(train_data_path='.', valid_data_path='.')
results = trainer.predict(model=task, datamodule=dm, return_predictions=True)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


ValueError: Expected a parent

In [5]:
dist.destroy_process_group()