In [None]:
!pip install torch-geometric

Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m19.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.6.1


In [None]:
import os
import numpy as np
import xarray as xr
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops
from torch_geometric.nn import GCNConv
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
from sklearn.metrics import mean_squared_error, mean_absolute_error


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Define Constants
INPUT_DAYS = 60   # Historical input
OUTPUT_DAYS = 14  # Forecasting horizon (10,14,44)
NUM_VARIABLES = 10 # Match dataset variables
GRID_SIZE = 14  # 14x14 grid (196 regions)
VARIABLE_NAMES = ['z500', 't850', 'e', 'evavt', 'lai_lv', 'pev', 'swvl2', 'swvl3', 't2m', 'tp']

Using device: cuda


In [None]:
# Data Loading and Processing
class ClimateDataset(Dataset):
    def __init__(self, filename, input_days=INPUT_DAYS, output_days=OUTPUT_DAYS):
        self.data = xr.open_dataset(filename)
        self.input_days = input_days
        self.output_days = output_days
        self.num_timesteps = len(self.data.time)
        self.valid_indices = list(range(self.num_timesteps - (input_days + output_days) + 1))

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        start_idx = self.valid_indices[idx]
        if start_idx + self.input_days + self.output_days > self.num_timesteps:
            return None
        input_data = np.stack([self.data[var].isel(time=slice(start_idx, start_idx + self.input_days)).values for var in VARIABLE_NAMES], axis=0)
        target_data = np.stack([self.data[var].isel(time=slice(start_idx + self.input_days, start_idx + self.input_days + self.output_days)).values for var in VARIABLE_NAMES], axis=0)
        return {'input': torch.FloatTensor(input_data), 'target': torch.FloatTensor(target_data)}

In [None]:
# Graph Construction
def create_grid_graph(height, width):
    edge_list = []
    for h in range(height):
        for w in range(width):
            node_idx = h * width + w
            if w < width - 1: edge_list.append([node_idx, node_idx + 1])
            if h < height - 1: edge_list.append([node_idx, node_idx + width])
    edge_index = torch.tensor(edge_list).t().contiguous()
    return add_self_loops(edge_index)[0]

# Graph Neural Network Layers
class GraphCastLayer(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GraphCastLayer, self).__init__(aggr='add')
        self.lin = nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=self.lin(x))

In [None]:
class GraphCast(nn.Module):
    def __init__(self):
        super(GraphCast, self).__init__()
        self.encoder = nn.Conv2d(NUM_VARIABLES * INPUT_DAYS, 128, kernel_size=3, padding=1)
        self.graph_layer = GCNConv(128, 128)
        self.temporal_attention = nn.MultiheadAttention(embed_dim=128, num_heads=8, batch_first=True)
        self.decoder = nn.Conv2d(128, NUM_VARIABLES, kernel_size=3, padding=1)
        self.edge_index = create_grid_graph(GRID_SIZE, GRID_SIZE).to(device)

    def forward(self, x):
        batch_size = x.size(0)
        x = x.reshape(batch_size, NUM_VARIABLES * INPUT_DAYS, GRID_SIZE, GRID_SIZE)
        x = self.encoder(x)
        x_graph = x.permute(0, 2, 3, 1).reshape(-1, 128)
        x_graph = self.graph_layer(x_graph, self.edge_index)
        x = x_graph.reshape(batch_size, GRID_SIZE, GRID_SIZE, 128)
        x_temp = x.reshape(batch_size, GRID_SIZE * GRID_SIZE, 128)
        x_temp, _ = self.temporal_attention(x_temp, x_temp, x_temp)
        x = x_temp.reshape(batch_size, 128, GRID_SIZE, GRID_SIZE)
        x = self.decoder(x)
        x = x.unsqueeze(2).repeat(1, 1, OUTPUT_DAYS, 1, 1)
        return x


In [None]:
# Training and Evaluation Functions
def train_one_epoch(model, train_loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader, desc="Training"):
        optimizer.zero_grad()
        outputs = model(batch['input'].to(device))
        loss = criterion(outputs, batch['target'].to(device))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

def evaluate(model, val_loader, criterion):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            outputs = model(batch['input'].to(device))
            loss = criterion(outputs, batch['target'].to(device))
            total_loss += loss.item()
    return total_loss / len(val_loader)

# Train the Model
def train_model(model, train_loader, val_loader, num_epochs=30):
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.SmoothL1Loss()  # Huber loss for better robustness
    best_val_loss = float('inf')
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion)
        val_loss = evaluate(model, val_loader, criterion)
        print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')
    return best_val_loss

In [None]:
# Run the Training Pipeline
train_file = '/content/drive/MyDrive/drought_data/processed/drought_train_normalized_imp.nc'
val_file = '/content/drive/MyDrive/drought_data/processed/drought_val_normalized_imp.nc'
train_dataset = ClimateDataset(train_file)
val_dataset = ClimateDataset(val_file)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)
model = GraphCast().to(device)
best_val_loss = train_model(model, train_loader, val_loader)
print(f"Best Validation Loss: {best_val_loss:.4f}")

In [None]:
def calculate_rmse(model_path, val_loader):
    """
    Calculates RMSE for each variable from the trained model.
    """
    model = GraphCast().to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    total_rmse = {var: 0.0 for var in VARIABLE_NAMES}
    count = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Evaluating RMSE"):
            inputs = batch['input'].to(device)
            targets = batch['target'].to(device)
            outputs = model(inputs)
            outputs = outputs.cpu().numpy()
            targets = targets.cpu().numpy()

            for i, var_name in enumerate(VARIABLE_NAMES):
                rmse = np.sqrt(np.mean((outputs[:, i, :, :, :] - targets[:, i, :, :, :]) ** 2))
                total_rmse[var_name] += rmse

            count += 1

    # Average RMSE over batches
    for var_name in VARIABLE_NAMES:
        total_rmse[var_name] /= count

    # Print RMSE values
    print("\nRMSE for each variable:")
    for var_name, rmse in total_rmse.items():
        print(f"{var_name}: {rmse:.4f}")

    return total_rmse

# Compute RMSE from Best Model
calculate_rmse('/content/drive/MyDrive/Best_models/graphcast_30epoch_10days.pth', val_loader)


Evaluating RMSE: 100%|██████████| 252/252 [00:08<00:00, 31.24it/s]


RMSE for each variable:
z500: 0.2718
t850: 0.3970
e: 1.1004
evavt: 2.1707
lai_lv: 0.3907
pev: 1.1536
swvl2: 0.3972
swvl3: 0.3699
t2m: 0.0868
tp: 2.1686





{'z500': np.float32(0.27178913),
 't850': np.float32(0.3969592),
 'e': np.float32(1.100355),
 'evavt': np.float32(2.17069),
 'lai_lv': np.float32(0.39069992),
 'pev': np.float32(1.1536137),
 'swvl2': np.float32(0.39717546),
 'swvl3': np.float32(0.36988273),
 't2m': np.float32(0.086847365),
 'tp': np.float32(2.1685805)}

In [None]:
import os
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
from pathlib import Path


def load_normalization_parameters(climatology_dir="/content/drive/MyDrive/climatology"):
    """
    Load normalization parameters.
    """
    era5_mean = pd.read_csv(os.path.join(climatology_dir, "era5_mean.csv"))
    era5_sigma = pd.read_csv(os.path.join(climatology_dir, "era5_sigma.csv"))
    lra5_mean = pd.read_csv(os.path.join(climatology_dir, "lra5_mean.csv"))
    lra5_sigma = pd.read_csv(os.path.join(climatology_dir, "lra5_sigma.csv"))

    era5_mean_dict = dict(zip(era5_mean["param"], era5_mean["mean"]))
    era5_sigma_dict = dict(zip(era5_sigma["param"], era5_sigma["sigma"]))
    lra5_mean_dict = dict(zip(lra5_mean["param"], lra5_mean["mean"]))
    lra5_sigma_dict = dict(zip(lra5_sigma["param"], lra5_sigma["sigma"]))

    variable_mapping = {
        'z500': 'z-500', 't850': 't-850', 't2m': 't2m', 'swvl3': 'swvl3',
        'tp': 'tp', 'pev': 'pev', 'swvl2': 'swvl2', 'e': 'e',
        'evavt': 'evavt', 'lai_lv': 'lai_lv'
    }

    return {
        'era5_mean': era5_mean_dict, 'era5_sigma': era5_sigma_dict,
        'lra5_mean': lra5_mean_dict, 'lra5_sigma': lra5_sigma_dict,
        'variable_mapping': variable_mapping
    }

def get_normalization_params_for_variable(var_name, norm_params):
    """
    Retrieve normalization parameters for a given variable.
    """
    era5_mean_dict = norm_params['era5_mean']
    era5_sigma_dict = norm_params['era5_sigma']
    lra5_mean_dict = norm_params['lra5_mean']
    lra5_sigma_dict = norm_params['lra5_sigma']
    variable_mapping = norm_params['variable_mapping']

    if var_name in variable_mapping:
        mapped_var = variable_mapping[var_name]
        if mapped_var in era5_mean_dict:
            return era5_mean_dict[mapped_var], era5_sigma_dict[mapped_var]
        elif mapped_var in lra5_mean_dict:
            return lra5_mean_dict[mapped_var], lra5_sigma_dict[mapped_var]

    if var_name in lra5_mean_dict:
        return lra5_mean_dict[var_name], lra5_sigma_dict[var_name]

    print(f"Warning: No normalization parameters found for {var_name}")
    return None, None

def denormalize_tensor(normalized_data, variable_names, norm_params=None):
    """
    Denormalize a tensor using climatology parameters.
    """
    if norm_params is None:
        norm_params = load_normalization_parameters()

    is_tensor = torch.is_tensor(normalized_data)
    if is_tensor:
        data_np = normalized_data.detach().cpu().numpy()
    else:
        data_np = normalized_data.copy()

    denormalized_data = np.zeros_like(data_np)

    for i, var_name in enumerate(variable_names):
        mean, sigma = get_normalization_params_for_variable(var_name, norm_params)
        if mean is not None and sigma is not None:
            denormalized_data[:, i] = data_np[:, i] * sigma + mean

    return torch.tensor(denormalized_data, device=normalized_data.device, dtype=normalized_data.dtype) if is_tensor else denormalized_data

def compute_denormalized_rmse(model, val_loader, climatology_dir):
    """
    Compute RMSE on denormalized data.
    """
    model.eval()
    norm_params = load_normalization_parameters(climatology_dir)
    total_rmse = {var: 0.0 for var in VARIABLE_NAMES}
    count = 0

    with torch.no_grad():
        for batch in val_loader:
            inputs = batch['input'].to(device)
            targets = batch['target'].to(device)
            outputs = model(inputs)
            outputs = denormalize_tensor(outputs, VARIABLE_NAMES, norm_params)
            targets = denormalize_tensor(targets, VARIABLE_NAMES, norm_params)

            outputs = outputs.cpu().numpy()
            targets = targets.cpu().numpy()

            for i, var_name in enumerate(VARIABLE_NAMES):
                rmse = np.sqrt(np.mean((outputs[:, i, :, :, :] - targets[:, i, :, :, :]) ** 2))
                total_rmse[var_name] += rmse
            count += 1

    for var_name in VARIABLE_NAMES:
        total_rmse[var_name] /= count

    print("\nDenormalized RMSE Per Variable:")
    for var_name, rmse in total_rmse.items():
        print(f"  {var_name}: {rmse:.6f}")

    overall_rmse = np.mean(list(total_rmse.values()))
    print(f"\nOverall Mean RMSE: {overall_rmse:.6f}")

    return total_rmse, overall_rmse

# Compute Denormalized RMSE for GraphCast
best_model_path = "/content/drive/MyDrive/Best_models/graphcast_30epoch_44days.pth"
model.load_state_dict(torch.load(best_model_path, map_location=device))
compute_denormalized_rmse(model, val_loader, climatology_dir="/content/drive/MyDrive/climatology")



Denormalized RMSE Per Variable:
  z500: 103.674919
  t850: 7.000451
  e: 0.000958
  evavt: 0.000464
  lai_lv: 0.318253
  pev: 0.005300
  swvl2: 0.068230
  swvl3: 0.067477
  t2m: 17.887964
  tp: 0.006203

Overall Mean RMSE: 12.903023


({'z500': np.float32(103.67492),
  't850': np.float32(7.0004506),
  'e': np.float32(0.0009583947),
  'evavt': np.float32(0.00046412705),
  'lai_lv': np.float32(0.3182531),
  'pev': np.float32(0.0052996804),
  'swvl2': np.float32(0.06823022),
  'swvl3': np.float32(0.06747739),
  't2m': np.float32(17.887964),
  'tp': np.float32(0.0062031457)},
 np.float32(12.903023))