In [None]:
!pip install cartopy torch-geometric

In [None]:
import os
import gc
import torch
import torch.nn as nn
import numpy as np
import xarray as xr
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from torch_geometric.loader import DataLoader as PyGDataLoader
from torch_geometric.nn import GATConv
from torch_geometric.data import Data, Dataset
import matplotlib.pyplot as plt
from tqdm import tqdm
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from scipy.stats import pearsonr

# Memory optimization
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# SphericalDataset class
class SphericalDataset(Dataset):
    def __init__(self, file_path, transform=None, normalize=True):
        self.transform = transform
        print(f"Loading data from {file_path}")
        self.ds = xr.open_dataset(file_path)
        self.lats = self.ds.lat.values
        self.lons = self.ds.lon.values
        self.times = self.ds.time.values
        self.tas = self.ds.tas.values
        self.nlat = len(self.lats)
        self.nlon = len(self.lons)
        lon_grid, lat_grid = np.meshgrid(self.lons, self.lats)
        self.grid_points = np.stack([lon_grid.flatten(), lat_grid.flatten()], axis=1)
        self.edge_index = self._create_graph_connections()
        if normalize:
            self._normalize_features()
        print(f"Dataset created with {len(self.times)} timesteps, grid size: {self.nlat}x{self.nlon}")
    
    def _normalize_features(self):
        data_reshaped = self.tas.reshape(len(self.times), -1)
        self.scaler = StandardScaler()
        self.scaler.fit(data_reshaped)
        normalized_data = self.scaler.transform(data_reshaped)
        self.tas = normalized_data.reshape(self.tas.shape)
        print("Data normalized")
    
    def _create_graph_connections(self):
        edges = []
        for i in range(self.nlat):
            for j in range(self.nlon):
                node_idx = i * self.nlon + j
                neighbors = []
                if i > 0:
                    neighbors.append((i-1) * self.nlon + j)
                if i < self.nlat - 1:
                    neighbors.append((i+1) * self.nlon + j)
                west_j = (j - 1) % self.nlon
                neighbors.append(i * self.nlon + west_j)
                east_j = (j + 1) % self.nlon
                neighbors.append(i * self.nlon + east_j)
                for neighbor in neighbors:
                    edges.append([node_idx, neighbor])
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
        return edge_index
    
    def __len__(self):
        return len(self.times) - 1
    
    def __getitem__(self, idx):
        x = self.tas[idx].reshape(-1, 1).astype(np.float32)
        y = self.tas[idx + 1].reshape(-1, 1).astype(np.float32)
        pos = torch.tensor(self.grid_points, dtype=torch.float)
        data = Data(x=torch.tensor(x, dtype=torch.float), y=torch.tensor(y, dtype=torch.float),
                    edge_index=self.edge_index, pos=pos)
        if self.transform:
            data = self.transform(data)
        return data

# SphericalGNN class
class SphericalGNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3, heads=4, dropout=0.1):
        super(SphericalGNN, self).__init__()
        self.num_layers = num_layers
        self.embedding = nn.Linear(input_dim, hidden_dim)
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        self.convs.append(GATConv(hidden_dim, hidden_dim, heads=heads, dropout=dropout))
        self.batch_norms.append(nn.BatchNorm1d(hidden_dim * heads))
        for _ in range(num_layers - 2):
            self.convs.append(GATConv(hidden_dim * heads, hidden_dim, heads=heads, dropout=dropout))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim * heads))
        self.convs.append(GATConv(hidden_dim * heads, hidden_dim, heads=1, dropout=dropout))
        self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        self.output_layer = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        self.act = nn.ELU()
    
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.embedding(x)
        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index)
            x = self.batch_norms[i](x)
            x = self.act(x)
            x = self.dropout(x)
        x = self.output_layer(x)
        return x

# Evaluation function (fixed)
def memory_efficient_evaluate(model, test_loader, criterion, dataset, device='cuda'):
    model.eval()
    test_losses = []
    all_predictions = []
    all_targets = []
    
    test_pbar = tqdm(test_loader, desc='Evaluating')
    
    with torch.no_grad():
        for batch in test_pbar:
            batch = batch.to(device)
            outputs = model(batch)
            loss = criterion(outputs, batch.y)
            all_predictions.append(outputs.cpu().numpy())
            all_targets.append(batch.y.cpu().numpy())
            test_losses.append(loss.item())
            test_pbar.set_postfix({'loss': f'{loss.item():.4f}'})
            del outputs, loss, batch
            torch.cuda.empty_cache()
    
    # Concatenate and reshape to (n_samples, n_nodes)
    predictions = np.concatenate(all_predictions)  # Shape: (n_samples, n_nodes, 1)
    targets = np.concatenate(all_targets)  # Shape: (n_samples, n_nodes, 1)
    predictions = predictions.reshape(-1, dataset.nlat * dataset.nlon)  # Shape: (n_samples, 192*288)
    targets = targets.reshape(-1, dataset.nlat * dataset.nlon)
    
    # Inverse-transform to Kelvin
    predictions = dataset.scaler.inverse_transform(predictions)
    targets = dataset.scaler.inverse_transform(targets)
    
    # Grid information
    nlat, nlon = dataset.nlat, dataset.nlon
    lats = dataset.lats
    lons = dataset.lons
    weights = np.cos(np.deg2rad(lats))[:, np.newaxis] * np.ones((1, nlon))
    weights = weights.flatten()
    
    # Standard metrics
    avg_test_loss = np.mean(test_losses)
    mse = np.mean((predictions - targets) ** 2)
    rmse = np.sqrt(mse)
    mae = np.mean(np.abs(predictions - targets))
    
    # Weighted MAE
    wmae = np.sum(weights * np.abs(predictions - targets), axis=1) / np.sum(weights)
    wmae = np.mean(wmae)
    
    # Spatial Correlation Coefficient (SCC)
    scc = [pearsonr(predictions[t], targets[t])[0] for t in range(predictions.shape[0])]
    avg_scc = np.mean(scc)
    
    # Temporal Correlation Coefficient (TCC)
    tcc = [pearsonr(predictions[:, i], targets[:, i])[0] for i in range(predictions.shape[1]) if not np.isnan(pearsonr(predictions[:, i], targets[:, i])[0])]
    avg_tcc = np.mean(tcc)
    
    # Anomaly Correlation Coefficient (ACC)
    anomalies_true = targets - np.mean(targets, axis=0, keepdims=True)
    anomalies_pred = predictions - np.mean(predictions, axis=0, keepdims=True)
    acc_num = np.sum(anomalies_true * anomalies_pred, axis=1)
    acc_den = np.sqrt(np.sum(anomalies_true**2, axis=1) * np.sum(anomalies_pred**2, axis=1))
    acc = np.mean(acc_num / acc_den, where=acc_den > 0)
    
    # RMSE of Gradients (latitude direction)
    grad_true = np.diff(targets, axis=1)
    grad_pred = np.diff(predictions, axis=1)
    rmse_grad = np.sqrt(np.mean((grad_true - grad_pred)**2))
    
    # Energy Conservation Error
    energy_true = np.sum(weights * targets, axis=1)
    energy_pred = np.sum(weights * predictions, axis=1)
    ece = np.mean(np.abs(energy_true - energy_pred) / np.abs(energy_true))
    
    # Reshape for spatial maps
    abs_errors = np.mean(np.abs(predictions - targets), axis=0).reshape(nlat, nlon)
    tcc_map = np.array([pearsonr(predictions[:, i], targets[:, i])[0]
                        for i in range(predictions.shape[1])]).reshape(nlat, nlon)
    tcc_map = np.where(np.isnan(tcc_map), 0, tcc_map)
    
    print(f'Test Loss: {avg_test_loss:.4f}')
    print(f'MSE: {mse:.4f} K², RMSE: {rmse:.4f} K, MAE: {mae:.4f} K')
    print(f'Weighted MAE: {wmae:.4f} K')
    print(f'Average Spatial Correlation: {avg_scc:.4f}')
    print(f'Average Temporal Correlation: {avg_tcc:.4f}')
    print(f'Anomaly Correlation: {acc:.4f}')
    print(f'RMSE of Gradients: {rmse_grad:.4f} K/node')
    print(f'Energy Conservation Error: {ece:.4f}')
    
    results = {
        'test_loss': avg_test_loss,
        'mse': mse,
        'rmse': rmse,
        'mae': mae,
        'wmae': wmae,
        'avg_scc': avg_scc,
        'avg_tcc': avg_tcc,
        'acc': acc,
        'rmse_grad': rmse_grad,
        'ece': ece,
        'abs_errors': abs_errors,
        'tcc_map': tcc_map,
        'predictions': predictions,
        'targets': targets,
        'lats': lats,
        'lons': lons
    }
    
    return results

# Visualization function
def plot_evaluation_results(results, output_dir='/kaggle/working/plots'):
    os.makedirs(output_dir, exist_ok=True)
    
    lats = results['lats']
    lons = results['lons']
    abs_errors = results['abs_errors']
    tcc_map = results['tcc_map']
    predictions = results['predictions']
    targets = results['targets']
    
    projection = ccrs.PlateCarree()
    
    plt.figure(figsize=(12, 6))
    ax = plt.axes(projection=projection)
    ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.BORDERS, linestyle=':')
    c = ax.pcolormesh(lons, lats, abs_errors, transform=projection, cmap='Reds', vmin=0)
    plt.colorbar(c, label='Mean Absolute Error (K)')
    ax.set_title('Spatial Distribution of Prediction Errors')
    plt.savefig(os.path.join(output_dir, 'error_map.png'), bbox_inches='tight', dpi=300)
    plt.close()
    
    plt.figure(figsize=(12, 6))
    ax = plt.axes(projection=projection)
    ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.BORDERS, linestyle=':')
    c = ax.pcolormesh(lons, lats, tcc_map, transform=projection, cmap='viridis', vmin=-1, vmax=1)
    plt.colorbar(c, label='Temporal Correlation Coefficient')
    ax.set_title('Temporal Correlation of Predictions')
    plt.savefig(os.path.join(output_dir, 'tcc_map.png'), bbox_inches='tight', dpi=300)
    plt.close()
    
    weights = np.cos(np.deg2rad(lats))[:, np.newaxis] * np.ones((1, len(lons)))
    weights = weights.flatten()
    global_mean_true = np.sum(targets * weights, axis=1) / np.sum(weights)
    global_mean_pred = np.sum(predictions * weights, axis=1) / np.sum(weights)
    anomalies_true = global_mean_true - np.mean(global_mean_true)
    anomalies_pred = global_mean_pred - np.mean(global_mean_pred)
    
    plt.figure(figsize=(10, 4))
    plt.plot(anomalies_true, label='True Anomalies', alpha=0.7)
    plt.plot(anomalies_pred, label='Predicted Anomalies', alpha=0.7)
    plt.xlabel('Timestep (Month)')
    plt.ylabel('Global Mean Temperature Anomaly (K)')
    plt.title('Global Temperature Anomalies (1850-2014)')
    plt.legend()
    plt.savefig(os.path.join(output_dir, 'anomaly_series.png'), bbox_inches='tight', dpi=300)
    plt.close()
    
    errors = predictions - targets
    plt.figure(figsize=(8, 4))
    plt.hist(errors.flatten(), bins=50, density=True, alpha=0.7)
    plt.xlabel('Prediction Error (K)')
    plt.ylabel('Density')
    plt.title('Distribution of Prediction Errors')
    plt.savefig(os.path.join(output_dir, 'error_histogram.png'), bbox_inches='tight', dpi=300)
    plt.close()

# Main function to run evaluation
def run_evaluation(dataset_path, model_path, batch_size=8, test_size=0.15, device='cuda'):
    gc.collect()
    torch.cuda.empty_cache()
    
    dataset = SphericalDataset(dataset_path, normalize=True)
    
    num_samples = len(dataset)
    indices = list(range(num_samples))
    _, test_indices = train_test_split(indices, test_size=test_size, random_state=42)
    
    test_loader = PyGDataLoader(
        [dataset[i] for i in test_indices],
        batch_size=batch_size,
        shuffle=False
    )
    
    checkpoint = torch.load(model_path, map_location=device)
    config = checkpoint.get('config', {
        'hidden_dim': 32,
        'num_layers': 2,
        'attention_heads': 2,
        'dropout': 0.2
    })
    
    model = SphericalGNN(
        input_dim=1,
        hidden_dim=config['hidden_dim'],
        output_dim=1,
        num_layers=config['num_layers'],
        heads=config['attention_heads'],
        dropout=config['dropout']
    ).to(device)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    criterion = nn.MSELoss()
    
    results = memory_efficient_evaluate(model, test_loader, criterion, dataset, device)
    plot_evaluation_results(results)
    
    torch.save(results, '/kaggle/working/evaluation_results.pt')
    
    return results

In [None]:
dataset_path = '/kaggle/input/cimp6-fraction/cmip6_tas_cesm2_historical.nc'
model_path = '/kaggle/input/gnn-model/spherical_gnn_model.pt'
results = run_evaluation(dataset_path, model_path, batch_size=8, test_size=0.15, device=device)