In [None]:
pip install rasterio geopandas

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
import matplotlib.pyplot as plt
import os
import joblib
import math
from tqdm import tqdm
import datetime
import rasterio # Import the rasterio library

# --- Configuration (can be common or passed as args) ---
MODEL_NAME = "GAT_Transformer_Prediction"
LOOKBACK = 24
BATCH_SIZE = 128
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- Re-define necessary classes and functions from your training notebook ---
# These are crucial for the model and data loading/processing to work.

# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

# Graph Attention Layer
class GraphAttentionLayer(nn.Module):
    def __init__(self, in_features, out_features, dropout=0.2, alpha=0.2):
        super(GraphAttentionLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.dropout_val = dropout
        self.alpha = alpha
        self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.empty(size=(2 * out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
        self.leakyrelu = nn.LeakyReLU(self.alpha)
        self.dropout_layer = nn.Dropout(self.dropout_val)

    def forward(self, h, adj): # h is (B, N, F_in), adj is (N,N)
        Wh = torch.matmul(h, self.W)
        a_input = self._prepare_attention_input(Wh)
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(-1))
        zero_vec = -9e15 * torch.ones_like(e)
        attention_masked = torch.where(adj.unsqueeze(0) > 0, e, zero_vec)
        attention_softmax = F.softmax(attention_masked, dim=-1)
        attention_dropout = self.dropout_layer(attention_softmax)
        h_prime = torch.matmul(attention_dropout, Wh)
        return F.elu(h_prime)

    def _prepare_attention_input(self, Wh):
        B, N, _ = Wh.size()
        Wh_i = Wh.unsqueeze(2).expand(B, N, N, -1)
        Wh_j = Wh.unsqueeze(1).expand(B, N, N, -1)
        return torch.cat([Wh_i, Wh_j], dim=-1)

# MultiHeadGraphAttention
class MultiHeadGraphAttention(nn.Module):
    def __init__(self, in_features, out_features_per_head, n_heads, dropout=0.2, alpha=0.2, concat=True):
        super(MultiHeadGraphAttention, self).__init__()
        self.n_heads = n_heads
        self.concat = concat
        self.attentions = nn.ModuleList([GraphAttentionLayer(in_features, out_features_per_head, dropout, alpha) for _ in range(n_heads)])
        self.out_dim = out_features_per_head * n_heads if concat else out_features_per_head

    def forward(self, x, adj): # x is (B,N,F_in), adj is (N,N)
        head_outputs = [att(x, adj) for att in self.attentions]
        if self.concat:
            return torch.cat(head_outputs, dim=-1)
        else:
            return torch.mean(torch.stack(head_outputs, dim=-1), dim=-1)

# GAT+Transformer Model
class SpiPredictorGATTransformer(nn.Module):
    def __init__(self, input_size, d_model_gat=64, gat_heads=4,
                 d_model_transformer=64, transformer_heads=4, num_transformer_layers=2,
                 dim_feedforward_transformer=256, dropout_rate=0.2):
        super(SpiPredictorGATTransformer, self).__init__()
        self.dropout_rate = dropout_rate
        self.gat_input_proj = nn.Linear(input_size, d_model_gat)
        self.gat_attention = MultiHeadGraphAttention(
            in_features=d_model_gat,
            out_features_per_head=d_model_gat // gat_heads,
            n_heads=gat_heads, dropout=dropout_rate, concat=True
        )
        gat_output_dim = self.gat_attention.out_dim
        self.pos_encoder = PositionalEncoding(gat_output_dim, max_len=LOOKBACK + 10)
        transformer_encoder_layer = nn.TransformerEncoderLayer(
            d_model=gat_output_dim, nhead=transformer_heads, dim_feedforward=dim_feedforward_transformer,
            dropout=dropout_rate, batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(transformer_encoder_layer, num_layers=num_transformer_layers)
        self.fc = nn.Sequential(
            nn.Linear(gat_output_dim, 64),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(32, 1)
        )

    def forward(self, x, adj): # x: (B,L,F_in), adj: (L,L)
        x_proj = F.elu(self.gat_input_proj(x))
        x_gat_drop = F.dropout(x_proj, self.dropout_rate, training=self.training)
        x_gat_attended = self.gat_attention(x_gat_drop, adj)
        x_pos_encoded = self.pos_encoder(x_gat_attended)
        x_transformed = self.transformer_encoder(x_pos_encoded)
        x_last_step = x_transformed[:, -1, :]
        output = self.fc(x_last_step)
        return output

# TimeSeriesDataset for GAT (Optimized Adjacency)
class TimeSeriesDatasetGAT(Dataset):
    def __init__(self, features, targets, single_adj_matrix, scaler_X=None, scaler_y=None, is_train=True):
        original_shape = features.shape
        if is_train and scaler_X is None:
            self.scaler_X = StandardScaler()
            flattened = features.reshape(-1, original_shape[-1])
            scaled = self.scaler_X.fit_transform(flattened)
            self.features = scaled.reshape(original_shape)
        elif not is_train and scaler_X is not None:
            self.scaler_X = scaler_X
            flattened = features.reshape(-1, original_shape[-1])
            scaled = self.scaler_X.transform(flattened)
            self.features = scaled.reshape(original_shape)
        else:
            self.features = features
            self.scaler_X = scaler_X
        if is_train and scaler_y is None:
            self.scaler_y = StandardScaler()
            self.targets = self.scaler_y.fit_transform(targets)
        elif not is_train and scaler_y is not None:
            self.scaler_y = scaler_y
            self.targets = self.scaler_y.transform(targets)
        else:
            self.targets = targets
            self.scaler_y = scaler_y
        self.features = torch.tensor(self.features, dtype=torch.float32)
        self.targets = torch.tensor(self.targets, dtype=torch.float32)
        if not isinstance(single_adj_matrix, torch.Tensor):
            self.single_adj_matrix = torch.tensor(single_adj_matrix, dtype=torch.float32)
        else:
            self.single_adj_matrix = single_adj_matrix.to(dtype=torch.float32)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.single_adj_matrix, self.targets[idx]

    def get_scalers(self):
        return self.scaler_X, self.scaler_y

# Evaluation Function for GAT-based models
def evaluate_model_gat_based(model, test_loader, scaler_y, device, model_name="GAT-based Model"):
    model.eval()
    all_preds_s, all_targets_s = [], []
    with torch.no_grad():
        for batch_X, batch_adj_single, batch_y_s in tqdm(test_loader, desc=f"Evaluating {model_name}"):
            batch_X, batch_adj_single = batch_X.to(device), batch_adj_single.to(device)
            outputs_s = model(batch_X, batch_adj_single[0]) # Pass one adj
            all_preds_s.append(outputs_s.cpu().numpy())
            all_targets_s.append(batch_y_s.numpy())
    preds_s_np = np.vstack(all_preds_s)
    tgts_s_np = np.vstack(all_targets_s)
    if scaler_y:
        preds_o = scaler_y.inverse_transform(preds_s_np)
        tgts_o = scaler_y.inverse_transform(tgts_s_np)
    else:
        preds_o = preds_s_np
        tgts_o = tgts_s_np
    mse = mean_squared_error(tgts_o, preds_o)
    rmse = np.sqrt(mse)
    r2 = r2_score(tgts_o, preds_o)
    mae = mean_absolute_error(tgts_o, preds_o)
    print(f"\nOverall Evaluation metrics ({model_name}):\n MSE: {mse:.4f}, RMSE: {rmse:.4f}, MAE: {mae:.4f}, R²: {r2:.4f}")
    return preds_o, tgts_o, mse, rmse, r2, mae

# Plotting and Reconstruction Functions
def plot_all_test_samples(targets_original, predictions_original, model_name, save_dir='visualizations'):
    if not os.path.exists(save_dir): os.makedirs(save_dir)
    save_path = os.path.join(save_dir, f'all_test_samples_{model_name}.png')
    plt.figure(figsize=(12, 8)); plt.scatter(targets_original, predictions_original, alpha=0.3, color='blue', label='Samples')
    min_val = min(targets_original.min(), predictions_original.min()); max_val = max(targets_original.max(), predictions_original.max())
    plt.plot([min_val, max_val], [min_val, max_val], 'r--', label='y=x')
    corr = np.nan
    if len(targets_original.ravel()) > 1 and len(predictions_original.ravel()) > 1:
        z = np.polyfit(targets_original.ravel(), predictions_original.ravel(), 1); p = np.poly1d(z)
        sorted_targets = np.sort(targets_original.ravel())
        plt.plot(sorted_targets, p(sorted_targets), 'g-', label=f'Regression: y={z[0]:.3f}x+{z[1]:.3f}')
        corr = np.corrcoef(targets_original.ravel(), predictions_original.ravel())[0, 1]
    plt.xlabel('True Values'); plt.ylabel('Predictions'); plt.title(f'Predictions vs True Values ({model_name}, r = {corr:.3f})')
    plt.legend(); plt.grid(True)
    mse = mean_squared_error(targets_original, predictions_original); rmse = np.sqrt(mse); mae = mean_absolute_error(targets_original, predictions_original); r2 = r2_score(targets_original, predictions_original)
    stats_text = (f"MSE: {mse:.4f}\nRMSE: {rmse:.4f}\nMAE: {mae:.4f}\nR²: {r2:.4f}\nCorr: {corr:.4f}")
    plt.figtext(0.15, 0.75, stats_text, bbox=dict(facecolor='white', alpha=0.8, boxstyle='round,pad=0.5'))
    plt.tight_layout(rect=[0, 0, 1, 0.96]); plt.savefig(save_path); plt.show()

def reconstruct_rasters(predictions_flat, test_sample_indices, original_mask, spi_full_shape):
    unique_test_time_indices_original = sorted(list(set([t for t, _ in test_sample_indices])))
    reconstructed_spatial_rasters = np.full((len(unique_test_time_indices_original), spi_full_shape[1], spi_full_shape[2]), np.nan)
    valid_pixel_row_coords, valid_pixel_col_coords = np.where(original_mask)
    for i, (original_t, pixel_idx_in_mask) in enumerate(test_sample_indices):
        relative_t_idx_in_reconstruction = unique_test_time_indices_original.index(original_t)
        h_coord = valid_pixel_row_coords[pixel_idx_in_mask]
        w_coord = valid_pixel_col_coords[pixel_idx_in_mask]
        reconstructed_spatial_rasters[relative_t_idx_in_reconstruction, h_coord, w_coord] = predictions_flat[i][0]
    return reconstructed_spatial_rasters, unique_test_time_indices_original

def visualize_raster_samples(real_rasters_for_test_timesteps, pred_rasters_for_test_timesteps,
                             mask_for_visualization, unique_test_time_indices, model_name,
                             dates_list, num_samples_to_plot=5, output_dir='visualizations'):
    if not os.path.exists(output_dir): os.makedirs(output_dir)
    prefix = f'raster_sample_{model_name.lower()}'
    colors = ['#730000', '#E60000', '#FFAA00', '#FCD37F', '#FFFF00', '#FFFFFF', '#DCF8FF', '#96D2FF', '#46A5FF', '#0000FF', '#000080']
    cmap_usdm_spi = plt.cm.colors.LinearSegmentedColormap.from_list('usdm_spi', colors, N=256)
    vmin_spi, vmax_spi = -3, 3
    num_available_rasters = len(real_rasters_for_test_timesteps)
    if num_available_rasters == 0: print("No rasters to visualize."); return
    plot_indices = np.random.choice(num_available_rasters, min(num_samples_to_plot, num_available_rasters), replace=False)
    for i, raster_idx_in_test_set in enumerate(plot_indices):
        original_time_index = unique_test_time_indices[raster_idx_in_test_set]
        date_str = dates_list[original_time_index].strftime('%Y-%m')

        true_raster_slice = real_rasters_for_test_timesteps[raster_idx_in_test_set]
        pred_raster_slice = pred_rasters_for_test_timesteps[raster_idx_in_test_set]
        diff_raster_slice = np.full_like(true_raster_slice, np.nan)
        valid_pixels_for_diff = ~np.isnan(true_raster_slice) & ~np.isnan(pred_raster_slice)
        diff_raster_slice[valid_pixels_for_diff] = true_raster_slice[valid_pixels_for_diff] - pred_raster_slice[valid_pixels_for_diff]
        diff_abs_max = np.nanmax(np.abs(diff_raster_slice)) if not np.all(np.isnan(diff_raster_slice)) else 1.0
        fig, axes = plt.subplots(1, 3, figsize=(18, 6)); fig.suptitle(f'SPI Comparison ({model_name}) - Time: {date_str} (Original Index: {original_time_index})', fontsize=16)
        ax1 = axes[0]; masked_true_spi = np.ma.array(true_raster_slice, mask=~mask_for_visualization); im1 = ax1.imshow(masked_true_spi, cmap=cmap_usdm_spi, vmin=vmin_spi, vmax=vmax_spi)
        fig.colorbar(im1, ax=ax1, label='SPI Value', fraction=0.046, pad=0.04); ax1.set_title('True SPI'); ax1.axis('off')
        ax2 = axes[1]; masked_pred_spi = np.ma.array(pred_raster_slice, mask=~mask_for_visualization); im2 = ax2.imshow(masked_pred_spi, cmap=cmap_usdm_spi, vmin=vmin_spi, vmax=vmax_spi)
        fig.colorbar(im2, ax=ax2, label='SPI Value', fraction=0.046, pad=0.04); ax2.set_title('Predicted SPI'); ax2.axis('off')
        ax3 = axes[2]; masked_diff_spi = np.ma.array(diff_raster_slice, mask=~mask_for_visualization); im3 = ax3.imshow(masked_diff_spi, cmap='RdBu_r', vmin=-diff_abs_max, vmax=diff_abs_max)
        fig.colorbar(im3, ax=ax3, label='Difference (True - Predicted)', fraction=0.046, pad=0.04); ax3.set_title('Difference'); ax3.axis('off')
        plt.tight_layout(rect=[0, 0, 1, 0.95]); save_filename = os.path.join(output_dir, f'{prefix}_time_{date_str.replace("-", "_")}.png')
        plt.savefig(save_filename, dpi=200); plt.show()

# --- NEW: Function to save raster as GeoTIFF ---
def save_raster_as_geotiff(raster_data, ref_geotiff_path, out_path):
    """
    Saves a numpy array as a GeoTIFF, using another GeoTIFF as a reference for metadata.
    Handles NaN values by setting a nodata value in the output profile.
    """
    try:
        with rasterio.open(ref_geotiff_path) as src:
            profile = src.profile
            # Use a nodata value that is unlikely to be a real prediction
            nodata_val = -9999.0 
            
            # Replace NaNs in the data with the chosen nodata value
            raster_data_with_nodata = np.nan_to_num(raster_data, nan=nodata_val)
            
            profile.update(
                dtype=rasterio.float32,
                count=1,
                compress='lzw', # Optional but good practice
                nodata=nodata_val # Set the nodata value in the file's metadata
            )

        with rasterio.open(out_path, 'w', **profile) as dst:
            dst.write(raster_data_with_nodata.astype(rasterio.float32), 1)
        print(f"Successfully saved GeoTIFF: {out_path}")

    except Exception as e:
        print(f"ERROR: Could not save GeoTIFF {out_path}. Reason: {e}")


# --- load_data_for_prediction (Modified for selecting time steps) ---
def load_data_for_prediction(imputed_file_path, lookback_period, selected_time_steps=None):
    print(f"Loading data for prediction from: {imputed_file_path}")
    data_imputed = np.load(imputed_file_path)
    ndvi_data = data_imputed['NDVI_imputed']
    soil_moisture_data = data_imputed['SoilMoisture_imputed']
    lst_data = data_imputed['LST_imputed']
    spi_data = data_imputed['SPI_original']
    mask = data_imputed['mask']

    print(f"Loaded imputed feature shapes: NDVI: {ndvi_data.shape}, SM: {soil_moisture_data.shape}, LST: {lst_data.shape}")
    print(f"Loaded SPI_original shape: {spi_data.shape}, Mask shape: {mask.shape}, Valid pixels in mask: {np.sum(mask)}")

    time_steps_total, H, W = ndvi_data.shape
    X_list, y_list, sample_indices_output = [], [], []

    num_valid_pixels_in_mask = np.sum(mask)
    ndvi_masked_ts = [ndvi_data[t][mask] for t in range(time_steps_total)]
    lst_masked_ts = [lst_data[t][mask] for t in range(time_steps_total)]
    sm_masked_ts = [soil_moisture_data[t][mask] for t in range(time_steps_total)]
    spi_masked_ts = [spi_data[t][mask] for t in range(time_steps_total)]

    time_steps_to_process = range(lookback_period, time_steps_total)
    if selected_time_steps is not None:
        time_steps_to_process = [t for t in selected_time_steps if t >= lookback_period and t < time_steps_total]
        print(f"Predicting for selected time steps: {time_steps_to_process}")
    else:
        print("Predicting for all valid time steps.")

    for t in tqdm(time_steps_to_process, desc="Preparing prediction samples"):
        targets_for_current_t_step = spi_masked_ts[t]
        for pixel_idx_in_mask in range(num_valid_pixels_in_mask):
            current_target_value = targets_for_current_t_step[pixel_idx_in_mask]
            if np.isnan(current_target_value):
                continue
            single_pixel_feature_sequence = []
            for i in range(lookback_period):
                time_index_in_original_data = t - lookback_period + i
                val_ndvi = ndvi_masked_ts[time_index_in_original_data][pixel_idx_in_mask]
                val_lst = lst_masked_ts[time_index_in_original_data][pixel_idx_in_mask]
                val_sm = sm_masked_ts[time_index_in_original_data][pixel_idx_in_mask]
                single_pixel_feature_sequence.append(np.array([val_ndvi, val_lst, val_sm]))
            X_list.append(np.array(single_pixel_feature_sequence))
            y_list.append(current_target_value)
            sample_indices_output.append((t, pixel_idx_in_mask))

    X = np.array(X_list)
    y = np.array(y_list).reshape(-1, 1)

    print(f"Final X shape for prediction: {X.shape}, y shape: {y.shape}, Samples: {len(sample_indices_output)}")
    original_data_info = {'SPI_full_shape': spi_data.shape, 'SPI_original_full': spi_data, 'time_steps_total': time_steps_total}
    return X, y, mask, sample_indices_output, original_data_info

# --- Date Generation Helper ---
def generate_dates(start_year, start_month, num_timesteps):
    dates = []
    current_date = datetime.datetime(start_year, start_month, 1)
    for _ in range(num_timesteps):
        dates.append(current_date)
        if current_date.month == 12:
            current_date = datetime.datetime(current_date.year + 1, 1, 1)
        else:
            current_date = datetime.datetime(current_date.year, current_date.month + 1, 1)
    return dates

# --- MODIFIED: General Prediction and Evaluation Function ---
def predict_and_evaluate_country(model_path, scaler_X_path, scaler_y_path,
                                 imputed_data_path_country,
                                 ref_geotiff_path, # ADDED: Path to the reference GeoTIFF
                                 output_base_dir,
                                 country_name,
                                 data_start_year, data_start_month,
                                 selected_time_steps=None):
    print(f"\n--- Running Prediction and Evaluation for {country_name} with {MODEL_NAME} ---")
    print(f"Using device: {DEVICE}")

    safe_country_name = country_name.lower().replace(' ', '_').replace(',', '')
    output_dir = os.path.join(output_base_dir, f"results_{safe_country_name}_{MODEL_NAME.lower().replace('+', '_')}")
    viz_dir = os.path.join(output_dir, "visualizations")
    geotiff_dir = os.path.join(output_dir, "geotiffs") # NEW: Directory for GeoTIFFs
    os.makedirs(viz_dir, exist_ok=True)
    os.makedirs(geotiff_dir, exist_ok=True) # NEW: Create GeoTIFF directory
    print(f"Created output directories in: {output_dir}")

    # 1. Load the trained model and scalers
    try:
        scaler_X = joblib.load(scaler_X_path)
        scaler_y = joblib.load(scaler_y_path)
    except Exception as e:
        print(f"Error loading scalers: {e}. Aborting.")
        return
    
    model_input_size = 3
    model = SpiPredictorGATTransformer(
        input_size=model_input_size, d_model_gat=64, gat_heads=4,
        d_model_transformer=64, transformer_heads=4, num_transformer_layers=2,
        dim_feedforward_transformer=256, dropout_rate=0.2
    ).to(DEVICE)
    try:
        model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    except Exception as e:
        print(f"Error loading model: {e}. Aborting.")
        return
    model.eval()

    # 2. Load and prepare data
    X_pred_all, y_true_all, data_mask, pred_sample_indices, orig_data_info = \
        load_data_for_prediction(imputed_data_path_country, LOOKBACK, selected_time_steps)

    if X_pred_all.shape[0] == 0:
        print(f"No valid samples for prediction in {country_name}. Exiting.")
        return

    total_timesteps_in_data = orig_data_info['time_steps_total']
    dates_for_visualization = generate_dates(data_start_year, data_start_month, total_timesteps_in_data)

    # 3. Create DataLoader
    adj_matrix_template = torch.zeros(LOOKBACK, LOOKBACK, dtype=torch.float32)
    for i in range(LOOKBACK):
        if i > 0: adj_matrix_template[i, i-1] = 1
        if i < LOOKBACK - 1: adj_matrix_template[i, i+1] = 1
        adj_matrix_template[i, i] = 1

    pred_dataset = TimeSeriesDatasetGAT(X_pred_all, y_true_all, adj_matrix_template.cpu(), scaler_X, scaler_y, is_train=False)
    pred_loader = DataLoader(pred_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

    # 4. Make predictions and evaluate (overall)
    preds_o, tgts_o, _, _, _, _ = evaluate_model_gat_based(model, pred_loader, scaler_y, DEVICE, f"{MODEL_NAME} on {country_name}")

    # 5. Reconstruction and Saving
    plot_all_test_samples(tgts_o, preds_o, f"{MODEL_NAME}_{country_name}", save_dir=viz_dir)

    spi_shape_o = orig_data_info['SPI_full_shape']
    recon_preds, unique_pred_t = reconstruct_rasters(preds_o, pred_sample_indices, data_mask, spi_shape_o)
    
    np.savez_compressed(os.path.join(output_dir, f'spi_predictions_{safe_country_name}.npz'),
                        predictions=recon_preds, mask=data_mask, time_steps=unique_pred_t)

    full_spi_original_country = orig_data_info['SPI_original_full']
    true_pred_rs = np.array([full_spi_original_country[t] for t in unique_pred_t])

    # --- ADDED: Calculate and print metrics per time step ---
    print(f"\n--- Metrics per Predicted Time Step for {country_name} ---")
    for i, time_index in enumerate(unique_pred_t):
        date_str = dates_for_visualization[time_index].strftime('%Y-%m')
        
        # Get the true and predicted rasters for the current time step
        true_raster = true_pred_rs[i]
        pred_raster = recon_preds[i]
        
        # Flatten and filter out NaNs for accurate metric calculation
        valid_mask = ~np.isnan(true_raster) & ~np.isnan(pred_raster)
        
        true_values_flat = true_raster[valid_mask]
        pred_values_flat = pred_raster[valid_mask]
        
        # Calculate metrics if there are any valid pixels
        if true_values_flat.size > 0:
            mse_step = mean_squared_error(true_values_flat, pred_values_flat)
            rmse_step = np.sqrt(mse_step)
            mae_step = mean_absolute_error(true_values_flat, pred_values_flat)
            r2_step = r2_score(true_values_flat, pred_values_flat)
            
            print(f"Time Step: {date_str} (Index: {time_index})")
            print(f"  MSE: {mse_step:.4f}, RMSE: {rmse_step:.4f}, MAE: {mae_step:.4f}, R²: {r2_step:.4f}")
        else:
            print(f"Time Step: {date_str} (Index: {time_index}) - No valid overlapping pixels for metrics.")
    print("---------------------------------------------------\n")
    
    # 6. Save each predicted raster as a GeoTIFF
    if ref_geotiff_path and os.path.exists(ref_geotiff_path):
        print(f"Saving predicted rasters as GeoTIFFs using reference: {ref_geotiff_path}")
        for i, time_index in enumerate(unique_pred_t):
            date_str = dates_for_visualization[time_index].strftime('%Y_%m')
            output_filename = f"predicted_spi_{safe_country_name}_{date_str}.tif"
            output_filepath = os.path.join(geotiff_dir, output_filename)
            raster_to_save = recon_preds[i]
            save_raster_as_geotiff(raster_to_save, ref_geotiff_path, output_filepath)
    elif not ref_geotiff_path:
         print("WARNING: No reference GeoTIFF path provided. Skipping GeoTIFF save.")
    else:
         print(f"WARNING: Reference GeoTIFF not found at '{ref_geotiff_path}'. Skipping GeoTIFF save.")


    # 7. Visualize Raster Samples
    visualize_raster_samples(true_pred_rs, recon_preds, data_mask, unique_pred_t,
                             f"{MODEL_NAME}_{country_name}", dates_for_visualization, output_dir=viz_dir)

    print(f"Prediction and evaluation for {country_name} complete. Results in {output_dir}")

# --- Main execution block for prediction ---
if __name__ == "__main__":
    # Paths to your trained model and scalers
    MODEL_DIR = "/kaggle/input/gat-transformer-guinea-bissau/pytorch/trained-on-guinea-bissau/2/results_gat_transformer"
    MODEL_WEIGHTS_PATH = os.path.join(MODEL_DIR, "best_spi_gat_transformer_model.pt")
    SCALER_X_PATH = os.path.join(MODEL_DIR, "scaler_X_gat_transformer.pkl")
    SCALER_Y_PATH = os.path.join(MODEL_DIR, "scaler_y_gat_transformer.pkl")

    OUTPUT_BASE_DIR = "/kaggle/working"

    # --- NEW: Dictionary mapping country names to their reference GeoTIFFs ---
    REFERENCE_GEOTIFFS = {
        "Gambia": "/kaggle/input/ndvi-month/NDVI_Images/Gambia,_The/NDVI_Gambia,_The_2000_02.tif",
        "Guinea": "/kaggle/input/ndvi-month/NDVI_Images/Guinea/NDVI_Guinea_2000_02.tif",
        "Guinea-Bissau": "/kaggle/input/ndvi-month/NDVI_Images/Guinea-Bissau/NDVI_Guinea-Bissau_2000_02.tif",
        "Nigeria": "/kaggle/input/ndvi-month/NDVI_Images/Nigeria/NDVI_Nigeria_2000_02.tif",
        "Senegal": "/kaggle/input/ndvi-month/NDVI_Images/Senegal/NDVI_Senegal_2000_02.tif"
    }
    
    # --- Configuration for countries ---
    COUNTRIES_CONFIG = {
        "Senegal": {
            "imputed_data_path": "/kaggle/input/imputed-data/Senegal_combined_imputed.npz",
            "start_year": 2000,
            "start_month": 2,
            "selected_timesteps": [270, 275, 280, 282, 285]
        },
        "Gambia": {
            "imputed_data_path": "/kaggle/input/imputed-data/Gambia,_The_combined_imputed.npz",
            "start_year": 2000,
            "start_month": 2,
            "selected_timesteps": [270, 275, 280, 282, 285]
        },
        "Guinea-Bissau": {
            "imputed_data_path": "/kaggle/input/imputed-data/Guinea-Bissau_combined_imputed.npz",
            "start_year": 2000,
            "start_month": 2,
            "selected_timesteps": [270, 275, 280, 282, 285]
        },
        "Guinea": {
            "imputed_data_path": "/kaggle/input/imputed-data/Guinea_combined_imputed.npz",
            "start_year": 2000,
            "start_month": 2,
            "selected_timesteps": [270, 275, 280, 282, 285]
        }
    }

    # --- Loop through countries and run predictions ---
    for country_name, config in COUNTRIES_CONFIG.items():
        predict_and_evaluate_country(
            model_path=MODEL_WEIGHTS_PATH,
            scaler_X_path=SCALER_X_PATH,
            scaler_y_path=SCALER_Y_PATH,
            imputed_data_path_country=config["imputed_data_path"],
            ref_geotiff_path=REFERENCE_GEOTIFFS.get(country_name), # Pass the reference GeoTIFF path
            output_base_dir=OUTPUT_BASE_DIR,
            country_name=country_name,
            data_start_year=config["start_year"],
            data_start_month=config["start_month"],
            selected_time_steps=config["selected_timesteps"]
        )

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
import matplotlib.pyplot as plt
import os
import joblib
import math
from tqdm import tqdm
import datetime
import rasterio
from rasterio.plot import show
import geopandas as gpd

# --- Configuration (remains the same) ---
MODEL_NAME = "GAT_Transformer_Prediction"
LOOKBACK = 24
BATCH_SIZE = 128
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- All class and function definitions up to the visualization part remain unchanged ---

# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

# Graph Attention Layer
class GraphAttentionLayer(nn.Module):
    def __init__(self, in_features, out_features, dropout=0.2, alpha=0.2):
        super(GraphAttentionLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.dropout_val = dropout
        self.alpha = alpha
        self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.empty(size=(2 * out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
        self.leakyrelu = nn.LeakyReLU(self.alpha)
        self.dropout_layer = nn.Dropout(self.dropout_val)

    def forward(self, h, adj): # h is (B, N, F_in), adj is (N,N)
        Wh = torch.matmul(h, self.W)
        a_input = self._prepare_attention_input(Wh)
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(-1))
        zero_vec = -9e15 * torch.ones_like(e)
        attention_masked = torch.where(adj.unsqueeze(0) > 0, e, zero_vec)
        attention_softmax = F.softmax(attention_masked, dim=-1)
        attention_dropout = self.dropout_layer(attention_softmax)
        h_prime = torch.matmul(attention_dropout, Wh)
        return F.elu(h_prime)

    def _prepare_attention_input(self, Wh):
        B, N, _ = Wh.size()
        Wh_i = Wh.unsqueeze(2).expand(B, N, N, -1)
        Wh_j = Wh.unsqueeze(1).expand(B, N, N, -1)
        return torch.cat([Wh_i, Wh_j], dim=-1)

# MultiHeadGraphAttention
class MultiHeadGraphAttention(nn.Module):
    def __init__(self, in_features, out_features_per_head, n_heads, dropout=0.2, alpha=0.2, concat=True):
        super(MultiHeadGraphAttention, self).__init__()
        self.n_heads = n_heads
        self.concat = concat
        self.attentions = nn.ModuleList([GraphAttentionLayer(in_features, out_features_per_head, dropout, alpha) for _ in range(n_heads)])
        self.out_dim = out_features_per_head * n_heads if concat else out_features_per_head

    def forward(self, x, adj): # x is (B,N,F_in), adj is (N,N)
        head_outputs = [att(x, adj) for att in self.attentions]
        if self.concat:
            return torch.cat(head_outputs, dim=-1)
        else:
            return torch.mean(torch.stack(head_outputs, dim=-1), dim=-1)

# GAT+Transformer Model
class SpiPredictorGATTransformer(nn.Module):
    def __init__(self, input_size, d_model_gat=64, gat_heads=4,
                 d_model_transformer=64, transformer_heads=4, num_transformer_layers=2,
                 dim_feedforward_transformer=256, dropout_rate=0.2):
        super(SpiPredictorGATTransformer, self).__init__()
        self.dropout_rate = dropout_rate
        self.gat_input_proj = nn.Linear(input_size, d_model_gat)
        self.gat_attention = MultiHeadGraphAttention(
            in_features=d_model_gat,
            out_features_per_head=d_model_gat // gat_heads,
            n_heads=gat_heads, dropout=dropout_rate, concat=True
        )
        gat_output_dim = self.gat_attention.out_dim
        self.pos_encoder = PositionalEncoding(gat_output_dim, max_len=LOOKBACK + 10)
        transformer_encoder_layer = nn.TransformerEncoderLayer(
            d_model=gat_output_dim, nhead=transformer_heads, dim_feedforward=dim_feedforward_transformer,
            dropout=dropout_rate, batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(transformer_encoder_layer, num_layers=num_transformer_layers)
        self.fc = nn.Sequential(
            nn.Linear(gat_output_dim, 64),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(32, 1)
        )

    def forward(self, x, adj): # x: (B,L,F_in), adj: (L,L)
        x_proj = F.elu(self.gat_input_proj(x))
        x_gat_drop = F.dropout(x_proj, self.dropout_rate, training=self.training)
        x_gat_attended = self.gat_attention(x_gat_drop, adj)
        x_pos_encoded = self.pos_encoder(x_gat_attended)
        x_transformed = self.transformer_encoder(x_pos_encoded)
        x_last_step = x_transformed[:, -1, :]
        output = self.fc(x_last_step)
        return output

# TimeSeriesDataset for GAT (Optimized Adjacency)
class TimeSeriesDatasetGAT(Dataset):
    def __init__(self, features, targets, single_adj_matrix, scaler_X=None, scaler_y=None, is_train=True):
        original_shape = features.shape
        if is_train and scaler_X is None:
            self.scaler_X = StandardScaler()
            flattened = features.reshape(-1, original_shape[-1])
            scaled = self.scaler_X.fit_transform(flattened)
            self.features = scaled.reshape(original_shape)
        elif not is_train and scaler_X is not None:
            self.scaler_X = scaler_X
            flattened = features.reshape(-1, original_shape[-1])
            scaled = self.scaler_X.transform(flattened)
            self.features = scaled.reshape(original_shape)
        else:
            self.features = features
            self.scaler_X = scaler_X
        if is_train and scaler_y is None:
            self.scaler_y = StandardScaler()
            self.targets = self.scaler_y.fit_transform(targets)
        elif not is_train and scaler_y is not None:
            self.scaler_y = scaler_y
            self.targets = self.scaler_y.transform(targets)
        else:
            self.targets = targets
            self.scaler_y = scaler_y
        self.features = torch.tensor(self.features, dtype=torch.float32)
        self.targets = torch.tensor(self.targets, dtype=torch.float32)
        if not isinstance(single_adj_matrix, torch.Tensor):
            self.single_adj_matrix = torch.tensor(single_adj_matrix, dtype=torch.float32)
        else:
            self.single_adj_matrix = single_adj_matrix.to(dtype=torch.float32)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.single_adj_matrix, self.targets[idx]

    def get_scalers(self):
        return self.scaler_X, self.scaler_y

# Evaluation Function for GAT-based models
def evaluate_model_gat_based(model, test_loader, scaler_y, device, model_name="GAT-based Model"):
    model.eval()
    all_preds_s, all_targets_s = [], []
    with torch.no_grad():
        for batch_X, batch_adj_single, batch_y_s in tqdm(test_loader, desc=f"Evaluating {model_name}"):
            batch_X, batch_adj_single = batch_X.to(device), batch_adj_single.to(device)
            outputs_s = model(batch_X, batch_adj_single[0]) # Pass one adj
            all_preds_s.append(outputs_s.cpu().numpy())
            all_targets_s.append(batch_y_s.numpy())
    preds_s_np = np.vstack(all_preds_s)
    tgts_s_np = np.vstack(all_targets_s)
    if scaler_y:
        preds_o = scaler_y.inverse_transform(preds_s_np)
        tgts_o = scaler_y.inverse_transform(tgts_s_np)
    else:
        preds_o = preds_s_np
        tgts_o = tgts_s_np
    mse = mean_squared_error(tgts_o, preds_o)
    rmse = np.sqrt(mse)
    r2 = r2_score(tgts_o, preds_o)
    mae = mean_absolute_error(tgts_o, preds_o)
    print(f"\nOverall Evaluation metrics ({model_name}):\n MSE: {mse:.4f}, RMSE: {rmse:.4f}, MAE: {mae:.4f}, R²: {r2:.4f}")
    return preds_o, tgts_o, mse, rmse, r2, mae

# Plotting and Reconstruction Functions
def plot_all_test_samples(targets_original, predictions_original, model_name, save_dir='visualizations'):
    if not os.path.exists(save_dir): os.makedirs(save_dir)
    save_path = os.path.join(save_dir, f'all_test_samples_{model_name}.png')
    plt.figure(figsize=(12, 8)); plt.scatter(targets_original, predictions_original, alpha=0.3, color='blue', label='Samples')
    min_val = min(targets_original.min(), predictions_original.min()); max_val = max(targets_original.max(), predictions_original.max())
    plt.plot([min_val, max_val], [min_val, max_val], 'r--', label='y=x')
    corr = np.nan
    if len(targets_original.ravel()) > 1 and len(predictions_original.ravel()) > 1:
        z = np.polyfit(targets_original.ravel(), predictions_original.ravel(), 1); p = np.poly1d(z)
        sorted_targets = np.sort(targets_original.ravel())
        plt.plot(sorted_targets, p(sorted_targets), 'g-', label=f'Regression: y={z[0]:.3f}x+{z[1]:.3f}')
        corr = np.corrcoef(targets_original.ravel(), predictions_original.ravel())[0, 1]
    plt.xlabel('True Values'); plt.ylabel('Predictions'); plt.title(f'Predictions vs True Values ({model_name}, r = {corr:.3f})')
    plt.legend(); plt.grid(True)
    mse = mean_squared_error(targets_original, predictions_original); rmse = np.sqrt(mse); mae = mean_absolute_error(targets_original, predictions_original); r2 = r2_score(targets_original, predictions_original)
    stats_text = (f"MSE: {mse:.4f}\nRMSE: {rmse:.4f}\nMAE: {mae:.4f}\nR²: {r2:.4f}\nCorr: {corr:.4f}")
    plt.figtext(0.15, 0.75, stats_text, bbox=dict(facecolor='white', alpha=0.8, boxstyle='round,pad=0.5'))
    plt.tight_layout(rect=[0, 0, 1, 0.96]); plt.savefig(save_path); plt.show()

def reconstruct_rasters(predictions_flat, test_sample_indices, original_mask, spi_full_shape):
    unique_test_time_indices_original = sorted(list(set([t for t, _ in test_sample_indices])))
    reconstructed_spatial_rasters = np.full((len(unique_test_time_indices_original), spi_full_shape[1], spi_full_shape[2]), np.nan)
    valid_pixel_row_coords, valid_pixel_col_coords = np.where(original_mask)
    for i, (original_t, pixel_idx_in_mask) in enumerate(test_sample_indices):
        relative_t_idx_in_reconstruction = unique_test_time_indices_original.index(original_t)
        h_coord = valid_pixel_row_coords[pixel_idx_in_mask]
        w_coord = valid_pixel_col_coords[pixel_idx_in_mask]
        reconstructed_spatial_rasters[relative_t_idx_in_reconstruction, h_coord, w_coord] = predictions_flat[i][0]
    return reconstructed_spatial_rasters, unique_test_time_indices_original

def visualize_raster_samples(real_rasters_for_test_timesteps, pred_rasters_for_test_timesteps,
                             mask_for_visualization, unique_test_time_indices, model_name,
                             dates_list, ref_geotiff_path, shapefile_path, country_name,
                             num_samples_to_plot=5, output_dir='visualizations'):
    """
    This function for single-country plots is unchanged.
    """
    if not os.path.exists(output_dir): os.makedirs(output_dir)
    if not ref_geotiff_path or not os.path.exists(ref_geotiff_path) or not shapefile_path or not os.path.exists(shapefile_path):
        print("ERROR: Reference GeoTIFF or Shapefile path is invalid or file not found. Skipping visualization.")
        return
    with rasterio.open(ref_geotiff_path) as src:
        plot_extent = (src.bounds.left, src.bounds.right, src.bounds.bottom, src.bounds.top)
        raster_crs = src.crs
    try:
        country_border = gpd.read_file(shapefile_path)
        if country_border.crs != raster_crs:
            country_border = country_border.to_crs(raster_crs)
        country_centroid = country_border.geometry.centroid.iloc[0]
    except Exception as e:
        print(f"Could not read/process shapefile: {e}. Borders will not be plotted.")
        country_border = None
    prefix = f'raster_sample_{model_name.lower()}'
    colors = ['#730000', '#E60000', '#FFAA00', '#FCD37F', '#FFFF00', '#FFFFFF', '#DCF8FF', '#96D2FF', '#46A5FF', '#0000FF', '#000080']
    cmap_usdm_spi = plt.cm.colors.LinearSegmentedColormap.from_list('usdm_spi', colors, N=256)
    vmin_spi, vmax_spi = -3, 3
    num_available_rasters = len(real_rasters_for_test_timesteps)
    if num_available_rasters == 0:
        print("No rasters to visualize.")
        return
    plot_indices = np.random.choice(num_available_rasters, min(num_samples_to_plot, num_available_rasters), replace=False)
    for i, raster_idx_in_test_set in enumerate(plot_indices):
        original_time_index = unique_test_time_indices[raster_idx_in_test_set]
        date_str = dates_list[original_time_index].strftime('%Y-%m')
        true_raster_slice = real_rasters_for_test_timesteps[raster_idx_in_test_set]
        pred_raster_slice = pred_rasters_for_test_timesteps[raster_idx_in_test_set]
        diff_raster_slice = true_raster_slice - pred_raster_slice
        diff_abs_max = np.nanmax(np.abs(diff_raster_slice)) if not np.all(np.isnan(diff_raster_slice)) else 1.0
        fig, axes = plt.subplots(1, 3, figsize=(20, 7))
        fig.suptitle(f'SPI Comparison for {country_name} - Time: {date_str} (Index: {original_time_index})', fontsize=18)
        plot_data = [
            {'ax': axes[0], 'data': true_raster_slice, 'title': 'True SPI', 'cmap': cmap_usdm_spi, 'vmin': vmin_spi, 'vmax': vmax_spi},
            {'ax': axes[1], 'data': pred_raster_slice, 'title': 'Predicted SPI', 'cmap': cmap_usdm_spi, 'vmin': vmin_spi, 'vmax': vmax_spi},
            {'ax': axes[2], 'data': diff_raster_slice, 'title': 'Difference', 'cmap': 'RdBu_r', 'vmin': -diff_abs_max, 'vmax': diff_abs_max}
        ]
        for p in plot_data:
            ax = p['ax']
            masked_data = np.ma.array(p['data'], mask=~mask_for_visualization)
            im = ax.imshow(masked_data, cmap=p['cmap'], vmin=p['vmin'], vmax=p['vmax'], extent=plot_extent)
            label = 'SPI Value' if 'SPI' in p['title'] else 'Difference (True - Predicted)'
            fig.colorbar(im, ax=ax, label=label, fraction=0.046, pad=0.04)
            ax.set_title(p['title'], fontsize=14)
            if country_border is not None:
                country_border.plot(ax=ax, edgecolor='black', facecolor='none', linewidth=1.5)
                ax.text(country_centroid.x, country_centroid.y, country_name, ha='center', va='center', fontsize=12, color='black', bbox=dict(facecolor='none', edgecolor='none'))
            ax.set_xlabel("Longitude"); ax.set_ylabel("Latitude"); ax.tick_params(axis='x', rotation=45)
        plt.tight_layout(rect=[0, 0, 1, 0.95])
        save_filename = os.path.join(output_dir, f'{prefix}_time_{date_str.replace("-", "_")}.png')
        plt.savefig(save_filename, dpi=200)
        plt.show()

def save_raster_as_geotiff(raster_data, ref_geotiff_path, out_path):
    try:
        with rasterio.open(ref_geotiff_path) as src:
            profile = src.profile
            nodata_val = -9999.0
            raster_data_with_nodata = np.nan_to_num(raster_data, nan=nodata_val)
            profile.update(dtype=rasterio.float32, count=1, compress='lzw', nodata=nodata_val)
        with rasterio.open(out_path, 'w', **profile) as dst:
            dst.write(raster_data_with_nodata.astype(rasterio.float32), 1)
    except Exception as e:
        print(f"ERROR: Could not save GeoTIFF {out_path}. Reason: {e}")

def load_data_for_prediction(imputed_file_path, lookback_period, selected_time_steps=None):
    print(f"Loading data for prediction from: {imputed_file_path}")
    data_imputed = np.load(imputed_file_path)
    ndvi_data, soil_moisture_data, lst_data = data_imputed['NDVI_imputed'], data_imputed['SoilMoisture_imputed'], data_imputed['LST_imputed']
    spi_data, mask = data_imputed['SPI_original'], data_imputed['mask']
    print(f"Loaded SPI_original shape: {spi_data.shape}, Mask shape: {mask.shape}, Valid pixels in mask: {np.sum(mask)}")
    time_steps_total, H, W = ndvi_data.shape
    X_list, y_list, sample_indices_output = [], [], []
    num_valid_pixels_in_mask = np.sum(mask)
    ndvi_masked_ts = [ndvi_data[t][mask] for t in range(time_steps_total)]
    lst_masked_ts = [lst_data[t][mask] for t in range(time_steps_total)]
    sm_masked_ts = [soil_moisture_data[t][mask] for t in range(time_steps_total)]
    spi_masked_ts = [spi_data[t][mask] for t in range(time_steps_total)]
    time_steps_to_process = range(lookback_period, time_steps_total)
    if selected_time_steps is not None:
        time_steps_to_process = [t for t in selected_time_steps if t >= lookback_period and t < time_steps_total]
    for t in tqdm(time_steps_to_process, desc="Preparing prediction samples"):
        targets_for_current_t_step = spi_masked_ts[t]
        for pixel_idx_in_mask in range(num_valid_pixels_in_mask):
            current_target_value = targets_for_current_t_step[pixel_idx_in_mask]
            if np.isnan(current_target_value): continue
            single_pixel_feature_sequence = [np.array([
                ndvi_masked_ts[t - lookback_period + i][pixel_idx_in_mask],
                lst_masked_ts[t - lookback_period + i][pixel_idx_in_mask],
                sm_masked_ts[t - lookback_period + i][pixel_idx_in_mask]
            ]) for i in range(lookback_period)]
            X_list.append(np.array(single_pixel_feature_sequence))
            y_list.append(current_target_value)
            sample_indices_output.append((t, pixel_idx_in_mask))
    X, y = np.array(X_list), np.array(y_list).reshape(-1, 1)
    print(f"Final X shape for prediction: {X.shape}, y shape: {y.shape}, Samples: {len(sample_indices_output)}")
    original_data_info = {'SPI_full_shape': spi_data.shape, 'SPI_original_full': spi_data, 'time_steps_total': time_steps_total}
    return X, y, mask, sample_indices_output, original_data_info

def generate_dates(start_year, start_month, num_timesteps):
    dates = []
    current_date = datetime.datetime(start_year, start_month, 1)
    for _ in range(num_timesteps):
        dates.append(current_date)
        current_date = (current_date.replace(day=1) + datetime.timedelta(days=32)).replace(day=1)
    return dates

def predict_and_evaluate_country(model_path, scaler_X_path, scaler_y_path,
                                     imputed_data_path_country,
                                     ref_geotiff_path, shapefile_path,
                                     output_base_dir, country_name,
                                     data_start_year, data_start_month,
                                     selected_time_steps=None):
    print(f"\n--- Running Prediction and Evaluation for {country_name} ---")
    safe_country_name = country_name.lower().replace(' ', '_').replace(',', '')
    output_dir = os.path.join(output_base_dir, f"results_{safe_country_name}_{MODEL_NAME.lower().replace('+', '_')}")
    viz_dir = os.path.join(output_dir, "visualizations")
    geotiff_dir_pred = os.path.join(output_dir, "geotiffs_predicted")
    geotiff_dir_true = os.path.join(output_dir, "geotiffs_true")
    os.makedirs(viz_dir, exist_ok=True)
    os.makedirs(geotiff_dir_pred, exist_ok=True)
    os.makedirs(geotiff_dir_true, exist_ok=True)
    scaler_X, scaler_y = joblib.load(scaler_X_path), joblib.load(scaler_y_path)
    model = SpiPredictorGATTransformer(input_size=3).to(DEVICE)
    model.load_state_dict(torch.load(model_path, map_location=DEVICE)); model.eval()
    X_pred_all, y_true_all, data_mask, pred_sample_indices, orig_data_info = \
        load_data_for_prediction(imputed_data_path_country, LOOKBACK, selected_time_steps)
    if X_pred_all.shape[0] == 0:
        print(f"No valid samples for prediction in {country_name}. Exiting."); return
    dates_for_visualization = generate_dates(data_start_year, data_start_month, orig_data_info['time_steps_total'])
    adj_matrix_template = torch.eye(LOOKBACK, dtype=torch.float32)
    for i in range(LOOKBACK):
        if i > 0: adj_matrix_template[i, i-1] = 1
        if i < LOOKBACK - 1: adj_matrix_template[i, i+1] = 1
    pred_dataset = TimeSeriesDatasetGAT(X_pred_all, y_true_all, adj_matrix_template.cpu(), scaler_X, scaler_y, is_train=False)
    pred_loader = DataLoader(pred_dataset, batch_size=BATCH_SIZE, shuffle=False)
    preds_o, tgts_o, _, _, _, _ = evaluate_model_gat_based(model, pred_loader, scaler_y, DEVICE, f"{MODEL_NAME} on {country_name}")
    recon_preds, unique_pred_t = reconstruct_rasters(preds_o, pred_sample_indices, data_mask, orig_data_info['SPI_full_shape'])
    true_pred_rs = np.array([orig_data_info['SPI_original_full'][t] for t in unique_pred_t])
    if ref_geotiff_path and os.path.exists(ref_geotiff_path):
        print(f"Saving predicted and true rasters as GeoTIFFs...")
        for i, time_index in enumerate(unique_pred_t):
            date_str = dates_for_visualization[time_index].strftime('%Y_%m')
            pred_filename = f"predicted_spi_{safe_country_name}_{date_str}.tif"
            pred_filepath = os.path.join(geotiff_dir_pred, pred_filename)
            save_raster_as_geotiff(recon_preds[i], ref_geotiff_path, pred_filepath)
            true_filename = f"true_spi_{safe_country_name}_{date_str}.tif"
            true_filepath = os.path.join(geotiff_dir_true, true_filename)
            save_raster_as_geotiff(true_pred_rs[i], ref_geotiff_path, true_filepath)
    visualize_raster_samples(
        real_rasters_for_test_timesteps=true_pred_rs, pred_rasters_for_test_timesteps=recon_preds,
        mask_for_visualization=data_mask, unique_test_time_indices=unique_pred_t,
        model_name=f"{MODEL_NAME}_{country_name}", dates_list=dates_for_visualization,
        ref_geotiff_path=ref_geotiff_path, shapefile_path=shapefile_path,
        country_name=country_name, output_dir=viz_dir
    )
    print(f"Prediction and evaluation for {country_name} complete.")


# --- MODIFIED: This is the only function that has been changed. ---
def create_combined_visualization_detailed(countries_config, country_shapefiles, ref_geotiffs, output_base_dir, model_name, date_str):
    print(f"\n--- Creating Detailed Combined Visualization for Time Step: {date_str} ---")
    # Note the change to sharex=False, sharey=False to allow for individual colorbars to fit properly
    fig, axes = plt.subplots(1, 3, figsize=(30, 10))

    colors_spi = ['#730000', '#E60000', '#FFAA00', '#FCD37F', '#FFFF00', '#FFFFFF', '#DCF8FF', '#96D2FF', '#46A5FF', '#0000FF', '#000080']
    cmap_spi = plt.cm.colors.LinearSegmentedColormap.from_list('usdm_spi', colors_spi, N=256)
    norm_spi = plt.Normalize(vmin=-3, vmax=3)
    cmap_diff = 'RdBu_r'
    
    max_abs_diff = 0.0
    country_keys = list(countries_config.keys())
    
    # Store image handles for colorbars
    im_handles = {'true': [], 'pred': [], 'diff': []}

    print("Pass 1: Calculating difference range...")
    for country_name in country_keys:
        safe_country_name = country_name.lower().replace(' ', '_').replace(',', '')
        results_dir = os.path.join(output_base_dir, f"results_{safe_country_name}_{model_name.lower().replace('+', '_')}")
        true_path = os.path.join(results_dir, "geotiffs_true", f"true_spi_{safe_country_name}_{date_str}.tif")
        pred_path = os.path.join(results_dir, "geotiffs_predicted", f"predicted_spi_{safe_country_name}_{date_str}.tif")
        if os.path.exists(true_path) and os.path.exists(pred_path):
            with rasterio.open(true_path) as t_src, rasterio.open(pred_path) as p_src:
                diff = t_src.read(1, masked=True) - p_src.read(1, masked=True)
                if not np.all(diff.mask):
                    current_max = np.ma.max(np.abs(diff))
                    if current_max > max_abs_diff:
                        max_abs_diff = current_max
    
    max_abs_diff = max(max_abs_diff, 1.0)
    norm_diff = plt.Normalize(vmin=-max_abs_diff, vmax=max_abs_diff)
    
    print(f"Pass 2: Plotting data with diff range [{-max_abs_diff:.2f}, {max_abs_diff:.2f}]...")
    all_borders_list = []
    for country_name in country_keys:
        safe_country_name = country_name.lower().replace(' ', '_').replace(',', '')
        results_dir = os.path.join(output_base_dir, f"results_{safe_country_name}_{model_name.lower().replace('+', '_')}")
        true_path = os.path.join(results_dir, "geotiffs_true", f"true_spi_{safe_country_name}_{date_str}.tif")
        pred_path = os.path.join(results_dir, "geotiffs_predicted", f"predicted_spi_{safe_country_name}_{date_str}.tif")

        if os.path.exists(true_path) and os.path.exists(pred_path):
            with rasterio.open(true_path) as src:
                show(src, ax=axes[0], cmap=cmap_spi, norm=norm_spi)
            with rasterio.open(pred_path) as src:
                show(src, ax=axes[1], cmap=cmap_spi, norm=norm_spi)
            # Difference plot requires reading the data again
            with rasterio.open(true_path) as t_src, rasterio.open(pred_path) as p_src:
                diff_data = t_src.read(1, masked=True) - p_src.read(1, masked=True)
                show(diff_data, ax=axes[2], transform=t_src.transform, cmap=cmap_diff, norm=norm_diff)

        shapefile_path = country_shapefiles.get(country_name)
        if shapefile_path and os.path.exists(shapefile_path):
            gdf = gpd.read_file(shapefile_path)
            with rasterio.open(ref_geotiffs[country_name]) as ref_src:
                gdf = gdf.to_crs(ref_src.crs)
                all_borders_list.append(gdf)
                centroid = gdf.geometry.centroid.iloc[0]
                for ax in axes:
                    ax.text(centroid.x, centroid.y, country_name, ha='center', va='center', fontsize=12, color='black', fontweight='bold')
    
    if all_borders_list:
        combined_borders = gpd.pd.concat(all_borders_list, ignore_index=True)
        for ax in axes:
            combined_borders.plot(ax=ax, edgecolor='black', facecolor='none', linewidth=1.5)

    plot_titles = ['True SPI', 'Predicted SPI', 'Difference (True - Predicted)']
    for i, ax in enumerate(axes):
        ax.set_title(plot_titles[i], fontsize=16)
        ax.set_xlabel('Longitude'); ax.set_ylabel('Latitude')
        ax.tick_params(axis='x', rotation=45); ax.grid(True, linestyle='--', alpha=0.6)
    axes[1].set_ylabel(''); axes[2].set_ylabel('')

    # --- MODIFICATION START ---
    # Create mappable objects that represent the color scales
    spi_mappable = plt.cm.ScalarMappable(norm=norm_spi, cmap=cmap_spi)
    diff_mappable = plt.cm.ScalarMappable(norm=norm_diff, cmap=cmap_diff)
    
    # Create a separate colorbar for each subplot, matching the single-country style
    fig.colorbar(spi_mappable, ax=axes[0], label='SPI Value', fraction=0.046, pad=0.04)
    fig.colorbar(spi_mappable, ax=axes[1], label='SPI Value', fraction=0.046, pad=0.04)
    fig.colorbar(diff_mappable, ax=axes[2], label='Difference (True - Predicted)', fraction=0.046, pad=0.04)
    # --- MODIFICATION END ---
    
    fig.suptitle(f'Combined Regional SPI Comparison - {date_str.replace("_", "-")}', fontsize=20)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    output_path = os.path.join(output_base_dir, f'combined_comparison_{date_str}.png')
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"Combined visualization saved to: {output_path}")


# --- Main execution block for prediction (unchanged) ---
if __name__ == "__main__":
    MODEL_DIR = "/kaggle/input/gat-transformer-guinea-bissau/pytorch/trained-on-guinea-bissau/2/results_gat_transformer"
    MODEL_WEIGHTS_PATH = os.path.join(MODEL_DIR, "best_spi_gat_transformer_model.pt")
    SCALER_X_PATH = os.path.join(MODEL_DIR, "scaler_X_gat_transformer.pkl")
    SCALER_Y_PATH = os.path.join(MODEL_DIR, "scaler_y_gat_transformer.pkl")
    OUTPUT_BASE_DIR = "/kaggle/working"

    REFERENCE_GEOTIFFS = {
        "Gambia": "/kaggle/input/ndvi-month/NDVI_Images/Gambia,_The/NDVI_Gambia,_The_2000_02.tif",
        "Guinea": "/kaggle/input/ndvi-month/NDVI_Images/Guinea/NDVI_Guinea_2000_02.tif",
        "Guinea-Bissau": "/kaggle/input/ndvi-month/NDVI_Images/Guinea-Bissau/NDVI_Guinea-Bissau_2000_02.tif",
        "Senegal": "/kaggle/input/ndvi-month/NDVI_Images/Senegal/NDVI_Senegal_2000_02.tif"
    }
    COUNTRY_SHAPEFILES = {
        "Guinea": "/kaggle/input/shape-files-study-area/gadm41_GIN_0.shp",
        "Gambia": "/kaggle/input/shape-files-study-area/gadm41_GMB_0.shp",
        "Guinea-Bissau": "/kaggle/input/shape-files-study-area/gadm41_GNB_0.shp",
        "Senegal": "/kaggle/input/shape-files-study-area/gadm41_SEN_0.shp"
    }

    timesteps_to_predict = [270, 275, 280, 282, 285]
    COUNTRIES_CONFIG = {
        "Senegal": {"imputed_data_path": "/kaggle/input/imputed-data/Senegal_combined_imputed.npz", "start_year": 2000, "start_month": 2, "selected_timesteps": timesteps_to_predict},
        "Gambia": {"imputed_data_path": "/kaggle/input/imputed-data/Gambia,_The_combined_imputed.npz", "start_year": 2000, "start_month": 2, "selected_timesteps": timesteps_to_predict},
        "Guinea-Bissau": {"imputed_data_path": "/kaggle/input/imputed-data/Guinea-Bissau_combined_imputed.npz", "start_year": 2000, "start_month": 2, "selected_timesteps": timesteps_to_predict},
        "Guinea": {"imputed_data_path": "/kaggle/input/imputed-data/Guinea_combined_imputed.npz", "start_year": 2000, "start_month": 2, "selected_timesteps": timesteps_to_predict}
    }

    for country_name, config in COUNTRIES_CONFIG.items():
        predict_and_evaluate_country(
            model_path=MODEL_WEIGHTS_PATH,
            scaler_X_path=SCALER_X_PATH,
            scaler_y_path=SCALER_Y_PATH,
            imputed_data_path_country=config["imputed_data_path"],
            ref_geotiff_path=REFERENCE_GEOTIFFS.get(country_name),
            shapefile_path=COUNTRY_SHAPEFILES.get(country_name),
            output_base_dir=OUTPUT_BASE_DIR,
            country_name=country_name,
            data_start_year=config["start_year"],
            data_start_month=config["start_month"],
            selected_time_steps=config["selected_timesteps"]
        )
    
    if timesteps_to_predict:
        dates = generate_dates(2000, 2, max(timesteps_to_predict) + 5) 
        for time_idx in timesteps_to_predict:
            date_to_plot_str = dates[time_idx].strftime('%Y_%m')
            create_combined_visualization_detailed(
                countries_config=COUNTRIES_CONFIG,
                country_shapefiles=COUNTRY_SHAPEFILES,
                ref_geotiffs=REFERENCE_GEOTIFFS,
                output_base_dir=OUTPUT_BASE_DIR,
                model_name=MODEL_NAME,
                date_str=date_to_plot_str
            )

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import rasterio
from scipy import stats
from sklearn.metrics import r2_score
import datetime

# --- Helper function to generate dates ---
def generate_dates(start_year, start_month, num_timesteps):
    dates = []
    current_date = datetime.datetime(start_year, start_month, 1)
    for _ in range(num_timesteps):
        dates.append(current_date)
        if current_date.month == 12:
            current_date = datetime.datetime(current_date.year + 1, 1, 1)
        else:
            current_date = datetime.datetime(current_date.year, current_date.month + 1, 1)
    return dates

# --- Main function for pixel analysis ---
def plot_pixel_analysis(ax_scatter, ax_timeseries, country_name, coords, predictions_path, ref_geotiff_path, imputed_data_path, all_dates):
    """
    Creates a scatter and time series plot for a single pixel using specific data paths.
    """
    print(f"--- Analyzing pixel for {country_name} at {coords} ---")
    lat, lon = coords
    
    # --- 1. Check if all required files exist ---
    if not all(os.path.exists(p) for p in [predictions_path, ref_geotiff_path, imputed_data_path]):
        print(f"ERROR: Missing data files for {country_name}. Skipping analysis.")
        ax_scatter.text(0.5, 0.5, 'Data not found', ha='center', color='red')
        ax_timeseries.text(0.5, 0.5, 'Data not found', ha='center', color='red')
        ax_scatter.set_title(f"{country_name}: Error")
        return

    # --- 2. Load data and find the pixel index ---
    predictions_data = np.load(predictions_path)
    predicted_rasters = predictions_data['predictions']
    predicted_time_indices = predictions_data['time_steps']
    
    original_data = np.load(imputed_data_path)
    true_rasters = original_data['SPI_original']

    with rasterio.open(ref_geotiff_path) as src:
        row, col = src.index(lon, lat)

    # --- 3. Extract Time Series for the pixel ---
    forecasted_values = predicted_rasters[:, row, col]
    actual_values = true_rasters[predicted_time_indices, row, col]
    dates_slice = [all_dates[i] for i in predicted_time_indices]
    
    valid_mask = ~np.isnan(actual_values) & ~np.isnan(forecasted_values)
    actual_values = actual_values[valid_mask]
    forecasted_values = forecasted_values[valid_mask]
    dates_slice = np.array(dates_slice)[valid_mask]

    if len(actual_values) < 2:
        print(f"Not enough valid data points for pixel in {country_name}.")
        ax_scatter.text(0.5, 0.5, 'Not enough data', ha='center')
        return

    # --- 4. Create Scatter Plot ---
    r2 = r2_score(actual_values, forecasted_values)
    slope, intercept, _, _, _ = stats.linregress(actual_values, forecasted_values)
    
    ax_scatter.plot(actual_values, forecasted_values, 'o', color='royalblue', alpha=0.6)
    line_1_1 = np.linspace(min(actual_values.min(), forecasted_values.min()), max(actual_values.max(), forecasted_values.max()), 100)
    ax_scatter.plot(line_1_1, line_1_1, 'k--', label='1:1 line')
    ax_scatter.plot(line_1_1, slope * line_1_1 + intercept, 'g-', label='Fit line')
    
    ax_scatter.set_xlabel("Actual")
    ax_scatter.set_ylabel("Forecasted")
    
    title_text = f"{country_name} (Lat: {lat:.2f}, Lon: {lon:.2f})\n$R^2 = {r2:.4f}$"
    ax_scatter.set_title(title_text)
    
    ax_scatter.legend()
    ax_scatter.grid(True)
    ax_scatter.set_aspect('equal', 'box')

    # --- 5. Create Time Series Plot ---
    ax_timeseries.plot(dates_slice, actual_values, label='Actual', color='royalblue')
    ax_timeseries.plot(dates_slice, forecasted_values, label='Forecasted', color='red')
    ax_timeseries.set_xlabel("Date")
    ax_timeseries.set_ylabel("Drought (SPI)")
    
    # CHANGED: Explicitly set the legend location
    ax_timeseries.legend(loc='upper right')
    
    ax_timeseries.grid(True)
    plt.setp(ax_timeseries.get_xticklabels(), rotation=30, ha='right')


# =============================================================================
# == SCRIPT EXECUTION =========================================================
# =============================================================================

if __name__ == "__main__":
    
    # --- 1. Define Paths and Locations ---
    IMPUTED_DATA_PATHS = {
        "Gambia": "/kaggle/input/imputed-data/Gambia,_The_combined_imputed.npz",
        "Guinea-Bissau": "/kaggle/input/imputed-data/Guinea-Bissau_combined_imputed.npz"
    }

    REFERENCE_GEOTIFFS = {
        "Gambia": "/kaggle/input/ndvi-month/NDVI_Images/Gambia,_The/NDVI_Gambia,_The_2000_02.tif",
        "Guinea-Bissau": "/kaggle/input/ndvi-month/NDVI_Images/Guinea-Bissau/NDVI_Guinea-Bissau_2000_02.tif",
    }
    
    PREDICTION_NPZ_PATHS = {
        "Gambia": "/kaggle/input/gambia-predictions/results_gat_transformer/spi_gat_transformer_predictions.npz",
        "Guinea-Bissau": "/kaggle/input/gat-transformer-guinea-bissau/pytorch/trained-on-guinea-bissau/2/results_gat_transformer/spi_gat_transformer_predictions.npz"
    }

    PIXEL_LOCATIONS = {
        "Gambia": (13.4, -15.5),
        "Guinea-Bissau": (12.0, -15.0)
    }

    # --- 2. Setup Plotting ---
    fig, axes = plt.subplots(2, 2, figsize=(12, 9))
    fig.suptitle("Pixel-Level Analysis: Actual vs. Forecasted SPI", fontsize=16)

    # From Feb 2000 to Dec 2023 is 287 months.
    all_dataset_dates = generate_dates(2000, 2, 287) 

    # --- 3. Loop and Plot ---
    for i, country_name in enumerate(PIXEL_LOCATIONS.keys()):
        ax_scatter = axes[i, 0]
        ax_timeseries = axes[i, 1]
        
        plot_pixel_analysis(
            ax_scatter, 
            ax_timeseries, 
            country_name, 
            PIXEL_LOCATIONS[country_name],
            PREDICTION_NPZ_PATHS[country_name],
            REFERENCE_GEOTIFFS[country_name],
            IMPUTED_DATA_PATHS[country_name],
            all_dataset_dates
        )

    plt.tight_layout(rect=[0, 0, 1, 0.95])
    
    # --- 4. Save Final Figure ---
    output_dir = "/kaggle/working/"
    final_plot_path = os.path.join(output_dir, "pixel_analysis_summary_Gambia_GB.png")
    plt.savefig(final_plot_path, dpi=300)
    plt.show()
    
    print(f"\nPixel analysis summary saved to: {final_plot_path}")