In [1]:
# -*- coding: utf-8 -*-
"""GNN Model.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1pSO-VEyn5Cywjw9sXKn2fjXdVbI3hAsH
"""

# -*- coding: utf-8 -*-
"""Untitled2.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/193StgLnr4doKklAxwBiQsVX3njEfb1oa
"""


DATA_PATH = r"C:\Users\hu4227mo-s\OneDrive - Lund University\updated_data (4).xlsx"

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import time
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import os
from datetime import datetime, timedelta
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cpu')  # Using CPU as specified

# ============================================================
# MODIFIED DATA LOADING AND PROCESSING FOR TARGETED SEASONAL ANALYSIS
# ============================================================

def load_data(file_path):
    """
    Load weather data, clean missing values, and filter to 5 stations representing different topographies.
    """
    try:
        # Determine file type and read
        if file_path.endswith('.xlsx'):
            df = pd.read_excel(file_path, engine='openpyxl')
        else:
            df = pd.read_csv(file_path, encoding='ISO-8859-1')
        # Rest of function stays the same
    except Exception as e:
        print(f"Error loading data: {e}")
        return pd.DataFrame()  # Return empty DataFrame

    # Convert timestamp
    if 'timestamp' in df.columns:
        df['timestamp'] = pd.to_datetime(df['timestamp'])
    elif 'DATE' in df.columns:
        df['timestamp'] = pd.to_datetime(df['DATE'])
        df = df.rename(columns={'DATE': 'date_original'})

    # Ensure time-based ordering before interpolation
    df = df.sort_values(by='timestamp')

    # Add hour of day feature - sine/cosine encoding for cyclical pattern
    df['hour_sin'] = np.sin(2 * np.pi * df['timestamp'].dt.hour / 24)
    df['hour_cos'] = np.cos(2 * np.pi * df['timestamp'].dt.hour / 24)

    # Add day of year feature - sine/cosine encoding for cyclical pattern
    df['day_sin'] = np.sin(2 * np.pi * df['timestamp'].dt.dayofyear / 365.25)
    df['day_cos'] = np.cos(2 * np.pi * df['timestamp'].dt.dayofyear / 365.25)

    # Define region mapping for stations
    region_mapping = {
        'San Francisco': 'coastal',
        'Mammoth Lakes': 'mountain',
        'Palm Spring': 'desert',
        'Fresno': 'valley',
        'LA Downtown': 'urban'
    }

    # Add region information
    if 'station_id' in df.columns:
        df['region'] = df['station_id'].map(region_mapping)

        # Filter to keep only one station per region (5 stations total)
        kept_stations = list(region_mapping.keys())
        df = df[df['station_id'].isin(kept_stations)]

        # Convert region to numerical encoding
        region_to_num = {region: i for i, region in enumerate(df['region'].unique())}
        df['region_code'] = df['region'].map(region_to_num)

        # Add elevation as a numerical topographic feature
        elevation_mapping = {
            'San Francisco': 16,      # meters
            'Mammoth Lakes': 2500,    # meters
            'Palm Spring': 146,      # meters
            'Fresno': 99,             # meters
            'LA Downtwon': 93         # meters
        }
        df['elevation'] = df['station_id'].map(elevation_mapping)
        # Normalize elevation
        df['elevation_norm'] = (df['elevation'] - df['elevation'].min()) / (df['elevation'].max() - df['elevation'].min())

        print(f"Filtered data to {len(kept_stations)} stations: {kept_stations}")

    # Interpolate missing values along the time dimension
    df.interpolate(method='linear', limit_direction='both', inplace=True)

    return df

def extract_target_days(df):
    """
    Extract the 4 specific target days (one per season) for predictions
    """
    # Define target days
    target_days = [
        {'season': 'Spring', 'month': 4, 'day': 15},  # April 15
        {'season': 'Summer', 'month': 7, 'day': 20},  # July 20
        {'season': 'Fall', 'month': 10, 'day': 10},   # October 10
        {'season': 'Winter', 'month': 1, 'day': 15}   # January 15
    ]

    # Filter for each target day
    target_data = {}
    for target in target_days:
        # Filter by month and day
        day_data = df[(df['timestamp'].dt.month == target['month']) &
                       (df['timestamp'].dt.day == target['day'])]

        # Get the most recent year that has data for this day
        if not day_data.empty:
            latest_year = day_data['timestamp'].dt.year.max()
            target_day_data = day_data[day_data['timestamp'].dt.year == latest_year]
            target_data[target['season']] = target_day_data
            print(f"Found {len(target_day_data)} records for {target['season']} target day ({target['month']}/{target['day']}/{latest_year})")
        else:
            print(f"WARNING: No data found for {target['season']} target day")

    return target_data

def prepare_seasonal_training_data(df, target_days):
    """
    For each target day, prepare all historical data for training
    """
    training_sets = {}

    for season, target_day_data in target_days.items():
        if target_day_data.empty:
            continue

        # Get the date of this target
        sample_date = target_day_data['timestamp'].iloc[0]
        target_year = sample_date.year

        # Use all historical data prior to the target year
        historical_data = df[df['timestamp'].dt.year < target_year]

        training_sets[season] = historical_data
        print(f"{season} training set: {len(historical_data)} samples from all historical data")

    return training_sets

def normalize_features(train_df, val_df, feature_cols):
    """
    Normalize features using StandardScaler fitted on training data
    """
    scaler = StandardScaler()

    # Fit on training data
    scaler.fit(train_df[feature_cols])

    # Transform datasets
    train_scaled = scaler.transform(train_df[feature_cols])
    val_scaled = scaler.transform(val_df[feature_cols])

    # Convert back to DataFrames
    train_norm = pd.DataFrame(train_scaled, columns=feature_cols, index=train_df.index)
    val_norm = pd.DataFrame(val_scaled, columns=feature_cols, index=val_df.index)

    return train_norm, val_norm, scaler

class WeatherDataset(Dataset):
    """
    Dataset for weather forecasting with sliding window approach.
    Input: sequence of weather data
    Output: next time step(s) temperature
    """
    def __init__(self, df, station_ids, feature_cols, seq_length=24, forecast_horizon=24):
        """
        Args:
            df: DataFrame with weather data
            station_ids: List of station IDs
            feature_cols: List of feature columns to use as input
            seq_length: Length of input sequence (in hours)
            forecast_horizon: How many hours ahead to predict
        """
        self.df = df
        self.station_ids = station_ids
        self.feature_cols = feature_cols
        self.seq_length = seq_length
        self.forecast_horizon = forecast_horizon
        self.n_stations = len(station_ids)

        # Group data by station_id for faster access
        self.station_data = {station: df[df['station_id'] == station].sort_values('timestamp') 
                            for station in station_ids}
        
        # Get unique timestamps across all stations
        all_timestamps = sorted(df['timestamp'].unique())
        
        # Find timestamps that have data for all stations
        valid_timestamps = []
        for ts in all_timestamps:
            if all(len(self.station_data[station][self.station_data[station]['timestamp'] == ts]) > 0 
                  for station in station_ids):
                valid_timestamps.append(ts)
        
        self.timestamps = valid_timestamps
        
        # Find valid window starting indices
        valid_idx = []
        for i in range(len(self.timestamps) - (seq_length + forecast_horizon - 1)):
            # Check if we have a continuous sequence
            start_time = self.timestamps[i]
            expected_times = [start_time + timedelta(hours=h) for h in range(seq_length + forecast_horizon)]
            
            # If all expected timestamps exist in our dataset
            if all(t in self.timestamps for t in expected_times):
                valid_idx.append(i)
        
        self.valid_indices = valid_idx
        
        if len(self.valid_indices) == 0:
            print(f"WARNING: No valid continuous windows found. Using relaxed requirements.")
            # Fall back to allowing any windows with at least input sequence length
            valid_idx = []
            for i in range(len(self.timestamps) - seq_length):
                valid_idx.append(i)
            self.valid_indices = valid_idx
            self.fallback_mode = True
            print(f"Found {len(self.valid_indices)} windows with relaxed continuity requirements")
        else:
            self.fallback_mode = False
            print(f"Created dataset with {len(self.valid_indices)} valid continuous windows")

    def __len__(self):
        return max(1, len(self.valid_indices))  # Ensure length is at least 1

    def __getitem__(self, idx):
        if len(self.valid_indices) == 0:
            # Return dummy data if no valid indices
            X = np.zeros((len(self.feature_cols), self.n_stations, self.seq_length))
            y = np.zeros((self.n_stations, self.forecast_horizon))
            static_features = np.zeros((self.n_stations, 2))  # region_code and elevation
            return (torch.FloatTensor(X), torch.FloatTensor(static_features)), torch.FloatTensor(y)

        # Get actual data when possible
        start_idx = self.valid_indices[idx % len(self.valid_indices)]

        # Get timestamps for input and output windows
        input_timestamps = self.timestamps[start_idx:start_idx + self.seq_length]
        output_timestamps = []
        
        # For fallback mode, just get as many valid output timestamps as possible
        if self.fallback_mode:
            next_idx = start_idx + self.seq_length
            while len(output_timestamps) < self.forecast_horizon and next_idx < len(self.timestamps):
                output_timestamps.append(self.timestamps[next_idx])
                next_idx += 1
        else:
            # Normal mode - get consecutive timestamps
            output_timestamps = self.timestamps[start_idx + self.seq_length:
                                            start_idx + self.seq_length + self.forecast_horizon]

        # Handle potential shortfall in output window
        if len(output_timestamps) < self.forecast_horizon:
            # Pad with repetition of last timestamp if needed
            last_time = output_timestamps[-1] if len(output_timestamps) > 0 else input_timestamps[-1]
            padding = [last_time] * (self.forecast_horizon - len(output_timestamps))
            output_timestamps = list(output_timestamps) + padding

        # Initialize tensors
        # [features, stations, time]
        X = np.zeros((len(self.feature_cols), self.n_stations, self.seq_length))
        # [stations, forecast_horizon]
        y = np.zeros((self.n_stations, self.forecast_horizon))

        # Static features for each station
        static_features = np.zeros((self.n_stations, 2))  # region_code and elevation

        # Fill in data for each station
        for s_idx, station_id in enumerate(self.station_ids):
            station_df = self.station_data[station_id]
            
            # Get static features (same for all timestamps)
            station_static = station_df.iloc[0]
            static_features[s_idx, 0] = station_static['region_code']
            static_features[s_idx, 1] = station_static['elevation_norm']
            
            # Input sequence
            for t_idx, ts in enumerate(input_timestamps):
                station_data = station_df[station_df['timestamp'] == ts]
                
                if not station_data.empty:
                    for f_idx, feat in enumerate(self.feature_cols):
                        X[f_idx, s_idx, t_idx] = station_data[feat].values[0]

            # Target sequence (temperature only)
            for t_idx, ts in enumerate(output_timestamps):
                if t_idx < self.forecast_horizon:  # Safety check
                    station_data = station_df[station_df['timestamp'] == ts]
                    
                    if not station_data.empty:
                        y[s_idx, t_idx] = station_data['Temperature_C'].values[0]
                    # If no data available, keep the initialized zero value

        return (torch.FloatTensor(X), torch.FloatTensor(static_features)), torch.FloatTensor(y)
# ============================================================
# TEMPORAL FUSION TRANSFORMER IMPLEMENTATION
# ============================================================
class TemporalSelfAttention(nn.Module):
    """
    Multi-head self-attention layer for temporal data.
    Simplified from the original TFT paper.
    """
    def __init__(self, d_model, n_heads=2, dropout=0.1):
        super(TemporalSelfAttention, self).__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)

        self.out = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        batch_size, seq_length, _ = x.size()

        # Linear projections
        queries = self.query(x).view(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
        keys = self.key(x).view(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
        values = self.value(x).view(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)

        # Attention scores
        scores = torch.matmul(queries, keys.transpose(-2, -1)) / np.sqrt(self.head_dim)
        attention = F.softmax(scores, dim=-1)
        attention = self.dropout(attention)

        # Apply attention to values
        out = torch.matmul(attention, values)
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)

        # Final linear layer
        return self.out(out)

class GatedResidualNetwork(nn.Module):
    """
    Gated Residual Network as described in the TFT paper.
    Simplified version with fewer layers.
    """
    def __init__(self, input_size, hidden_size, output_size, dropout=0.1):
        super(GatedResidualNetwork, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size

        # If input and output sizes are different, apply a skip connection
        self.skip_layer = None
        if input_size != output_size:
            self.skip_layer = nn.Linear(input_size, output_size)

        # Main layers
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.gate = nn.Linear(input_size + output_size, output_size)

        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(output_size)

    def forward(self, x):
        # Main branch
        hidden = F.elu(self.fc1(x))
        hidden = self.dropout(hidden)
        hidden = self.fc2(hidden)

        # Skip connection
        if self.skip_layer is not None:
            skip = self.skip_layer(x)
        else:
            skip = x

        # Gate mechanism
        gate_input = torch.cat([x, hidden], dim=-1)
        gate = torch.sigmoid(self.gate(gate_input))

        # Combine using gate
        output = gate * hidden + (1 - gate) * skip

        # Layer normalization
        return self.layer_norm(output)

class VariableSelectionNetwork(nn.Module):
    """
    Variable Selection Network for TFT.
    Simplified version with fewer layers.
    """
    def __init__(self, input_size_per_var, num_vars, hidden_size, output_size, dropout=0.1):
        super(VariableSelectionNetwork, self).__init__()
        self.input_size_per_var = input_size_per_var
        self.num_vars = num_vars
        self.hidden_size = hidden_size
        self.output_size = output_size

        # GRN for variable weights
        self.weight_grn = GatedResidualNetwork(
            input_size=input_size_per_var * num_vars,
            hidden_size=hidden_size,
            output_size=num_vars,
            dropout=dropout
        )

        # GRN for each variable
        self.var_grns = nn.ModuleList([
            GatedResidualNetwork(
                input_size=input_size_per_var,
                hidden_size=hidden_size,
                output_size=output_size,
                dropout=dropout
            ) for _ in range(num_vars)
        ])

    def forward(self, x):
        # x shape: [batch_size, num_vars, input_size_per_var]
        batch_size = x.size(0)
        flat_x = x.view(batch_size, -1)

        # Calculate variable weights
        var_weights = self.weight_grn(flat_x)
        var_weights = F.softmax(var_weights, dim=-1).unsqueeze(-1)  # [batch_size, num_vars, 1]

        # Transform each variable
        var_outputs = []
        for i in range(self.num_vars):
            var_outputs.append(self.var_grns[i](x[:, i]))

        var_outputs = torch.stack(var_outputs, dim=1)  # [batch_size, num_vars, output_size]

        # Weighted combination
        outputs = torch.sum(var_outputs * var_weights, dim=1)  # [batch_size, output_size]

        return outputs, var_weights

class TemporalFusionTransformer(nn.Module):
    """
    Simplified Temporal Fusion Transformer for temperature forecasting.
    """
    def __init__(self, num_features, num_stations, hidden_size=64, num_heads=2, dropout=0.1, forecast_horizon=24):
        super(TemporalFusionTransformer, self).__init__()
        self.num_features = num_features
        self.num_stations = num_stations
        self.hidden_size = hidden_size
        self.forecast_horizon = forecast_horizon  # Store forecast horizon

        # Static variable processing (region_code, elevation)
        self.static_var_processor = GatedResidualNetwork(
            input_size=2,
            hidden_size=hidden_size,
            output_size=hidden_size,
            dropout=dropout
        )

        # Variable selection for time-varying features
        self.temporal_var_selection = VariableSelectionNetwork(
            input_size_per_var=24,  # Sequence length per feature
            num_vars=num_features,
            hidden_size=hidden_size,
            output_size=hidden_size,
            dropout=dropout
        )

        # LSTM encoder
        self.lstm_encoder = nn.LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size,
            batch_first=True
        )

        # Temporal self-attention
        self.self_attention = TemporalSelfAttention(
            d_model=hidden_size,
            n_heads=num_heads,
            dropout=dropout
        )

        # Final output layers
        self.fc1 = nn.Linear(hidden_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, forecast_horizon)  # 24-hour forecast

    def forward(self, inputs):
        # Unpack inputs
        temporal_features, static_features = inputs
        batch_size = temporal_features.size(0)

        # [batch, features, stations, time] -> [batch*stations, features, time]
        temporal_features = temporal_features.permute(0, 2, 1, 3)
        temporal_features = temporal_features.reshape(batch_size * self.num_stations, self.num_features, -1)

        # Static features: [batch, stations, static_dims] -> [batch*stations, static_dims]
        static_features = static_features.reshape(batch_size * self.num_stations, -1)

        # Process static features
        static_embeddings = self.static_var_processor(static_features)

        # Process temporal features with variable selection
        temporal_embeddings, temporal_weights = self.temporal_var_selection(temporal_features)

        # Reshape to [batch*stations, seq_len, hidden]
        temporal_embeddings = temporal_embeddings.unsqueeze(1).expand(-1, 24, -1)

        # Add static embeddings to each timestep
        temporal_embeddings = temporal_embeddings + static_embeddings.unsqueeze(1)

        # LSTM encoding
        lstm_out, _ = self.lstm_encoder(temporal_embeddings)

        # Self-attention
        attention_out = self.self_attention(lstm_out)

        # Final prediction
        outputs = F.relu(self.fc1(attention_out))
        outputs = self.fc2(outputs)

        # Take the last 24 timesteps for the forecast
        forecast = outputs[:, -24:, 0]

        # Reshape back to [batch, stations, horizon]
        forecast = forecast.reshape(batch_size, self.num_stations, -1)

        return forecast

# ============================================================
# TRAINING AND EVALUATION FUNCTIONS
# ============================================================
def train_model(model, train_loader, val_loader, learning_rate=0.001, epochs=20, patience=5):
    """
    Train the model with early stopping based on validation loss.
    """
    model.to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=2, verbose=True
    )

    best_val_loss = float('inf')
    patience_counter = 0
    train_losses = []
    val_losses = []
    epoch_times = []

    print(f"Starting training for {epochs} epochs with patience {patience}...")
    total_start_time = time.time()

    for epoch in range(epochs):
        epoch_start_time = time.time()
        # Training
        model.train()
        train_loss = 0
        train_batches = 0

        batch_times = []
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            batch_start = time.time()
            # Move to device - handle different types of input structures
            if isinstance(inputs, tuple):
                inputs = tuple(x.to(device) for x in inputs)
            elif isinstance(inputs, list):
                inputs = [x.to(device) if isinstance(x, torch.Tensor) else x for x in inputs]
            else:
                inputs = inputs.to(device)

            targets = targets.to(device)

            # Forward pass
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # Backward pass
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_batches += 1
            
            batch_end = time.time()
            batch_time = batch_end - batch_start
            batch_times.append(batch_time)
            
            # Print progress every 10 batches
            if (batch_idx + 1) % 10 == 0:
                print(f"  Batch {batch_idx+1}/{len(train_loader)}, Loss: {loss.item():.4f}, Time: {batch_time:.3f}s")

        avg_train_loss = train_loss / max(1, train_batches)
        train_losses.append(avg_train_loss)
        
        avg_batch_time = sum(batch_times) / len(batch_times) if batch_times else 0
        
        # Validation
        val_start_time = time.time()
        model.eval()
        val_loss = 0
        val_batches = 0

        with torch.no_grad():
            for inputs, targets in val_loader:
                # Move to device
                if isinstance(inputs, tuple):
                    inputs = tuple(x.to(device) for x in inputs)
                else:
                    inputs = inputs.to(device)
                targets = targets.to(device)

                # Forward pass
                outputs = model(inputs)
                loss = criterion(outputs, targets)

                val_loss += loss.item()
                val_batches += 1

        avg_val_loss = val_loss / max(1, val_batches)
        val_losses.append(avg_val_loss)
        
        val_time = time.time() - val_start_time
        epoch_end_time = time.time()
        epoch_time = epoch_end_time - epoch_start_time
        epoch_times.append(epoch_time)
        
        # Calculate estimated time remaining
        avg_epoch_time = sum(epoch_times) / len(epoch_times)
        remaining_epochs = epochs - (epoch + 1)
        est_time_remaining = avg_epoch_time * remaining_epochs
        
        # Print detailed progress
        print(f"Epoch {epoch+1}/{epochs}")
        print(f"  Train Loss: {avg_train_loss:.4f} ({train_batches} batches, avg batch time: {avg_batch_time:.3f}s)")
        print(f"  Val Loss: {avg_val_loss:.4f} (validation time: {val_time:.2f}s)")
        print(f"  Epoch Time: {epoch_time:.2f}s, Est. Remaining: {est_time_remaining/60:.2f} minutes")
        print(f"  Elapsed Time: {(epoch_end_time - total_start_time)/60:.2f} minutes")

        # Learning rate scheduler
        scheduler.step(avg_val_loss)

        # Early stopping
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            # Save best model
            try:
                torch.save(model.state_dict(), 'best_tft_model.pth')
                print(f"  Saved best model with val loss: {best_val_loss:.4f}")
            except Exception as e:
                print(f"  Error saving model: {e}")
        else:
            patience_counter += 1
            print(f"  No improvement for {patience_counter}/{patience} epochs")
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

    total_time = time.time() - total_start_time
    print(f"Training completed in {total_time/60:.2f} minutes ({total_time:.1f} seconds)")
    print(f"Average epoch time: {sum(epoch_times)/len(epoch_times):.2f} seconds")

    # Load best model
    try:
        model.load_state_dict(torch.load('best_tft_model.pth'))
    except Exception as e:
        print(f"Error loading best model: {e}")

    return model, train_losses, val_losses

def evaluate_model(model, data_loader, station_ids, regions):
    """
    Evaluate the model and calculate metrics.
    """
    # Check if GPU is available
    #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.eval()
    all_preds = []
    all_targets = []
    station_preds = {station: [] for station in station_ids}
    station_targets = {station: [] for station in station_ids}

    with torch.no_grad():
        for inputs, targets in data_loader:
            # Move to device - handle different types of input structures
            if isinstance(inputs, tuple):
                inputs = tuple(x.to(device) for x in inputs)
            elif isinstance(inputs, list):
                inputs = [x.to(device) if isinstance(x, torch.Tensor) else x for x in inputs]
            else:
                inputs = inputs.to(device)

            # Forward pass
            outputs = model(inputs)

            # Move to CPU for further processing
            outputs = outputs.cpu().numpy()
            targets = targets.numpy()

            all_preds.append(outputs)
            all_targets.append(targets)

            # Store predictions by station
            for i, station in enumerate(station_ids):
                station_preds[station].append(outputs[:, i, :])
                station_targets[station].append(targets[:, i, :])

    # Concatenate predictions and targets
    all_preds = np.concatenate(all_preds, axis=0)
    all_targets = np.concatenate(all_targets, axis=0)

    # Calculate overall metrics
    rmse = np.sqrt(mean_squared_error(all_targets.flatten(), all_preds.flatten()))
    mae = mean_absolute_error(all_targets.flatten(), all_preds.flatten())
    r2 = r2_score(all_targets.flatten(), all_preds.flatten())

    print(f"Overall Metrics - RMSE: {rmse:.4f}, MAE: {mae:.4f}, R²: {r2:.4f}")

    # Calculate metrics by station/region
    station_metrics = {}
    for i, station in enumerate(station_ids):
        station_pred = np.concatenate([p[:, i, :] for p in all_preds], axis=0).flatten()
        station_target = np.concatenate([t[:, i, :] for t in all_targets], axis=0).flatten()

        station_rmse = np.sqrt(mean_squared_error(station_target, station_pred))
        station_mae = mean_absolute_error(station_target, station_pred)
        station_r2 = r2_score(station_target, station_pred)

        station_metrics[station] = {
            'region': regions.get(station, 'Unknown'),
            'rmse': station_rmse,
            'mae': station_mae,
            'r2': station_r2
        }

        print(f"Station {station} ({regions.get(station, 'Unknown')}) - "
              f"RMSE: {station_rmse:.4f}, MAE: {station_mae:.4f}, R²: {station_r2:.4f}")

    return rmse, mae, r2, station_metrics

def visualize_predictions(model, data_loader, station_ids, regions, season):
    """
    Visualize predictions for each station.
    """
    model.eval()
    if len(data_loader) == 0:
        print("No data available for visualization")
        return
    
    # Get predictions
    try:
        for inputs, targets in data_loader:
            # Only process one batch for visualization
            if isinstance(inputs, tuple):
                inputs = tuple(x.to(device) for x in inputs)
            else:
                inputs = inputs.to(device)

            # Forward pass
            outputs = model(inputs)

            # Move to CPU for plotting
            outputs = outputs.cpu().numpy()
            targets = targets.numpy()
            break

         # Check if we have data to plot
        if 'outputs' not in locals():
            print("No data was loaded from the dataloader")
            return
    except Exception as e:
        print(f"Error processing visualization data: {e}")
        return
    
    # Create subplots for each station
    fig, axes = plt.subplots(len(station_ids), 1, figsize=(12, 3*len(station_ids)))
    if len(station_ids) == 1:
        axes = [axes]

    hours = np.arange(24)

    for i, station in enumerate(station_ids):
        ax = axes[i]

        # Plot actual vs predicted
        ax.plot(hours, targets[0, i, :], 'b-', label='Actual')
        ax.plot(hours, outputs[0, i, :], 'r--', label='Predicted')

        ax.set_title(f"{station} ({regions.get(station, 'Unknown')}) - {season}")
        ax.set_xlabel('Hour of Day')
        ax.set_ylabel('Temperature (°C)')
        ax.legend()
        ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()
    plt.savefig(f"{season}_predictions.png")
    plt.close()


def analyze_topographic_performance(station_metrics, regions):
    """
    Analyze model performance across different topographic regions.
    """
    # Group metrics by region
    region_metrics = {}
    for station, metrics in station_metrics.items():
        region = regions.get(station, 'Unknown')
        if region not in region_metrics:
            region_metrics[region] = []
        region_metrics[region].append(metrics)

    # Calculate average metrics by region
    region_avg_metrics = {}
    for region, metrics_list in region_metrics.items():
        avg_rmse = np.mean([m['rmse'] for m in metrics_list])
        avg_mae = np.mean([m['mae'] for m in metrics_list])
        avg_r2 = np.mean([m['r2'] for m in metrics_list])

        region_avg_metrics[region] = {
            'avg_rmse': avg_rmse,
            'avg_mae': avg_mae,
            'avg_r2': avg_r2
        }

        print(f"Region {region} - Avg RMSE: {avg_rmse:.4f}, Avg MAE: {avg_mae:.4f}, Avg R²: {avg_r2:.4f}")

        if not station_metrics:
            print("No station metrics available for analysis")
            return {}

    # Create bar chart comparing regions
    regions = list(region_avg_metrics.keys())
    rmse_values = [region_avg_metrics[r]['avg_rmse'] for r in regions]

    plt.figure(figsize=(10, 6))
    bars = plt.bar(regions, rmse_values)

    # Add styling
    plt.title('RMSE by Topographic Region', fontsize=16)

    plt.title('RMSE by Topographic Region', fontsize=16)
    plt.ylabel('RMSE (°C)', fontsize=14)
    plt.xlabel('Region', fontsize=14)
    plt.xticks(rotation=45)
    plt.grid(True, alpha=0.3)

    # Add value labels on bars
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 0.1,
                 f'{height:.2f}',
                 ha='center', va='bottom', fontsize=12)

    plt.tight_layout()
    plt.show()
    plt.savefig('region_performance.png')
    plt.close()

    return region_avg_metrics

def analyze_seasonal_performance(seasonal_results):
    """
    Compare model performance across different seasons.
    """
    seasons = list(seasonal_results.keys())
    rmse_values = [results['rmse'] for results in seasonal_results.values()]
    mae_values = [results['mae'] for results in seasonal_results.values()]

    if not seasonal_results:
        print("No seasonal results available for analysis")
        return [], [], []

    # Create grouped bar chart
    x = np.arange(len(seasons))
    width = 0.35

    fig, ax = plt.subplots(figsize=(12, 6))
    ax.bar(x - width/2, rmse_values, width, label='RMSE')
    ax.bar(x + width/2, mae_values, width, label='MAE')

    ax.set_title('Model Performance by Season', fontsize=16)
    ax.set_ylabel('Error (°C)', fontsize=14)
    ax.set_xticks(x)
    ax.set_xticklabels(seasons)
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Add value labels
    for i, v in enumerate(rmse_values):
        ax.text(i - width/2, v + 0.1, f'{v:.2f}', ha='center', fontsize=10)

    for i, v in enumerate(mae_values):
        ax.text(i + width/2, v + 0.1, f'{v:.2f}', ha='center', fontsize=10)

    plt.tight_layout()
    plt.show()
    plt.savefig('seasonal_performance.png')
    plt.close()

    return seasons, rmse_values, mae_values

# ============================================================
# MAIN EXECUTION
# ============================================================
# ============================================================
# MAIN EXECUTION
# ============================================================

if __name__ == "__main__":
    print("California Weather Forecasting with Temporal Fusion Transformer")
    print("=" * 70)


    df = load_data(DATA_PATH)
    print(f"Loaded data with {len(df)} records")

    # Extract target days (one per season) - keep this for evaluation
    target_days = extract_target_days(df)

    # Instead of preparing training data by season, use all data
    print("\nPreparing unified dataset from all historical data")

    # Get all stations that appear in the dataset
    all_stations = df['station_id'].unique()
    print(f"Using {len(all_stations)} stations: {all_stations}")

    # Create region mapping
    regions = {station: group for station, group in
              zip(df['station_id'].unique(), df['region'].unique())}

    # Define feature columns to use (unchanged)
    feature_cols = [
        'Temperature_C',
        'HourlyRelativeHumidity',
        'HourlyStationPressure',
        'hour_sin', 'hour_cos',
        'day_sin', 'day_cos'
    ]

    # Create a single unified dataset
    full_dataset = WeatherDataset(
        df=df,
        station_ids=list(all_stations),
        feature_cols=feature_cols,
        seq_length=24,
        forecast_horizon=24
    )

    # Split into train/val/test
    dataset_size = len(full_dataset)
    train_size = int(dataset_size * 0.7)
    val_size = int(dataset_size * 0.15)
    test_size = dataset_size - train_size - val_size

    train_dataset, val_dataset, test_dataset = random_split(
        full_dataset, [train_size, val_size, test_size])

    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)  # For visualization

    # Create and train model
    model = TemporalFusionTransformer(
        num_features=len(feature_cols),
        num_stations=len(all_stations),
        hidden_size=64,
        num_heads=2,
        dropout=0.1
    )

    # Train model
    print(f"Training unified model on all historical data...")
    model, train_losses, val_losses = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        learning_rate=0.001,
        epochs=20,
        patience=5
    )

    # Plot training curves
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Training and Validation Loss - All Data')
    plt.xlabel('Epoch')
    plt.ylabel('Loss (MSE)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig('all_data_training_curve.png')
    plt.close()

    # Evaluate model
    print(f"Evaluating unified model...")
    rmse, mae, r2, station_metrics = evaluate_model(
        model=model,
        data_loader=test_loader,
        station_ids=all_stations,
        regions=regions
    )

    # Visualize predictions
    visualize_predictions(
        model=model,
        data_loader=test_loader,
        station_ids=all_stations,
        regions=regions,
        season="All Seasons"
    )

    # Analyze performance by topography
    region_metrics = analyze_topographic_performance(station_metrics, regions)

    # Save final results to CSV
    results_df = pd.DataFrame([{
        'Dataset': 'All Data',
        'RMSE': rmse,
        'MAE': mae,
        'R2': r2
    }])

    results_df.to_csv('unified_results.csv', index=False)
    print("Summary results saved to unified_results.csv")

    # Generate detailed station-level results
    station_results = []
    for station, metrics in station_metrics.items():
        station_results.append({
            'Station': station,
            'Region': metrics['region'],
            'RMSE': metrics['rmse'],
            'MAE': metrics['mae'],
            'R2': metrics['r2']
        })

    station_results_df = pd.DataFrame(station_results)
    station_results_df.to_csv('station_results.csv', index=False)
    print("Station-level results saved to station_results.csv")

California Weather Forecasting with Temporal Fusion Transformer
Filtered data to 5 stations: ['San Francisco', 'Mammoth Lakes', 'Palm Spring', 'Fresno', 'LA Downtown']
Loaded data with 268746 records
Found 209 records for Spring target day (4/15/2024)
Found 169 records for Summer target day (7/20/2024)
Found 165 records for Fall target day (10/10/2024)
Found 216 records for Winter target day (1/15/2024)

Preparing unified dataset from all historical data
Using 5 stations: ['Palm Spring' 'LA Downtown' 'San Francisco' 'Fresno' 'Mammoth Lakes']
Found 216133 windows with relaxed continuity requirements
Training unified model on all historical data...
Starting training for 20 epochs with patience 5...
  Batch 10/9456, Loss: 0.0374, Time: 0.101s
  Batch 20/9456, Loss: 0.0408, Time: 0.087s


KeyboardInterrupt: 