In [None]:
import numpy as np
import os
from tqdm import tqdm

def impute_feature_array(feature_data, overall_mask, feature_name):
    """
    Imputes NaN values in a feature array using spatial means for each time step.
    Imputation is done only for pixels that are True in the overall_mask.

    Args:
        feature_data (np.array): The raw feature data (time, height, width).
        overall_mask (np.array): A 2D boolean mask (height, width) where True indicates
                                 pixels to be considered for mean calculation and imputation.
        feature_name (str): Name of the feature for printing progress.

    Returns:
        np.array: The feature data with NaNs imputed within the masked areas.
    """
    print(f"Starting imputation for {feature_name}...")
    imputed_data = np.copy(feature_data) # Work on a copy
    time_steps = feature_data.shape[0]

    for t in tqdm(range(time_steps), desc=f"Imputing {feature_name}"):
        current_slice = imputed_data[t] # Shape (height, width)
        
        # Extract values from the current slice only for pixels within the overall_mask
        values_in_slice_for_mean_calc = current_slice[overall_mask]
        
        # Calculate the mean for imputation from these valid masked pixels
        mean_for_imputation = np.nanmean(values_in_slice_for_mean_calc)
        
        # If the mean itself is NaN (e.g., all masked pixels in this slice were NaN), use 0.0
        if np.isnan(mean_for_imputation):
            mean_for_imputation = 0.0
            # print(f"Warning: All masked values were NaN for {feature_name} at timestep {t}. Using 0.0 for imputation.")

        # Identify NaN locations within the current_slice that are ALSO within the overall_mask
        nan_locations_in_slice_to_impute = np.isnan(current_slice) & overall_mask
        
        # Apply imputation to these specific locations
        current_slice[nan_locations_in_slice_to_impute] = mean_for_imputation
        imputed_data[t] = current_slice
        
    print(f"Imputation for {feature_name} complete.")
    return imputed_data

def main_offline_imputation(raw_data_path, output_data_path):
    """
    Loads raw data, performs offline imputation for features, and saves the processed data.
    """
    if not os.path.exists(raw_data_path):
        print(f"Error: Raw data file not found at {raw_data_path}")
        # Optionally, create dummy data for testing if needed
        if not os.path.exists("dummy_raw_data.npz"):
            print("Creating dummy raw data for testing offline imputation...")
            T, H, W = 50, 10, 10 # Smaller for quick test
            dummy_ndvi = np.random.rand(T, H, W) * 2 - 0.5 # Include negatives and >1 to test imputation
            dummy_sm = np.random.rand(T, H, W)
            dummy_lst = np.random.rand(T, H, W) * 50 + 273.15 # Temperature range
            dummy_spi = np.random.randn(T, H, W) * 2
            
            # Introduce NaNs into features and SPI
            for arr in [dummy_ndvi, dummy_sm, dummy_lst, dummy_spi]:
                nan_indices = np.random.choice([True, False], size=arr.shape, p=[0.1, 0.9]) # 10% NaNs
                arr[nan_indices] = np.nan
            
            # Ensure mask has some False areas if SPI is all NaN in some columns/rows for testing mask logic
            dummy_spi[:, 0, :] = np.nan # Example: make first row of pixels always NaN in SPI

            np.savez_compressed("dummy_raw_data.npz", NDVI=dummy_ndvi, SoilMoisture=dummy_sm, LST=dummy_lst, SPI=dummy_spi)
            raw_data_path = "dummy_raw_data.npz"
            print(f"Using dummy raw data from {raw_data_path}")
        else:
            raw_data_path = "dummy_raw_data.npz"
            print(f"Using existing dummy raw data from {raw_data_path}")


    print(f"Loading raw data from: {raw_data_path}")
    data = np.load(raw_data_path)

    original_ndvi = data['NDVI']
    original_sm = data['SoilMoisture']
    original_lst = data['LST']
    original_spi = data['SPI'] # SPI (target) is not imputed here, kept as is.

    print("Original data shapes:")
    print(f"NDVI: {original_ndvi.shape}")
    print(f"SoilMoisture: {original_sm.shape}")
    print(f"LST: {original_lst.shape}")
    print(f"SPI: {original_spi.shape}")

    # Define the overall_mask based on SPI (pixels that are NOT NaN for ALL timesteps in SPI)
    # This is the mask your model scripts generally use to define valid pixels for analysis.
    overall_mask = ~np.isnan(original_spi).all(axis=0)
    print(f"Overall mask shape: {overall_mask.shape}, Total valid pixels in mask: {np.sum(overall_mask)}")

    if np.sum(overall_mask) == 0:
        print("Error: The overall mask has no valid pixels. Check your SPI data or mask definition.")
        return

    # Perform imputation
    imputed_ndvi = impute_feature_array(original_ndvi, overall_mask, "NDVI")
    imputed_sm = impute_feature_array(original_sm, overall_mask, "SoilMoisture")
    imputed_lst = impute_feature_array(original_lst, overall_mask, "LST")

    # Verify no NaNs in imputed features within the masked area
    for feature_arr, name in zip([imputed_ndvi, imputed_sm, imputed_lst], ["NDVI", "SM", "LST"]):
        has_nans_in_mask = np.isnan(feature_arr[:, overall_mask]).any()
        print(f"NaNs in imputed {name} (within mask): {has_nans_in_mask}")
        if has_nans_in_mask:
            print(f"Warning: Imputed {name} still contains NaNs within the masked area. Review imputation logic.")

    # Ensure the output directory exists
    output_dir = os.path.dirname(output_data_path)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print(f"Created output directory: {output_dir}")

    print(f"Saving imputed data to: {output_data_path}")
    np.savez_compressed(
        output_data_path,
        NDVI_imputed=imputed_ndvi,
        SoilMoisture_imputed=imputed_sm,
        LST_imputed=imputed_lst,
        SPI_original=original_spi, # Save original SPI
        mask=overall_mask         # Save the mask used for imputation and future loading
    )
    print("Offline imputation complete. Processed data saved.")

if __name__ == '__main__':
    # --- Configuration for the imputation script ---
    # Replace with the actual path to your raw data file
    # Example for full Gambia data based on your notebook:
    RAW_DATA_NPZ_PATH = "/kaggle/input/multivariate-time-series-12/time_series/Guinea-Bissau/Guinea-Bissau_combined.npz"
    
    # Define where the imputed data will be saved
    IMPUTED_DATA_NPZ_PATH = "/kaggle/working/Guinea-Bissau_combined_imputed.npz" # Saving in working directory

    # Example for Wuli data (if you have a separate file for it)
    # RAW_DATA_NPZ_PATH_WULI = "/path/to/your/Wuli_combined_raw.npz"
    # IMPUTED_DATA_NPZ_PATH_WULI = "/path/to/your/Wuli_combined_imputed.npz"

    main_offline_imputation(RAW_DATA_NPZ_PATH, IMPUTED_DATA_NPZ_PATH)
    
    # If you want to process another dataset like Wuli, call main_offline_imputation again:
    # print("\nProcessing Wuli data (example)...")
    # main_offline_imputation(RAW_DATA_NPZ_PATH_WULI, IMPUTED_DATA_NPZ_PATH_WULI)

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import math
import joblib

# --- Configuration ---
MODEL_NAME = "GAT_Transformer" # Reflects the changes
LOOKBACK = 24
BATCH_SIZE = 128 
LEARNING_RATE = 0.001
NUM_EPOCHS = 70
PATIENCE = 20
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- Updated load_and_preprocess_data for Sequence Models (from pre-imputed data) ---
def load_preprocessed_data_for_sequence(imputed_file_path, lookback_period):
    print(f"--- {MODEL_NAME}: Loading pre-imputed data from: {imputed_file_path} ---")
    data_imputed = np.load(imputed_file_path)
    # These keys match your offline imputation script's output
    ndvi_data = data_imputed['NDVI_imputed']
    soil_moisture_data = data_imputed['SoilMoisture_imputed']
    lst_data = data_imputed['LST_imputed']
    spi_data = data_imputed['SPI_original'] # Original SPI for targets
    mask = data_imputed['mask'] # The mask used during imputation

    print(f"Loaded imputed feature shapes: NDVI: {ndvi_data.shape}, SM: {soil_moisture_data.shape}, LST: {lst_data.shape}")
    print(f"Loaded SPI_original shape: {spi_data.shape}, Mask shape: {mask.shape}, Valid pixels in mask: {np.sum(mask)}")

    time_steps, _, _ = ndvi_data.shape
    X_list = []; y_list = []; valid_sample_indices_output = []
    num_valid_pixels_in_mask = np.sum(mask)

    # Pre-extract feature timeseries for valid pixels to make inner loop faster
    ndvi_masked_ts = [ndvi_data[t][mask] for t in range(time_steps)]
    lst_masked_ts = [lst_data[t][mask] for t in range(time_steps)]
    sm_masked_ts = [soil_moisture_data[t][mask] for t in range(time_steps)]
    spi_masked_ts = [spi_data[t][mask] for t in range(time_steps)] # For targets

    for t in range(lookback_period, time_steps):
        targets_for_current_t_step = spi_masked_ts[t]
        
        for pixel_idx_in_mask in range(num_valid_pixels_in_mask):
            current_target_value = targets_for_current_t_step[pixel_idx_in_mask]
            if np.isnan(current_target_value): continue # TARGET NaN REMOVAL

            single_pixel_feature_sequence = []
            for i in range(lookback_period):
                time_index_in_original_data = t - lookback_period + i
                val_ndvi = ndvi_masked_ts[time_index_in_original_data][pixel_idx_in_mask]
                val_lst = lst_masked_ts[time_index_in_original_data][pixel_idx_in_mask]
                val_sm = sm_masked_ts[time_index_in_original_data][pixel_idx_in_mask]
                single_pixel_feature_sequence.append(np.array([val_ndvi, val_lst, val_sm]))
            
            X_list.append(np.array(single_pixel_feature_sequence))
            y_list.append(current_target_value)
            valid_sample_indices_output.append((t, pixel_idx_in_mask))
            
    X = np.array(X_list); y = np.array(y_list).reshape(-1, 1)
    print(f"Final X shape ({MODEL_NAME}): {X.shape}, y shape: {y.shape}, Valid samples: {len(valid_sample_indices_output)}")
    original_data_info = {'SPI': spi_data, 'NDVI': ndvi_data} 
    return X, y, mask, valid_sample_indices_output, original_data_info

# --- Chronological Train-Test Split ---
def chronological_train_test_split(X, y, valid_sample_indices, test_ratio=0.2):
    time_steps_of_samples = sorted(list(set([idx[0] for idx in valid_sample_indices])))
    num_total_timesteps_with_samples = len(time_steps_of_samples)
    num_test_timesteps = int(num_total_timesteps_with_samples * test_ratio)
    if num_test_timesteps == 0 and num_total_timesteps_with_samples > 1: num_test_timesteps = 1
    elif num_total_timesteps_with_samples <= 1: num_test_timesteps = 0
    if num_test_timesteps == 0 or not time_steps_of_samples:
        split_time_actual = time_steps_of_samples[-1] + 1 if time_steps_of_samples else 0
    else:
        split_time_actual = time_steps_of_samples[-num_test_timesteps]
    train_indices = [i for i, (t, _) in enumerate(valid_sample_indices) if t < split_time_actual]
    test_indices = [i for i, (t, _) in enumerate(valid_sample_indices) if t >= split_time_actual]
    X_train = X[train_indices]; X_test = X[test_indices]; y_train = y[train_indices]; y_test = y[test_indices]
    print(f"X_train shape ({MODEL_NAME}): {X_train.shape}"); print(f"X_test shape ({MODEL_NAME}): {X_test.shape}")
    test_sample_indices_output = [valid_sample_indices[i] for i in test_indices]
    return X_train, X_test, y_train, y_test, test_sample_indices_output

# --- TimeSeriesDataset for GAT (Optimized Adjacency) ---
class TimeSeriesDatasetGAT(Dataset):
    def __init__(self, features, targets, single_adj_matrix, scaler_X=None, scaler_y=None, is_train=True):
        original_shape = features.shape
        if is_train and scaler_X is None:
            self.scaler_X = StandardScaler()
            flattened = features.reshape(-1, original_shape[-1]); scaled = self.scaler_X.fit_transform(flattened)
            self.features = scaled.reshape(original_shape)
        elif not is_train and scaler_X is not None:
            self.scaler_X = scaler_X
            flattened = features.reshape(-1, original_shape[-1]); scaled = self.scaler_X.transform(flattened)
            self.features = scaled.reshape(original_shape)
        else: self.features = features; self.scaler_X = scaler_X
        if is_train and scaler_y is None:
            self.scaler_y = StandardScaler(); self.targets = self.scaler_y.fit_transform(targets)
        elif not is_train and scaler_y is not None:
            self.scaler_y = scaler_y; self.targets = self.scaler_y.transform(targets)
        else: self.targets = targets; self.scaler_y = scaler_y
        self.features = torch.tensor(self.features, dtype=torch.float32)
        self.targets = torch.tensor(self.targets, dtype=torch.float32)
        # Store the single, pre-computed adjacency matrix. Ensure it's a tensor.
        if not isinstance(single_adj_matrix, torch.Tensor):
            self.single_adj_matrix = torch.tensor(single_adj_matrix, dtype=torch.float32)
        else:
            self.single_adj_matrix = single_adj_matrix.to(dtype=torch.float32)

    def __len__(self): return len(self.features)
    def __getitem__(self, idx): return self.features[idx], self.single_adj_matrix, self.targets[idx]
    def get_scalers(self): return self.scaler_X, self.scaler_y

# --- Positional Encoding ---
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000): # max_len can be LOOKBACK + buffer
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term); pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0); self.register_buffer('pe', pe)
    def forward(self, x): # x: (batch, seq_len, d_model)
        # print(f"DEBUG PositionalEncoding input x shape: {x.shape}")
        # print(f"DEBUG PositionalEncoding self.pe original shape: {self.pe.shape}")
        # print(f"DEBUG PositionalEncoding x.size(1) [for slicing pe's seq_len]: {x.size(1)}")
        pe_sliced = self.pe[:, :x.size(1)]
        # print(f"DEBUG PositionalEncoding self.pe_sliced shape: {pe_sliced.shape}")
        # print(f"DEBUG Adding x ({x.shape}) and pe_sliced ({pe_sliced.shape})")
        try:
            result = x + pe_sliced
        except RuntimeError as e:
            print(f"ERROR during PositionalEncoding addition: x.shape={x.shape}, pe_sliced.shape={pe_sliced.shape}")
            raise e
        return result


# --- Graph Attention Layer & MultiHeadGraphAttention ---
class GraphAttentionLayer(nn.Module):
    def __init__(self, in_features, out_features, dropout=0.2, alpha=0.2):
        super(GraphAttentionLayer, self).__init__()
        self.in_features=in_features; self.out_features=out_features; self.dropout_val=dropout; self.alpha=alpha
        self.W=nn.Parameter(torch.empty(size=(in_features, out_features))); nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a=nn.Parameter(torch.empty(size=(2*out_features, 1))); nn.init.xavier_uniform_(self.a.data, gain=1.414)
        self.leakyrelu=nn.LeakyReLU(self.alpha); self.dropout_layer=nn.Dropout(self.dropout_val)
    def forward(self, h, adj): # h is (B, N, F_in), adj is (N,N) - single adj matrix passed from Dataset
        Wh=torch.matmul(h, self.W); a_input=self._prepare_attention_input(Wh)
        e=self.leakyrelu(torch.matmul(a_input, self.a).squeeze(-1))
        # adj is (N,N), e is (B,N,N). Unsqueeze adj for broadcasting.
        zero_vec = -9e15*torch.ones_like(e); attention_masked = torch.where(adj.unsqueeze(0) > 0, e, zero_vec)
        attention_softmax=F.softmax(attention_masked, dim=-1); attention_dropout=self.dropout_layer(attention_softmax)
        h_prime=torch.matmul(attention_dropout, Wh); return F.elu(h_prime)
    def _prepare_attention_input(self, Wh):
        B,N,_=Wh.size(); Wh_i=Wh.unsqueeze(2).expand(B,N,N,-1); Wh_j=Wh.unsqueeze(1).expand(B,N,N,-1)
        return torch.cat([Wh_i, Wh_j], dim=-1)

class MultiHeadGraphAttention(nn.Module):
    def __init__(self, in_features, out_features_per_head, n_heads, dropout=0.2, alpha=0.2, concat=True):
        super(MultiHeadGraphAttention, self).__init__()
        self.n_heads=n_heads; self.concat=concat
        self.attentions=nn.ModuleList([GraphAttentionLayer(in_features, out_features_per_head, dropout, alpha) for _ in range(n_heads)])
        self.out_dim = out_features_per_head * n_heads if concat else out_features_per_head
    def forward(self, x, adj): # x is (B,N,F_in), adj is (N,N)
        head_outputs=[att(x, adj) for att in self.attentions] # Each att will use the same adj
        if self.concat: return torch.cat(head_outputs, dim=-1)
        else: return torch.mean(torch.stack(head_outputs, dim=-1), dim=-1)

# --- GAT+Transformer Model ---
class SpiPredictorGATTransformer(nn.Module):
    def __init__(self, input_size, d_model_gat=64, gat_heads=4, 
                 d_model_transformer=64, transformer_heads=4, num_transformer_layers=2, 
                 dim_feedforward_transformer=256, dropout_rate=0.2): # Matched to presentation/log
        super(SpiPredictorGATTransformer, self).__init__()
        self.dropout_rate = dropout_rate
        self.gat_input_proj = nn.Linear(input_size, d_model_gat) 
        self.gat_attention = MultiHeadGraphAttention(
            in_features=d_model_gat, 
            out_features_per_head=d_model_gat // gat_heads, # Assumes divisibility
            n_heads=gat_heads, dropout=dropout_rate, concat=True
        )
        gat_output_dim = self.gat_attention.out_dim
        self.pos_encoder = PositionalEncoding(gat_output_dim, max_len=LOOKBACK + 10)
        transformer_encoder_layer = nn.TransformerEncoderLayer(
            d_model=gat_output_dim, nhead=transformer_heads, dim_feedforward=dim_feedforward_transformer,
            dropout=dropout_rate, batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(transformer_encoder_layer, num_layers=num_transformer_layers)
        # FC layers: 2 (64->32->1) from slides [cite: 1]
        self.fc = nn.Sequential(
            nn.Linear(gat_output_dim, 64), # Input is gat_output_dim (e.g. 64)
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(32, 1)
        )
    def forward(self, x, adj): # x: (B,L,F_in), adj: (L,L) - adj is NOT batched here
        # print(f"Model Input x shape: {x.shape}, adj shape: {adj.shape}")
        x_proj = F.elu(self.gat_input_proj(x))
        # print(f"After input_proj x_proj shape: {x_proj.shape}")
        x_gat_drop = F.dropout(x_proj, self.dropout_rate, training=self.training)
        x_gat_attended = self.gat_attention(x_gat_drop, adj) # GAT layers will handle adj
        # print(f"After gat_attention x_gat_attended shape: {x_gat_attended.shape}")
        
        # Ensure x_gat_attended is 3D (Batch, Seq, Feature) before PE
        if x_gat_attended.dim() == 4 and x_gat_attended.shape[0] == 1 and x_gat_attended.shape[1] == BATCH_SIZE : # Heuristic for the previous error
             print(f"Warning: Reshaping x_gat_attended from {x_gat_attended.shape} to 3D")
             x_gat_attended = x_gat_attended.squeeze(0) # (BATCH_SIZE, Seq, Feature) - This is a guess if first dim is 1

        x_pos_encoded = self.pos_encoder(x_gat_attended)
        x_transformed = self.transformer_encoder(x_pos_encoded)
        x_last_step = x_transformed[:, -1, :]
        output = self.fc(x_last_step)
        return output

# --- Training Function for GAT-based models (handles single adjacency matrix) ---
def train_model_gat_based(model, train_loader, val_loader, criterion, optimizer, num_epochs, patience, device, model_save_name):
    best_val_loss = float('inf'); epochs_no_improve = 0
    train_losses_history, val_losses_history = [], []
    for epoch in range(num_epochs):
        model.train(); epoch_train_loss = 0
        for batch_X, batch_adj_single, batch_y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]"):
            # batch_adj_single is the same for all samples in batch, already on correct device if pre-moved.
            # Model's GAT layers expect adj to be (N,N) and handle broadcasting internally or it's passed (B,N,N) if collate_fn creates it.
            # Current TimeSeriesDatasetGAT returns adj as (N,N). DataLoader makes it (B,N,N)
            batch_X, batch_adj_single, batch_y = batch_X.to(device), batch_adj_single.to(device), batch_y.to(device)
            outputs = model(batch_X, batch_adj_single[0]); # Pass only one adj matrix from the batch (they are all same)
            loss = criterion(outputs, batch_y)
            optimizer.zero_grad(); loss.backward(); optimizer.step(); epoch_train_loss += loss.item()
        epoch_train_loss /= len(train_loader); train_losses_history.append(epoch_train_loss)
        model.eval(); epoch_val_loss = 0
        with torch.no_grad():
            for batch_X, batch_adj_single, batch_y in val_loader:
                batch_X, batch_adj_single, batch_y = batch_X.to(device), batch_adj_single.to(device), batch_y.to(device)
                outputs = model(batch_X, batch_adj_single[0])
                loss = criterion(outputs, batch_y); epoch_val_loss += loss.item()
        epoch_val_loss /= len(val_loader); val_losses_history.append(epoch_val_loss)
        print(f"Epoch {epoch+1}: Train Loss={epoch_train_loss:.4f}, Val Loss={epoch_val_loss:.4f}")
        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss; epochs_no_improve = 0; torch.save(model.state_dict(), model_save_name)
        else: epochs_no_improve += 1
        if epochs_no_improve >= patience: print(f"Early stopping @ epoch {epoch+1}"); break
    if os.path.exists(model_save_name): model.load_state_dict(torch.load(model_save_name))
    return model, train_losses_history, val_losses_history

# --- Evaluation Function for GAT-based models ---
def evaluate_model_gat_based(model, test_loader, scaler_y, device, model_name="GAT-based Model"):
    model.eval(); all_preds_s, all_targets_s = [], []
    with torch.no_grad():
        for batch_X, batch_adj_single, batch_y_s in test_loader:
            batch_X, batch_adj_single = batch_X.to(device), batch_adj_single.to(device)
            outputs_s = model(batch_X, batch_adj_single[0]) # Pass one adj
            all_preds_s.append(outputs_s.cpu().numpy()); all_targets_s.append(batch_y_s.numpy())
    preds_s_np = np.vstack(all_preds_s); tgts_s_np = np.vstack(all_targets_s)
    if scaler_y: preds_o = scaler_y.inverse_transform(preds_s_np); tgts_o = scaler_y.inverse_transform(tgts_s_np)
    else: preds_o = preds_s_np; tgts_o = tgts_s_np
    mse=mean_squared_error(tgts_o,preds_o); rmse=np.sqrt(mse); r2=r2_score(tgts_o,preds_o); mae=mean_absolute_error(tgts_o,preds_o)
    print(f"\nEvaluation metrics ({model_name}):\n MSE: {mse:.4f}, RMSE: {rmse:.4f}, MAE: {mae:.4f}, R²: {r2:.4f}")
    return preds_o, tgts_o, mse, rmse, r2, mae

# --- Plotting and Reconstruction (Definitions from Common Helper Functions section) ---
def plot_training_losses(train_losses, val_losses, model_name, save_dir='visualizations'):
    if not os.path.exists(save_dir): os.makedirs(save_dir)
    save_path = os.path.join(save_dir, f'training_losses_{model_name}.png')
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, len(train_losses) + 1), train_losses, 'b-', label='Train Loss')
    plt.plot(range(1, len(val_losses) + 1), val_losses, 'r-', label='Validation Loss')

    # --- ADDED ANNOTATIONS ---
    # Get the last epoch and loss values
    last_epoch = len(train_losses)
    last_train_loss = train_losses[-1]
    last_val_loss = val_losses[-1]

    # Add text annotation for the last training loss
    plt.text(last_epoch, last_train_loss, f' {last_train_loss:.4f}', 
             color='blue', verticalalignment='center')

    # Add text annotation for the last validation loss
    plt.text(last_epoch, last_val_loss, f' {last_val_loss:.4f}', 
             color='red', verticalalignment='center')
    # --- END OF ADDED CODE ---

    # --- MODIFIED LINE from original code ---
    # Set the y-axis to go from 0 to 0.3
    plt.ylim(0, 0.3)

    plt.xlabel('Epochs'); plt.ylabel('Loss (MSE)'); plt.title(f'Training & Validation Losses ({model_name})')
    plt.legend(); plt.grid(True); plt.savefig(save_path); plt.show()

def plot_all_test_samples(targets_original, predictions_original, model_name, save_dir='visualizations'):
    if not os.path.exists(save_dir): os.makedirs(save_dir)
    save_path = os.path.join(save_dir, f'all_test_samples_{model_name}.png')
    plt.figure(figsize=(12, 8)); plt.scatter(targets_original, predictions_original, alpha=0.3, color='blue', label='Samples')
    min_val = min(targets_original.min(), predictions_original.min()); max_val = max(targets_original.max(), predictions_original.max())
    plt.plot([min_val, max_val], [min_val, max_val], 'r--', label='y=x')
    corr = np.nan
    if len(targets_original.ravel()) > 1 and len(predictions_original.ravel()) > 1:
        z = np.polyfit(targets_original.ravel(), predictions_original.ravel(), 1); p = np.poly1d(z)
        sorted_targets = np.sort(targets_original.ravel())
        plt.plot(sorted_targets, p(sorted_targets), 'g-', label=f'Regression: y={z[0]:.3f}x+{z[1]:.3f}')
        corr = np.corrcoef(targets_original.ravel(), predictions_original.ravel())[0, 1]
    plt.xlabel('True Values'); plt.ylabel('Predictions'); plt.title(f'Predictions vs True Values ({model_name}, r = {corr:.3f})')
    plt.legend(); plt.grid(True)
    mse = mean_squared_error(targets_original, predictions_original); rmse = np.sqrt(mse); mae = mean_absolute_error(targets_original, predictions_original); r2 = r2_score(targets_original, predictions_original)
    stats_text = (f"MSE: {mse:.4f}\nRMSE: {rmse:.4f}\nMAE: {mae:.4f}\nR²: {r2:.4f}\nCorr: {corr:.4f}")
    plt.figtext(0.15, 0.75, stats_text, bbox=dict(facecolor='white', alpha=0.8, boxstyle='round,pad=0.5'))
    plt.tight_layout(rect=[0, 0, 1, 0.96]); plt.savefig(save_path); plt.show()

def reconstruct_rasters(predictions_flat, test_sample_indices, original_mask, spi_full_shape):
    unique_test_time_indices_original = sorted(list(set([t for t, _ in test_sample_indices])))
    reconstructed_spatial_rasters = np.full((len(unique_test_time_indices_original), spi_full_shape[1], spi_full_shape[2]), np.nan)
    valid_pixel_row_coords, valid_pixel_col_coords = np.where(original_mask)
    for i, (original_t, pixel_idx_in_mask) in enumerate(test_sample_indices):
        relative_t_idx_in_reconstruction = unique_test_time_indices_original.index(original_t)
        h_coord = valid_pixel_row_coords[pixel_idx_in_mask]; w_coord = valid_pixel_col_coords[pixel_idx_in_mask]
        reconstructed_spatial_rasters[relative_t_idx_in_reconstruction, h_coord, w_coord] = predictions_flat[i][0]
    return reconstructed_spatial_rasters, unique_test_time_indices_original

def visualize_raster_samples(real_rasters_for_test_timesteps, pred_rasters_for_test_timesteps, 
                             mask_for_visualization, unique_test_time_indices, model_name,
                             num_samples_to_plot=5, output_dir='visualizations'):
    if not os.path.exists(output_dir): os.makedirs(output_dir)
    prefix = f'raster_sample_{model_name.lower()}'
    colors = ['#730000', '#E60000', '#FFAA00', '#FCD37F', '#FFFF00', '#FFFFFF', '#DCF8FF', '#96D2FF', '#46A5FF', '#0000FF', '#000080']
    cmap_usdm_spi = plt.cm.colors.LinearSegmentedColormap.from_list('usdm_spi', colors, N=256)
    vmin_spi, vmax_spi = -3, 3
    num_available_rasters = len(real_rasters_for_test_timesteps)
    if num_available_rasters == 0: print("No rasters to visualize."); return
    plot_indices = np.random.choice(num_available_rasters, min(num_samples_to_plot, num_available_rasters), replace=False)
    for i, raster_idx_in_test_set in enumerate(plot_indices):
        true_raster_slice = real_rasters_for_test_timesteps[raster_idx_in_test_set]
        pred_raster_slice = pred_rasters_for_test_timesteps[raster_idx_in_test_set]
        diff_raster_slice = np.full_like(true_raster_slice, np.nan)
        valid_pixels_for_diff = ~np.isnan(true_raster_slice) & ~np.isnan(pred_raster_slice)
        diff_raster_slice[valid_pixels_for_diff] = true_raster_slice[valid_pixels_for_diff] - pred_raster_slice[valid_pixels_for_diff]
        diff_abs_max = np.nanmax(np.abs(diff_raster_slice)) if not np.all(np.isnan(diff_raster_slice)) else 1.0
        fig, axes = plt.subplots(1, 3, figsize=(18, 6)); fig.suptitle(f'SPI Comparison ({model_name}) - Original Time Step: {unique_test_time_indices[raster_idx_in_test_set]}', fontsize=16)
        ax1 = axes[0]; masked_true_spi = np.ma.array(true_raster_slice, mask=~mask_for_visualization); im1 = ax1.imshow(masked_true_spi, cmap=cmap_usdm_spi, vmin=vmin_spi, vmax=vmax_spi)
        fig.colorbar(im1, ax=ax1, label='SPI Value', fraction=0.046, pad=0.04); ax1.set_title('True SPI'); ax1.axis('off')
        ax2 = axes[1]; masked_pred_spi = np.ma.array(pred_raster_slice, mask=~mask_for_visualization); im2 = ax2.imshow(masked_pred_spi, cmap=cmap_usdm_spi, vmin=vmin_spi, vmax=vmax_spi)
        fig.colorbar(im2, ax=ax2, label='SPI Value', fraction=0.046, pad=0.04); ax2.set_title('Predicted SPI'); ax2.axis('off')
        ax3 = axes[2]; masked_diff_spi = np.ma.array(diff_raster_slice, mask=~mask_for_visualization); im3 = ax3.imshow(masked_diff_spi, cmap='RdBu_r', vmin=-diff_abs_max, vmax=diff_abs_max)
        fig.colorbar(im3, ax=ax3, label='Difference (True - Predicted)', fraction=0.046, pad=0.04); ax3.set_title('Difference'); ax3.axis('off')
        plt.tight_layout(rect=[0, 0, 1, 0.95]); save_filename = os.path.join(output_dir, f'{prefix}_time_{unique_test_time_indices[raster_idx_in_test_set]}.png')
        plt.savefig(save_filename, dpi=200); plt.show()


# --- Main Function (GAT+Transformer) ---
def main_gat_transformer():
    print(f"--- Running {MODEL_NAME} Model ---")
    print(f"Using device: {DEVICE}")
    output_dir = f"results_{MODEL_NAME.lower().replace('+', '_')}" # Ensure '+' is filename-safe
    viz_dir = os.path.join(output_dir, "visualizations")
    if not os.path.exists(viz_dir): os.makedirs(viz_dir)

    imputed_file_path = "/kaggle/working/Guinea-Bissau_combined_imputed.npz" # Adjust this path
    # imputed_file_path = "/kaggle/working/Guinea-Bissau_combined_imputed.npz" # Or this for G-B

    if not os.path.exists(imputed_file_path):
        print(f"ERROR: Pre-imputed data file not found at {imputed_file_path}. Please run the offline imputation script first.")
        print("Creating dummy pre-imputed data for workflow testing...")
        T, H, W = 50, 10, 10
        dummy_ndvi_i = np.random.rand(T, H, W); dummy_sm_i = np.random.rand(T, H, W); dummy_lst_i = np.random.rand(T, H, W) + 273.15
        dummy_spi_o = np.random.randn(T, H, W) * 2; dummy_spi_o[np.random.rand(T,H,W) < 0.05] = np.nan
        dummy_mask_o = ~np.isnan(dummy_spi_o).all(axis=0); dummy_spi_o[:, ~dummy_mask_o] = np.nan
        for arr in [dummy_ndvi_i, dummy_sm_i, dummy_lst_i]:
            arr[:, ~dummy_mask_o] = np.nan
            for t_idx in range(T):
                 slice_masked = arr[t_idx][dummy_mask_o]; mean_val = np.nanmean(slice_masked)
                 mean_val = 0.0 if np.isnan(mean_val) else mean_val
                 nan_locs = np.isnan(arr[t_idx]) & dummy_mask_o; arr[t_idx][nan_locs] = mean_val
        imputed_file_path = f"dummy_imputed_data_{MODEL_NAME.lower().replace('+', '_')}.npz"
        np.savez_compressed(imputed_file_path, NDVI_imputed=dummy_ndvi_i, SoilMoisture_imputed=dummy_sm_i, LST_imputed=dummy_lst_i, SPI_original=dummy_spi_o, mask=dummy_mask_o)
        print(f"Using dummy pre-imputed data from {imputed_file_path}")

    X_all, y_all, data_mask, valid_idx_all, orig_data_info = load_preprocessed_data_for_sequence(imputed_file_path, LOOKBACK)
    if X_all.shape[0] == 0: print("No valid samples. Exiting."); return
    X_tv, X_test, y_tv, y_test, test_idx_list = chronological_train_test_split(X_all, y_all, valid_idx_all, test_ratio=0.2)
    if X_tv.shape[0] == 0: print(f"No training/validation samples for {MODEL_NAME}. Exiting."); return

    # Create the single temporal adjacency matrix
    adj_matrix_template = torch.zeros(LOOKBACK, LOOKBACK, dtype=torch.float32)
    for i in range(LOOKBACK):
        if i > 0: adj_matrix_template[i, i-1] = 1
        if i < LOOKBACK - 1: adj_matrix_template[i, i+1] = 1
        adj_matrix_template[i, i] = 1
    adj_matrix_template = adj_matrix_template.to(DEVICE) # Move to device once if possible, or handle in DataLoader/loop
    
    tv_dataset = TimeSeriesDatasetGAT(X_tv, y_tv, adj_matrix_template.cpu(), is_train=True) # Dataset takes CPU tensor or numpy
    scaler_X, scaler_y = tv_dataset.get_scalers()
    test_loader = None
    if X_test.shape[0] > 0:
        test_dataset = TimeSeriesDatasetGAT(X_test, y_test, adj_matrix_template.cpu(), scaler_X, scaler_y, is_train=False)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, persistent_workers=(DEVICE.type=='cuda'))

    train_s = int(0.8 * len(tv_dataset)); val_s = len(tv_dataset) - train_s
    if val_s == 0 and train_s > 0: train_sub, val_sub = tv_dataset, tv_dataset
    elif train_s == 0 or val_s == 0: print("Not enough data for train/val split. Exiting."); return
    else: gen=torch.Generator().manual_seed(42); train_sub, val_sub = torch.utils.data.random_split(tv_dataset, [train_s, val_s], generator=gen)

    train_loader = DataLoader(train_sub, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, persistent_workers=(DEVICE.type=='cuda'))
    val_loader = DataLoader(val_sub, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, persistent_workers=(DEVICE.type=='cuda'))
   
    model_input_size = X_tv.shape[-1] 
    model = SpiPredictorGATTransformer(
        input_size=model_input_size, d_model_gat=64, gat_heads=4, 
        d_model_transformer=64, transformer_heads=4, num_transformer_layers=2, 
        dim_feedforward_transformer=256, dropout_rate=0.2
    ).to(DEVICE)
    print(f"\n{MODEL_NAME} Model Architecture:\n{model}")
    criterion = nn.MSELoss(); optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
   
    model_save_path = os.path.join(output_dir, f"best_spi_{MODEL_NAME.lower().replace('+', '_')}_model.pt")
    trained_model, train_L, val_L = train_model_gat_based(model, train_loader, val_loader, criterion, optimizer, NUM_EPOCHS, PATIENCE, DEVICE, model_save_path)
    plot_training_losses(train_L, val_L, MODEL_NAME, save_dir=viz_dir)
   
    if test_loader:
        preds_o, tgts_o, mse, rmse, r2, mae = evaluate_model_gat_based(trained_model, test_loader, scaler_y, DEVICE, MODEL_NAME)
        plot_all_test_samples(tgts_o, preds_o, MODEL_NAME, save_dir=viz_dir)
        spi_shape_o = orig_data_info['SPI'].shape
        recon_preds, unique_test_t = reconstruct_rasters(preds_o, test_idx_list, data_mask, spi_shape_o)
        np.savez_compressed(os.path.join(output_dir, f'spi_{MODEL_NAME.lower().replace("+", "_")}_predictions.npz'), predictions=recon_preds, mask=data_mask, time_steps=unique_test_t)
        full_spi_original = np.load(imputed_file_path)['SPI_original']
        true_test_rs = np.array([full_spi_original[t] for t in unique_test_t])
        visualize_raster_samples(true_test_rs, recon_preds, data_mask, unique_test_t, MODEL_NAME, output_dir=viz_dir)
    else: print(f"Test set for {MODEL_NAME} was empty. Skipping final evaluation.")

    joblib.dump(scaler_X, os.path.join(output_dir, f"scaler_X_{MODEL_NAME.lower().replace('+', '_')}.pkl"))
    joblib.dump(scaler_y, os.path.join(output_dir, f"scaler_y_{MODEL_NAME.lower().replace('+', '_')}.pkl"))
    print(f"{MODEL_NAME} model training complete. Results in {output_dir}")

if __name__ == "__main__":
    main_gat_transformer()