In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from torch_geometric.nn import SAGEConv
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from datetime import datetime, timedelta
from tqdm import tqdm
import math
from torch.utils.data import Dataset, DataLoader
from sklearn.neighbors import BallTree
from math import radians
from statsmodels.tsa.stattools import acf
import os
import warnings


class Time2VecEncoding(nn.Module):
    """Time2Vec encoding for temporal features with improved initialization"""
    def __init__(self, h_dim, scale=1):
        super(Time2VecEncoding, self).__init__()
        # Better initialization for stability
        self.w0 = nn.parameter.Parameter(torch.randn(1, 1) * 0.1)
        self.b0 = nn.parameter.Parameter(torch.zeros(1))
        self.w = nn.parameter.Parameter(torch.randn(1, h_dim-1) * 0.1)
        self.b = nn.parameter.Parameter(torch.zeros(h_dim-1))
        self.f = torch.sin
        self.scale = scale

    def forward(self, time):
        # Ensure time is 2D: [batch_size, 1]
        if len(time.shape) == 1:
            time = time.unsqueeze(-1)
        elif len(time.shape) == 3:
            # If time has shape [batch_size, seq_len, 1], take the last time step
            time = time[:, -1, :].unsqueeze(-1)
            
        # Scale time
        time = time / self.scale
        
        # Calculate linear and periodic components
        v1 = torch.matmul(time, self.w0) + self.b0  # Shape: [batch_size, 1]
        v2 = self.f(torch.matmul(time, self.w) + self.b)  # Shape: [batch_size, h_dim-1]
        
        return torch.cat([v1, v2], dim=1)  # Shape: [batch_size, h_dim]


class FeatureAttention(nn.Module):
    """Attention mechanism for feature importance"""
    def __init__(self, input_dim):
        super(FeatureAttention, self).__init__()
        self.attention = nn.Sequential(
            nn.Linear(input_dim, input_dim // 2),
            nn.ReLU(),
            nn.Linear(input_dim // 2, input_dim),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        # Calculate attention weights
        attention_weights = self.attention(x)
        # Apply attention
        return x * attention_weights


class ResidualBlock(nn.Module):
    """Residual block for improved gradient flow"""
    def __init__(self, dim):
        super(ResidualBlock, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(dim, dim),
            nn.LayerNorm(dim),
            nn.ReLU(),
            nn.Linear(dim, dim),
            nn.LayerNorm(dim)
        )
        
    def forward(self, x):
        return x + self.layers(x)


class EnhancedSTRAP(nn.Module):
    """Enhanced implementation of ST-RAP model with stability improvements"""
    def __init__(self, input_dim, temporal_dim, hidden_dim=256, num_gru_layers=3, dropout_rate=0.2):
        super(EnhancedSTRAP, self).__init__()
        
        # Input feature attention
        self.feature_attention = FeatureAttention(input_dim)
        
        # Feature embeddings with increased dimensionality
        self.property_embedding = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )
        
        # Enhanced time embedding
        self.time_embedding = Time2VecEncoding(hidden_dim, scale=1000)
        self.time_projection = nn.Linear(temporal_dim, hidden_dim)
        
        # Temporal GRU layers with bidirectional processing
        self.gru = nn.GRU(
            input_size=hidden_dim,
            hidden_size=hidden_dim,
            num_layers=num_gru_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout_rate if num_gru_layers > 1 else 0
        )
        
        # Reduce GRU output dimension (bidirectional doubles the output size)
        self.temporal_reduction = nn.Linear(hidden_dim * 2, hidden_dim)
        self.temporal_norm = nn.LayerNorm(hidden_dim)
        
        # Residual blocks for deeper representation
        self.residual_blocks = nn.Sequential(
            ResidualBlock(hidden_dim),
            nn.Dropout(dropout_rate),
            ResidualBlock(hidden_dim),
            nn.Dropout(dropout_rate)
        )
        
        # Using only SAGEConv for spatial modeling (more stable than GATConv)
        self.graph_conv1 = SAGEConv(hidden_dim, hidden_dim)
        self.graph_conv2 = SAGEConv(hidden_dim, hidden_dim)
        self.spatial_norm = nn.LayerNorm(hidden_dim)
        
        # Final prediction layers
        self.predictor = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate/2),
            nn.Linear(hidden_dim // 2, 1)
        )
        
    def forward(self, x, edge_index=None, time_features=None):
        batch_size = x.size(0)
        
        # Apply feature attention
        x = self.feature_attention(x)
        
        # Property embeddings
        property_emb = self.property_embedding(x)
        
        # Process time features if provided
        if time_features is not None:
            # Project all time features
            time_emb = self.time_projection(time_features)
            property_emb = property_emb + time_emb
        
        # Reshape for GRU (adding sequence dimension)
        temporal_input = property_emb.unsqueeze(1)  # Shape: [batch_size, 1, hidden_dim]
            
        # Process with bidirectional GRU
        temporal_output, _ = self.gru(temporal_input)  # Shape: [batch_size, 1, hidden_dim*2]
        
        # Reduce dimension from bidirectional output
        temporal_output = self.temporal_reduction(temporal_output[:, -1])  # Shape: [batch_size, hidden_dim]
        temporal_output = self.temporal_norm(temporal_output)
        
        # Apply residual blocks
        temporal_output = self.residual_blocks(temporal_output)
        
        # Process with graph convolution if edge_index provided and valid
        spatial_output = torch.zeros_like(temporal_output)
        if edge_index is not None and edge_index.numel() > 0:
            try:
                # Validate edge_index to ensure it's within valid range
                if edge_index.min() >= 0 and edge_index.max() < batch_size:
                    # First graph layer
                    spatial_output = self.graph_conv1(temporal_output, edge_index)
                    spatial_output = F.relu(spatial_output)
                    spatial_output = F.dropout(spatial_output, p=0.2, training=self.training)
                    
                    # Second graph layer
                    spatial_output = self.graph_conv2(spatial_output, edge_index)
                    spatial_output = self.spatial_norm(spatial_output)
                else:
                    print(f"Warning: Invalid edge_index values. Min: {edge_index.min()}, Max: {edge_index.max()}, Batch size: {batch_size}")
            except Exception as e:
                print(f"Error in graph convolution: {e}")
                # If the graph convolution fails, proceed without it
                spatial_output = torch.zeros_like(temporal_output)
        
        # Combine temporal and spatial outputs
        combined = torch.cat([temporal_output, spatial_output], dim=1)
        
        # Final prediction
        output = self.predictor(combined).squeeze(-1)
        
        return output


def create_edge_index_efficient(data, k_neighbors=5, distance_threshold=2.0, chunk_size=1000):
    """
    Create spatial edge index efficiently using BallTree and chunking
    with additional safety checks
    """
    print("Creating spatial edge index efficiently...")
    
    try:
        coords = data[['latitude', 'longitude']].values
        n_samples = len(coords)
        
        # Check for NaN or infinite values
        if np.isnan(coords).any() or np.isinf(coords).any():
            print("Warning: NaN or infinite values found in coordinates. Cleaning data...")
            coords = np.nan_to_num(coords, nan=0.0, posinf=90.0, neginf=-90.0)
        
        # Convert to radians for BallTree
        coords_rad = np.radians(coords)
        
        # Create BallTree
        tree = BallTree(coords_rad, metric='haversine')
        
        # Create edge index
        edge_index = []
        
        # Process in chunks to avoid memory issues
        for i in tqdm(range(0, n_samples, chunk_size)):
            end_idx = min(i + chunk_size, n_samples)
            chunk_coords = coords_rad[i:end_idx]
            
            # Query k+1 nearest neighbors (including self)
            distances, indices = tree.query(chunk_coords, k=min(k_neighbors+1, n_samples))
            
            # Convert distances from radians to km
            distances = distances * 6371.0  # Earth radius in km
            
            # Add edges for each point in chunk
            for j in range(len(chunk_coords)):
                point_idx = i + j
                for k in range(1, min(k_neighbors+1, indices.shape[1])):  # Skip self (index 0)
                    if k < indices.shape[1]:  # Ensure we don't go out of bounds
                        neighbor_idx = indices[j, k]
                        
                        # Safety check for valid indices
                        if neighbor_idx < n_samples and neighbor_idx >= 0:
                            distance = distances[j, k]
                            
                            if distance <= distance_threshold:
                                edge_index.append([point_idx, neighbor_idx])
                                # Add reverse edge for undirected graph
                                edge_index.append([neighbor_idx, point_idx])
        
        # Remove duplicates
        if edge_index:
            edge_index = list(set(tuple(edge) for edge in edge_index))
            edge_index = [list(edge) for edge in edge_index]
            
            # Final safety check
            edge_index_tensor = torch.tensor(edge_index, dtype=torch.long).t()
            
            # Check for valid indices
            if edge_index_tensor.min() < 0 or edge_index_tensor.max() >= n_samples:
                print("Warning: Invalid indices in edge_index. Creating a safer version...")
                valid_mask = (edge_index_tensor[0] >= 0) & (edge_index_tensor[0] < n_samples) & \
                             (edge_index_tensor[1] >= 0) & (edge_index_tensor[1] < n_samples)
                edge_index_tensor = edge_index_tensor[:, valid_mask]
            
            print(f"Created edge index with shape {edge_index_tensor.shape}")
            return edge_index_tensor
        else:
            print("Warning: No valid edges found. Returning empty edge index.")
            return torch.zeros((2, 0), dtype=torch.long)
            
    except Exception as e:
        print(f"Error creating edge index: {e}")
        print("Returning empty edge index.")
        return torch.zeros((2, 0), dtype=torch.long)


def calculate_distance(lat1, lon1, lat2, lon2):
    """Calculate the distance between two points using the Haversine formula."""
    R = 6371  # Earth's radius in kilometers
    lat1, lon1, lat2, lon2 = map(radians, [lat1, lon1, lat2, lon2])
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = np.sin(dlat/2)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon/2)**2
    c = 2 * np.arcsin(np.sqrt(a))
    return R * c


def create_enhanced_spatial_features(df, k_neighbors=5, chunk_size=1000, n_jobs=-1):
    """Create enhanced spatial features with error handling"""
    print("Creating enhanced spatial features...")
    
    # Make a copy to avoid modifying the original
    df = df.copy()
    
    # Paris coordinates
    city_center_lat, city_center_lon = 48.8566, 2.3522
    
    try:
        # Calculate distance to city center and north-south position
        df['distance_to_center'] = df.apply(
            lambda row: calculate_distance(
                row['latitude'], 
                row['longitude'], 
                city_center_lat, 
                city_center_lon
            ),
            axis=1
        )
        
        # Calculate directional features
        df['north_south'] = df['latitude'] - city_center_lat
        df['east_west'] = df['longitude'] - city_center_lon
        
        # Check for NaN or infinite values
        coords = df[['latitude', 'longitude']].values
        if np.isnan(coords).any() or np.isinf(coords).any():
            print("Warning: NaN or infinite values found in coordinates. Cleaning data...")
            df['latitude'] = np.nan_to_num(df['latitude'].values, nan=city_center_lat)
            df['longitude'] = np.nan_to_num(df['longitude'].values, nan=city_center_lon)
        
        # Create BallTree for nearest neighbor calculations
        coords = np.radians(df[['latitude', 'longitude']].values)
        tree = BallTree(coords, metric='haversine')
        
        # Process in chunks
        n_chunks = math.ceil(len(df) / chunk_size)
        chunks = np.array_split(df, n_chunks)
        
        # KNN features
        knn_price_mean = np.zeros(len(df))
        knn_price_std = np.zeros(len(df))
        knn_price_median = np.zeros(len(df))
        price_diff = np.zeros(len(df))
        
        start_idx = 0
        for chunk in tqdm(chunks, desc="Processing chunks for spatial features"):
            chunk_size = len(chunk)
            chunk_coords = np.radians(chunk[['latitude', 'longitude']].values)
            
            # Find k+1 nearest neighbors (including self)
            k_to_use = min(k_neighbors+1, len(df))  # Ensure k is not larger than dataset
            distances, indices = tree.query(chunk_coords, k=k_to_use)
            
            # Calculate neighbor statistics
            for i in range(chunk_size):
                # Skip self (index 0)
                neighbor_indices = indices[i, 1:]
                
                # Filter out invalid indices
                valid_indices = [idx for idx in neighbor_indices if 0 <= idx < len(df)]
                
                if valid_indices:
                    prices = df.iloc[valid_indices]['price'].values
                    
                    # Handle potential NaN or Inf values
                    prices = prices[~np.isnan(prices) & ~np.isinf(prices)]
                    
                    if len(prices) > 0:
                        knn_price_mean[start_idx + i] = np.mean(prices)
                        knn_price_std[start_idx + i] = np.std(prices) if len(prices) > 1 else 0
                        knn_price_median[start_idx + i] = np.median(prices)
                        price_diff[start_idx + i] = chunk.iloc[i]['price'] - np.mean(prices)
            
            start_idx += chunk_size
        
        # Add features to dataframe
        df['knn_price_mean'] = knn_price_mean
        df['knn_price_std'] = knn_price_std
        df['knn_price_median'] = knn_price_median
        df['price_diff_from_neighbors'] = price_diff
        
        # Standardize the new features
        spatial_features = ['distance_to_center', 'north_south', 'east_west',
                            'knn_price_mean', 'knn_price_std', 'knn_price_median', 
                            'price_diff_from_neighbors']
        
        for col in spatial_features:
            # Replace NaN or infinite values
            df[col] = np.nan_to_num(df[col].values, nan=0.0)
            
            mean_val = df[col].mean()
            std_val = df[col].std()
            if std_val > 0:
                df[col] = (df[col] - mean_val) / std_val
    
    except Exception as e:
        print(f"Error in creating spatial features: {e}")
        print("Proceeding with original features only.")
    
    return df


def calculate_error_autocorrelation(errors, max_lag=7):
    """Calculate autocorrelation of prediction errors with error handling"""
    try:
        # Remove NaN or infinite values
        clean_errors = errors[~np.isnan(errors) & ~np.isinf(errors)]
        
        if len(clean_errors) < max_lag + 1:
            print("Warning: Not enough valid error values for autocorrelation calculation.")
            return np.zeros(max_lag)
        
        # Calculate autocorrelation
        error_acf = acf(clean_errors, nlags=max_lag)
        
        # Return values excluding lag 0 (which is always 1)
        return error_acf[1:]
    except Exception as e:
        print(f"Error calculating autocorrelation: {e}")
        return np.zeros(max_lag)


def calculate_error_stability(all_results):
    """Calculate error stability metrics with error handling"""
    try:
        # Group by date
        grouped = all_results.groupby('date_str')
        
        # Calculate standard deviation of errors for each day
        daily_error_std = grouped['error'].std()
        
        # Calculate MAE for each day
        daily_mae = grouped['abs_error'].mean()
        
        # Calculate stability metrics
        mae_mean = daily_mae.mean() if len(daily_mae) > 0 else 1.0
        mae_stability = daily_mae.std() / mae_mean if mae_mean > 0 else 0.0
        
        return {
            'daily_error_std': daily_error_std,
            'daily_mae': daily_mae,
            'mae_stability_coefficient': mae_stability
        }
    except Exception as e:
        print(f"Error calculating error stability: {e}")
        # Return empty objects with the expected structure
        return {
            'daily_error_std': pd.Series(),
            'daily_mae': pd.Series(),
            'mae_stability_coefficient': 0.0
        }


def run_day_by_day_enhanced_strap_prediction(train_path, test_path, features_to_drop=None, prediction_days=7, output_path=None):
    """Run enhanced ST-RAP model with day-by-day retraining for multiple days prediction with robust error handling"""
    print(f"Processing dataset for {prediction_days}-day prediction with enhanced ST-RAP model")
    
    # Default features to drop if none specified
    if features_to_drop is None:
        features_to_drop = []
    
    print(f"Features being dropped: {features_to_drop}")
    
    try:
        # Load training and test data
        print("Loading data...")
        train_data = pd.read_csv(train_path)
        test_data = pd.read_csv(test_path)
        
        # Convert date columns to datetime
        train_data['date'] = pd.to_datetime(train_data['date'])
        test_data['date'] = pd.to_datetime(test_data['date'])
        
        # Drop specified columns if they exist
        for col in features_to_drop:
            if col in train_data.columns:
                print(f"Dropping column: {col}")
                train_data = train_data.drop(col, axis=1)
            if col in test_data.columns:
                test_data = test_data.drop(col, axis=1)
        
        # Sort by date
        train_data = train_data.sort_values('date')
        test_data = test_data.sort_values('date')
        
        # Get unique dates in test set
        test_dates = test_data['date'].dt.date.unique()
        print(f"Test set contains {len(test_dates)} unique dates.")
        
        # Limit to specified prediction days
        if len(test_dates) > prediction_days:
            test_dates = test_dates[:prediction_days]
            print(f"Limited to first {prediction_days} days for prediction.")
        
        # Add enhanced spatial features
        train_data = create_enhanced_spatial_features(train_data, k_neighbors=5)
        test_data = create_enhanced_spatial_features(test_data, k_neighbors=5)
        
        # Create feature matrices    
        feature_cols = [col for col in train_data.columns 
                      if col not in ['listing_id', 'date', 'price']]
        
        # Handle NaN or infinite values in feature matrices
        for col in feature_cols:
            train_data[col] = np.nan_to_num(train_data[col].values, nan=0.0)
            test_data[col] = np.nan_to_num(test_data[col].values, nan=0.0)
        
        X_train = train_data[feature_cols].values
        y_train = train_data['price'].values
        
        # Initialize device
        use_cuda = torch.cuda.is_available()
        device = torch.device('cuda' if use_cuda else 'cpu')
        print(f"Using device: {device}")
        
        # Create a simpler model if on CPU to speed up training
        if not use_cuda:
            print("Running on CPU - using a more efficient model configuration")
            model = EnhancedSTRAP(
                input_dim=len(feature_cols),
                temporal_dim=5,
                hidden_dim=128,  # Smaller hidden dimension
                num_gru_layers=2,  # Fewer GRU layers
                dropout_rate=0.2
            ).to(device)
        else:
            model = EnhancedSTRAP(
                input_dim=len(feature_cols),
                temporal_dim=5,
                hidden_dim=256,
                num_gru_layers=3,
                dropout_rate=0.3
            ).to(device)
        
        # Define optimizer with learning rate scheduler
        optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-5)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=3, verbose=True
        )
        
        # Define loss function
        criterion = nn.HuberLoss(delta=1.0)  # Huber loss is more robust to outliers than MSE
        
        # Initialize results storage
        daily_results = []
        
        # Build edge index once for efficiency
        # Use a smaller k_neighbors value for stability
        edge_index = None
        try:
            print("Attempting to build edge index for graph structure...")
            # Create a subset of data for edge index to reduce memory usage
            if len(train_data) > 5000:
                edge_data = pd.concat([train_data, test_data]).sample(5000, random_state=42)
            else:
                edge_data = pd.concat([train_data, test_data])
                
            edge_index = create_edge_index_efficient(
                edge_data[['latitude', 'longitude']],
                k_neighbors=5,  # Reduced from 8 to 5 for stability
                distance_threshold=3.0
            )
            
            if edge_index.numel() > 0:
                edge_index = edge_index.to(device)
                print(f"Successfully built edge index with shape {edge_index.shape}")
            else:
                print("Edge index is empty. Graph convolution will be skipped.")
                edge_index = None
        except Exception as e:
            print(f"Failed to build edge index: {e}")
            print("Falling back to non-graph mode")
            edge_index = None
        
        # Process each day in the test set
        for day in tqdm(test_dates, desc="Processing days"):
            # Convert day to datetime for filtering
            day_dt = pd.to_datetime(day)
            
            # Get test data for the current day
            day_test = test_data[test_data['date'].dt.date == day]
            
            X_test_day = day_test[feature_cols].values
            y_test_day = day_test['price'].values
            
            # Convert to PyTorch tensors
            X_train_tensor = torch.FloatTensor(X_train).to(device)
            y_train_tensor = torch.FloatTensor(y_train).to(device)
            X_test_tensor = torch.FloatTensor(X_test_day).to(device)
            y_test_tensor = torch.FloatTensor(y_test_day).to(device)
            
            # Extract time features
            time_cols = ['DTF_day_of_week', 'DTF_month', 'DTF_is_weekend', 'DTF_season_sin', 'DTF_season_cos']
            if all(col in train_data.columns for col in time_cols):
                time_features_train = torch.FloatTensor(train_data[time_cols].values).to(device)
                time_features_test = torch.FloatTensor(day_test[time_cols].values).to(device)
            else:
                # If time columns don't exist, create dummy time features
                print("Warning: Time columns not found. Using dummy time features.")
                time_features_train = torch.zeros((len(X_train_tensor), 5), device=device)
                time_features_test = torch.zeros((len(X_test_tensor), 5), device=device)
            
            # Create DataLoader for training with variable batch size based on data size
            batch_size = min(64, len(X_train_tensor) // 10) if len(X_train_tensor) > 640 else 32
            train_dataset = torch.utils.data.TensorDataset(X_train_tensor, time_features_train, y_train_tensor)
            train_loader = torch.utils.data.DataLoader(
                train_dataset, 
                batch_size=batch_size,
                shuffle=True
            )
            
            # Train model with dynamic epochs based on dataset size
            model.train()
            epochs = 10 if use_cuda else 5  # Fewer epochs if on CPU
            patience = 5
            patience_counter = 0
            best_val_loss = float('inf')
            
            for epoch in range(epochs):
                total_loss = 0
                batch_count = 0
                
                # Training phase with error handling
                for batch_x, batch_time, batch_y in train_loader:
                    try:
                        # Forward pass
                        outputs = model(batch_x, edge_index, batch_time)
                        
                        # Check for NaN or Inf values
                        if torch.isnan(outputs).any() or torch.isinf(outputs).any():
                            print("Warning: NaN or Inf values in model outputs. Skipping batch.")
                            continue
                        
                        # Calculate loss
                        loss = criterion(outputs, batch_y)
                        
                        # Backward and optimize
                        optimizer.zero_grad()
                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)  # Increased clip value
                        optimizer.step()
                        
                        total_loss += loss.item()
                        batch_count += 1
                    except RuntimeError as e:
                        if "CUDA" in str(e):
                            print(f"CUDA error during training: {e}")
                            print("Attempting to continue with next batch...")
                            # Clear any stored gradients
                            optimizer.zero_grad()
                            torch.cuda.empty_cache()  # Try to free GPU memory
                            continue
                        else:
                            raise
                
                # Calculate average loss for the epoch
                avg_loss = total_loss/batch_count if batch_count > 0 else float('inf')
                
                # Validation phase - use a small subset of training data
                val_size = min(1000, len(X_train_tensor))
                val_indices = torch.randperm(len(X_train_tensor))[:val_size]
                val_x = X_train_tensor[val_indices]
                val_time = time_features_train[val_indices]
                val_y = y_train_tensor[val_indices]
                
                model.eval()
                with torch.no_grad():
                    try:
                        val_outputs = model(val_x, edge_index, val_time)
                        val_loss = criterion(val_outputs, val_y).item()
                    except Exception as e:
                        print(f"Error during validation: {e}")
                        val_loss = float('inf')
                model.train()
                
                # Update learning rate based on validation loss
                scheduler.step(val_loss)
                
                # Print epoch stats
                print(f"Day: {day}, Epoch: {epoch+1}/{epochs}, Train Loss: {avg_loss:.6f}, Val Loss: {val_loss:.6f}")
                
                # Early stopping
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    patience_counter = 0
                else:
                    patience_counter += 1
                    if patience_counter >= patience:
                        print(f"Early stopping at epoch {epoch+1}")
                        break
            
            # Evaluate on test data
            model.eval()
            with torch.no_grad():
                # Generate predictions in smaller batches to avoid OOM
                test_batch_size = min(64, len(X_test_tensor))
                predictions = []
                
                for i in range(0, len(X_test_tensor), test_batch_size):
                    try:
                        end_idx = min(i + test_batch_size, len(X_test_tensor))
                        batch_x = X_test_tensor[i:end_idx]
                        batch_time = time_features_test[i:end_idx]
                        batch_pred = model(batch_x, edge_index, batch_time)
                        predictions.append(batch_pred)
                    except Exception as e:
                        print(f"Error during prediction batch {i}-{end_idx}: {e}")
                        # Fill with zeros for the failed batch
                        dummy_pred = torch.zeros(end_idx - i, device=device)
                        predictions.append(dummy_pred)
                
                # Concatenate all prediction batches
                if predictions:
                    y_pred = torch.cat(predictions)
                    # Ensure the length matches
                    if len(y_pred) != len(y_test_tensor):
                        print(f"Warning: Prediction length mismatch. Expected {len(y_test_tensor)}, got {len(y_pred)}")
                        # Pad or truncate to match
                        if len(y_pred) < len(y_test_tensor):
                            y_pred = torch.cat([y_pred, torch.zeros(len(y_test_tensor) - len(y_pred), device=device)])
                        else:
                            y_pred = y_pred[:len(y_test_tensor)]
                    
                    test_loss = criterion(y_pred, y_test_tensor).item()
                else:
                    print("No valid predictions generated. Using zeros.")
                    y_pred = torch.zeros_like(y_test_tensor)
                    test_loss = float('inf')
            
            # Convert predictions to numpy
            y_pred_np = y_pred.cpu().numpy()
            
            # Store results for the day
            day_results_df = pd.DataFrame({
                'date': day_test['date'],
                'listing_id': day_test['listing_id'],
                'price': y_test_day,
                'predicted': y_pred_np,
                'error': y_test_day - y_pred_np,
                'abs_error': np.abs(y_test_day - y_pred_np),
                'pct_error': np.abs((y_test_day - y_pred_np) / (np.abs(y_test_day) + 1e-8)) * 100
            })
            
            daily_results.append(day_results_df)
            
            # Update training data with the current day's actual values
            # This simulates getting actual values at the end of each day
            # before predicting the next day
            X_train = np.concatenate([X_train, X_test_day])
            y_train = np.concatenate([y_train, y_test_day])
            
            # Instead of growing the full dataframe (which can be slow),
            # just update the necessary columns for the next iteration
            new_train_data = day_test.copy()
            train_data = pd.concat([train_data, new_train_data], ignore_index=True)
            
            # Free GPU memory
            if use_cuda:
                torch.cuda.empty_cache()
        
        # Combine all daily results
        all_results = pd.concat(daily_results, ignore_index=True)
        
        # Calculate overall metrics with error handling
        y_true = all_results['price'].values
        y_pred = all_results['predicted'].values
        
        # Remove any potential NaN or Inf values
        valid_mask = ~np.isnan(y_true) & ~np.isinf(y_true) & ~np.isnan(y_pred) & ~np.isinf(y_pred)
        y_true_clean = y_true[valid_mask]
        y_pred_clean = y_pred[valid_mask]
        
        # Only calculate metrics if we have valid predictions
        if len(y_true_clean) > 0:
            metrics = {
                'rmse': np.sqrt(mean_squared_error(y_true_clean, y_pred_clean)),
                'mae': mean_absolute_error(y_true_clean, y_pred_clean),
                'r2': r2_score(y_true_clean, y_pred_clean) if len(set(y_true_clean)) > 1 else 0,
                'mape': np.mean(np.abs((y_true_clean - y_pred_clean) / (np.abs(y_true_clean) + 1e-8))) * 100
            }
        else:
            print("Warning: No valid predictions for metric calculation.")
            metrics = {'rmse': float('inf'), 'mae': float('inf'), 'r2': 0, 'mape': float('inf')}
        
        # Calculate daily metrics with error handling
        daily_metrics = []
        for day_df in daily_results:
            try:
                day = day_df['date'].iloc[0]
                y_true_day = day_df['price'].values
                y_pred_day = day_df['predicted'].values
                
                # Remove any NaN or Inf values
                valid_mask = ~np.isnan(y_true_day) & ~np.isinf(y_true_day) & ~np.isnan(y_pred_day) & ~np.isinf(y_pred_day)
                y_true_day_clean = y_true_day[valid_mask]
                y_pred_day_clean = y_pred_day[valid_mask]
                
                if len(y_true_day_clean) > 0:
                    daily_metrics.append({
                        'date': day,
                        'rmse': np.sqrt(mean_squared_error(y_true_day_clean, y_pred_day_clean)),
                        'mae': mean_absolute_error(y_true_day_clean, y_pred_day_clean),
                        'r2': r2_score(y_true_day_clean, y_pred_day_clean) if len(set(y_true_day_clean)) > 1 else np.nan,
                        'mape': np.mean(np.abs((y_true_day_clean - y_pred_day_clean) / (np.abs(y_true_day_clean) + 1e-8))) * 100,
                        'n_samples': len(y_true_day_clean)
                    })
                else:
                    # Add placeholder metrics if no valid data
                    daily_metrics.append({
                        'date': day,
                        'rmse': np.nan,
                        'mae': np.nan,
                        'r2': np.nan,
                        'mape': np.nan,
                        'n_samples': 0
                    })
            except Exception as e:
                print(f"Error calculating metrics for day {day_df['date'].iloc[0]}: {e}")
        
        daily_metrics_df = pd.DataFrame(daily_metrics)
        
        # Create evaluation results dictionary
        evaluation_results = {
            'overall_metrics': metrics,
            'daily_metrics': daily_metrics_df,
            'all_results': all_results
        }
        
        # Add date_str column for grouping
        all_results['date_str'] = pd.to_datetime(all_results['date']).dt.strftime('%Y-%m-%d')

        # Calculate error autocorrelation
        error_autocorrelation = calculate_error_autocorrelation(all_results['error'].values)

        # Calculate error stability metrics
        error_stability = calculate_error_stability(all_results)

        # Add to evaluation results
        evaluation_results['error_autocorrelation'] = error_autocorrelation
        evaluation_results['error_stability'] = error_stability
        
        # Save results to CSV if output path is provided
        if output_path:
            try:
                # Make sure the directory exists
                os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
                
                # Include location data from test data if available
                if 'longitude' in test_data.columns and 'latitude' in test_data.columns:
                    location_data = test_data[['listing_id', 'longitude', 'latitude']].drop_duplicates()
                    results_with_location = all_results.merge(location_data, on='listing_id', how='left')
                    results_with_location.to_csv(output_path, index=False)
                    print(f"Results saved to {output_path} with location data")
                else:
                    all_results.to_csv(output_path, index=False)
                    print(f"Results saved to {output_path}")
                
                # Also save daily metrics
                metrics_path = output_path.replace('.csv', '_daily_metrics.csv')
                daily_metrics_df.to_csv(metrics_path, index=False)
                print(f"Daily metrics saved to {metrics_path}")
            except Exception as e:
                print(f"Error saving results: {e}")
        
        return evaluation_results
    
    except Exception as e:
        print(f"Critical error in prediction pipeline: {e}")
        # Return minimal valid output to avoid further errors
        return {
            'overall_metrics': {'rmse': float('inf'), 'mae': float('inf'), 'r2': 0, 'mape': float('inf')},
            'daily_metrics': pd.DataFrame(),
            'all_results': pd.DataFrame(),
            'error_autocorrelation': np.zeros(7),
            'error_stability': {
                'daily_error_std': pd.Series(),
                'daily_mae': pd.Series(),
                'mae_stability_coefficient': 0.0
            }
        }


def plot_enhanced_results(evaluation_results):
    """Plot the results from predictions with enhanced visualizations and error handling"""
    try:
        # Set style with fallback option
        try:
            plt.style.use('seaborn-v0_8-whitegrid')
        except:
            try:
                plt.style.use('seaborn-whitegrid')  # Fallback for older versions
            except:
                pass  # Continue with default style if both fail
        
        # Extract data
        daily_metrics = evaluation_results['daily_metrics']
        all_results = evaluation_results['all_results']
        
        # Check if we have valid data to plot
        if len(daily_metrics) == 0 or len(all_results) == 0:
            print("No valid data for plotting. Skipping visualization.")
            return
        
        # Create a figure with multiple subplots
        fig, axes = plt.subplots(2, 2, figsize=(16, 14))
        
        # Plot 1: Daily MAE with trend line
        if len(daily_metrics) > 0 and 'date' in daily_metrics.columns and 'mae' in daily_metrics.columns:
            dates = pd.to_datetime(daily_metrics['date'])
            
            # Filter out NaN values
            valid_mae = daily_metrics['mae'].notna()
            if valid_mae.any():
                sns.lineplot(
                    x=dates[valid_mae],
                    y=daily_metrics.loc[valid_mae, 'mae'],
                    marker='o',
                    linewidth=2,
                    color='royalblue',
                    ax=axes[0, 0]
                )
                
                # Add trend line if we have at least 2 points
                if sum(valid_mae) >= 2:
                    x_indices = np.arange(len(dates[valid_mae]))
                    z = np.polyfit(x_indices, daily_metrics.loc[valid_mae, 'mae'], 1)
                    p = np.poly1d(z)
                    axes[0, 0].plot(dates[valid_mae], p(x_indices), "r--", alpha=0.8, 
                                   label=f"Trend: {'increasing' if z[0] > 0 else 'decreasing'}")
                
                axes[0, 0].set_title('Mean Absolute Error by Day', fontsize=14)
                axes[0, 0].set_xlabel('Date', fontsize=12)
                axes[0, 0].set_ylabel('MAE', fontsize=12)
                axes[0, 0].legend()
        
        # Plot 2: Daily RMSE with confidence interval
        if len(daily_metrics) > 0 and 'date' in daily_metrics.columns and 'rmse' in daily_metrics.columns:
            # Filter out NaN values
            valid_rmse = daily_metrics['rmse'].notna()
            if valid_rmse.any():
                sns.lineplot(
                    x=dates[valid_rmse],
                    y=daily_metrics.loc[valid_rmse, 'rmse'],
                    marker='o',
                    linewidth=2,
                    color='forestgreen',
                    ax=axes[0, 1]
                )
                
                # Add error bands (assuming std dev of 10% for illustration)
                rmse_values = daily_metrics.loc[valid_rmse, 'rmse']
                rmse_std = rmse_values * 0.1
                axes[0, 1].fill_between(
                    dates[valid_rmse], 
                    rmse_values - rmse_std,
                    rmse_values + rmse_std,
                    alpha=0.2,
                    color='forestgreen'
                )
                
                axes[0, 1].set_title('Root Mean Squared Error by Day', fontsize=14)
                axes[0, 1].set_xlabel('Date', fontsize=12)
                axes[0, 1].set_ylabel('RMSE', fontsize=12)
        
        # Plot 3: Actual vs Predicted with improved styling
        if len(all_results) > 0 and 'price' in all_results.columns and 'predicted' in all_results.columns:
            # Filter out rows with NaN values
            valid_results = all_results.dropna(subset=['price', 'predicted', 'date'])
            
            if len(valid_results) > 0:
                valid_results['date_str'] = pd.to_datetime(valid_results['date']).dt.strftime('%Y-%m-%d')
                
                # Create colormap based on date
                unique_dates = pd.to_datetime(valid_results['date']).dt.date.unique()
                date_map = {date: i for i, date in enumerate(unique_dates)}
                valid_results['date_num'] = pd.to_datetime(valid_results['date']).dt.date.map(date_map)
                
                scatter = axes[1, 0].scatter(
                    valid_results['price'],
                    valid_results['predicted'],
                    c=valid_results['date_num'],
                    cmap='viridis',
                    alpha=0.7,
                    edgecolors='w',
                    linewidths=0.2
                )
                
                # Calculate and plot regression line
                min_val = min(valid_results['price'].min(), valid_results['predicted'].min())
                max_val = max(valid_results['price'].max(), valid_results['predicted'].max())
                
                # Perfect prediction line
                axes[1, 0].plot([min_val, max_val], [min_val, max_val], 'k--', alpha=0.7, label='Perfect Prediction')
                
                # Add regression line if we have enough points
                if len(valid_results) >= 2:
                    z = np.polyfit(valid_results['price'], valid_results['predicted'], 1)
                    p = np.poly1d(z)
                    x_reg = np.linspace(min_val, max_val, 100)
                    axes[1, 0].plot(x_reg, p(x_reg), color='red', linestyle='-', alpha=0.7, 
                                  label=f'Regression Line (slope={z[0]:.2f})')
                
                axes[1, 0].set_title('Actual vs Predicted Prices', fontsize=14)
                axes[1, 0].set_xlabel('Actual Price', fontsize=12)
                axes[1, 0].set_ylabel('Predicted Price', fontsize=12)
                axes[1, 0].legend(loc='upper left')
                
                # Add colorbar for dates
                cbar = plt.colorbar(scatter, ax=axes[1, 0])
                cbar.set_label('Date Progression')
        
        # Plot 4: Error distribution
        if len(all_results) > 0 and 'error' in all_results.columns:
            all_results['error'] = all_results['price'] - all_results['predicted']
            
            # Remove outliers for better visualization (keep within 3 std deviations)
            mean_error = all_results['error'].mean()
            std_error = all_results['error'].std()
            
            lower_bound = mean_error - 3 * std_error
            upper_bound = mean_error + 3 * std_error
            
            filtered_errors = all_results['error'][(all_results['error'] >= lower_bound) & 
                                                   (all_results['error'] <= upper_bound)]
            
            if len(filtered_errors) > 0:
                sns.histplot(
                    filtered_errors, 
                    kde=True, 
                    ax=axes[1, 1],
                    bins=min(30, len(filtered_errors) // 10 + 5),
                    color='darkviolet',
                    edgecolor='white',
                    linewidth=0.5,
                    stat='density'
                )
                
                # Add vertical lines for mean and median
                axes[1, 1].axvline(filtered_errors.mean(), color='red', linestyle='--', alpha=0.7, 
                                  label=f'Mean: {filtered_errors.mean():.3f}')
                axes[1, 1].axvline(filtered_errors.median(), color='green', linestyle='--', alpha=0.7,
                                  label=f'Median: {filtered_errors.median():.3f}')
                axes[1, 1].axvline(0, color='blue', linestyle='-', alpha=0.7, label='Zero Error')
                
                axes[1, 1].set_title('Error Distribution (Outliers Filtered)', fontsize=14)
                axes[1, 1].set_xlabel('Error (Actual - Predicted)', fontsize=12)
                axes[1, 1].set_ylabel('Density', fontsize=12)
                axes[1, 1].legend()
        
        plt.tight_layout()
        
        try:
            plt.savefig('enhanced_strap_performance.png', dpi=300, bbox_inches='tight')
            print("Saved performance plot to 'enhanced_strap_performance.png'")
        except Exception as e:
            print(f"Could not save plot: {e}")
            
        plt.show()
        
        # Create additional plot for sample size and MAPE if we have data
        if len(daily_metrics) > 0 and 'n_samples' in daily_metrics.columns and 'mape' in daily_metrics.columns:
            valid_metrics = daily_metrics.dropna(subset=['n_samples', 'mape'])
            
            if len(valid_metrics) > 0:
                plt.figure(figsize=(12, 7))
                ax1 = plt.gca()
                ax2 = ax1.twinx()
                
                # Format dates for x-axis
                date_labels = pd.to_datetime(valid_metrics['date']).dt.strftime('%Y-%m-%d')
                
                # Sample size bars
                bars = ax1.bar(
                    date_labels,
                    valid_metrics['n_samples'],
                    color='skyblue',
                    alpha=0.7,
                    edgecolor='navy',
                    linewidth=1
                )
                
                # Add data labels on top of bars
                for bar in bars:
                    height = bar.get_height()
                    ax1.text(
                        bar.get_x() + bar.get_width()/2., 
                        height + 5,
                        f'{int(height)}',
                        ha='center', 
                        va='bottom',
                        fontsize=9,
                        rotation=0
                    )
                
                ax1.set_xlabel('Date', fontsize=12)
                ax1.set_ylabel('Number of Samples', color='navy', fontsize=12)
                plt.setp(ax1.xaxis.get_majorticklabels(), rotation=45, ha='right')
                
                # MAPE line
                sns.lineplot(
                    x=date_labels,
                    y=valid_metrics['mape'],
                    marker='o',
                    markersize=8,
                    color='crimson',
                    linewidth=2,
                    ax=ax2
                )
                
                # Add data labels to line points
                for i, mape in enumerate(valid_metrics['mape']):
                    ax2.annotate(
                        f'{mape:.2f}%', 
                        (i, mape),
                        textcoords="offset points",
                        xytext=(0, 10),
                        ha='center',
                        fontsize=9,
                        color='crimson'
                    )
                
                ax2.set_ylabel('MAPE (%)', color='crimson', fontsize=12)
                ax2.tick_params(axis='y', colors='crimson')
                
                plt.title('Sample Size and Mean Absolute Percentage Error by Day', fontsize=14)
                plt.tight_layout()
                
                try:
                    plt.savefig('enhanced_strap_mape_samples.png', dpi=300, bbox_inches='tight')
                    print("Saved MAPE plot to 'enhanced_strap_mape_samples.png'")
                except Exception as e:
                    print(f"Could not save MAPE plot: {e}")
                    
                plt.show()
    
    except Exception as e:
        print(f"Error in plot_enhanced_results: {e}")
        print("Skipping visualization.")


def print_detailed_summary(evaluation_results):
    """Print a detailed performance summary with enhanced error handling"""
    try:
        overall = evaluation_results['overall_metrics']
        daily = evaluation_results['daily_metrics']
        error_autocorr = evaluation_results['error_autocorrelation']
        error_stability = evaluation_results['error_stability']
        all_results = evaluation_results['all_results']
        
        # Check if we have valid data
        if len(all_results) == 0:
            print("No valid results to summarize.")
            return
        
        print("\n" + "="*80)
        print(f"{' ENHANCED ST-RAP MODEL EVALUATION ':=^80}")
        print("="*80)
        
        print("\n" + "-"*35 + " OVERALL METRICS " + "-"*35)
        print(f"{'RMSE:':<25} {overall['rmse']:.4f}")
        print(f"{'MAE:':<25} {overall['mae']:.4f}")
        print(f"{'R²:':<25} {overall['r2']:.4f}")
        print(f"{'MAPE:':<25} {overall['mape']:.4f}%")
        
        # Distribution of errors with error handling
        errors = all_results['error'].dropna()
        abs_errors = all_results['abs_error'].dropna()
        
        if len(errors) > 0:
            error_skew = errors.skew()
            error_kurtosis = errors.kurtosis()
            
            print("\n" + "-"*35 + " ERROR DISTRIBUTION " + "-"*34)
            print(f"{'Mean Error:':<25} {errors.mean():.4f}")
            print(f"{'Median Error:':<25} {errors.median():.4f}")
            print(f"{'Error Std Dev:':<25} {errors.std():.4f}")
            print(f"{'Error Skewness:':<25} {error_skew:.4f} ({'Symmetric' if abs(error_skew) < 0.5 else 'Skewed'})")
            print(f"{'Error Kurtosis:':<25} {error_kurtosis:.4f}")
            
            # Calculate percentiles
            error_percentiles = {
                '5%': errors.quantile(0.05),
                '25%': errors.quantile(0.25),
                '50%': errors.median(),
                '75%': errors.quantile(0.75),
                '95%': errors.quantile(0.95),
            }
            
            abs_error_percentiles = {
                '5%': abs_errors.quantile(0.05),
                '25%': abs_errors.quantile(0.25),
                '50%': abs_errors.median(),
                '75%': abs_errors.quantile(0.75),
                '95%': abs_errors.quantile(0.95),
            }
            
            print("\n" + "-"*30 + " ERROR PERCENTILES (SIGNED) " + "-"*30)
            for label, value in error_percentiles.items():
                print(f"{label + ':':<25} {value:.4f}")
            
            print("\n" + "-"*30 + " ERROR PERCENTILES (ABSOLUTE) " + "-"*29)
            for label, value in abs_error_percentiles.items():
                print(f"{label + ':':<25} {value:.4f}")
        
        if len(daily) > 0:
            print("\n" + "-"*35 + " DAILY PERFORMANCE " + "-"*34)
            print(daily[['date', 'rmse', 'mae', 'r2', 'mape', 'n_samples']].to_string(index=False))
            
            # Check if we have enough data for statistics
            if len(daily['mae'].dropna()) > 0:
                # Calculate additional metrics
                mae_cv = daily['mae'].std() / daily['mae'].mean() if daily['mae'].mean() != 0 else float('inf')
                rmse_cv = daily['rmse'].std() / daily['rmse'].mean() if daily['rmse'].mean() != 0 else float('inf')
                
                print("\n" + "-"*33 + " PERFORMANCE STATISTICS " + "-"*33)
                print("MAE:")
                print(f"  {'Average:':<20} {daily['mae'].mean():.4f}")
                print(f"  {'Std Dev:':<20} {daily['mae'].std():.4f}")
                print(f"  {'CV (Stability):':<20} {mae_cv:.4f} ({'Stable' if mae_cv < 0.1 else 'Variable'})")
                
                if not daily['mae'].isna().all():
                    min_idx = daily['mae'].idxmin()
                    max_idx = daily['mae'].idxmax()
                    print(f"  {'Min:':<20} {daily['mae'].min():.4f} (Day: {daily.loc[min_idx, 'date']})")
                    print(f"  {'Max:':<20} {daily['mae'].max():.4f} (Day: {daily.loc[max_idx, 'date']})")
                
                print("\nRMSE:")
                print(f"  {'Average:':<20} {daily['rmse'].mean():.4f}")
                print(f"  {'Std Dev:':<20} {daily['rmse'].std():.4f}")
                print(f"  {'CV (Stability):':<20} {rmse_cv:.4f}")
                
                if not daily['rmse'].isna().all():
                    min_idx = daily['rmse'].idxmin()
                    max_idx = daily['rmse'].idxmax()
                    print(f"  {'Min:':<20} {daily['rmse'].min():.4f} (Day: {daily.loc[min_idx, 'date']})")
                    print(f"  {'Max:':<20} {daily['rmse'].max():.4f} (Day: {daily.loc[max_idx, 'date']})")
                
                print("\nMAPE:")
                print(f"  {'Average:':<20} {daily['mape'].mean():.2f}%")
                print(f"  {'Std Dev:':<20} {daily['mape'].std():.2f}%")
                
                if not daily['mape'].isna().all():
                    min_idx = daily['mape'].idxmin()
                    max_idx = daily['mape'].idxmax()
                    print(f"  {'Min:':<20} {daily['mape'].min():.2f}% (Day: {daily.loc[min_idx, 'date']})")
                    print(f"  {'Max:':<20} {daily['mape'].max():.2f}% (Day: {daily.loc[max_idx, 'date']})")
        
        if isinstance(error_autocorr, np.ndarray) and len(error_autocorr) > 0:
            print("\n" + "-"*35 + " ERROR AUTOCORRELATION " + "-"*33)
            for lag, acf_value in enumerate(error_autocorr, 1):
                significance = ""
                if abs(acf_value) > (1.96 / np.sqrt(len(all_results))):
                    significance = " *SIGNIFICANT*"
                print(f"  {'Lag ' + str(lag) + ':':<20} {acf_value:.4f}{significance}")
        
        if 'mae_stability_coefficient' in error_stability:
            print("\n" + "-"*38 + " ERROR STABILITY " + "-"*37)
            print(f"  {'MAE Stability Coef:':<25} {error_stability['mae_stability_coefficient']:.4f}")
            print("  (Lower values indicate more consistent predictions across days)")
        
        # Identify best and worst performing groups if we have enough data
        if len(all_results) >= 10:
            print("\n" + "-"*35 + " PERFORMANCE ANALYSIS " + "-"*33)
            
            try:
                # Group by listing_id to find consistently well/poorly predicted listings
                listing_perf = all_results.groupby('listing_id').agg({
                    'abs_error': 'mean',
                    'pct_error': 'mean'
                }).sort_values('abs_error')
                
                if len(listing_perf) > 0:
                    # Get top and bottom 5 (or fewer if less are available)
                    top_n = min(5, len(listing_perf))
                    print(f"\nBest Predicted Listings (Lowest Average Absolute Error):")
                    print(listing_perf.head(top_n).reset_index())
                    
                    print(f"\nWorst Predicted Listings (Highest Average Absolute Error):")
                    print(listing_perf.tail(top_n).reset_index())
            except Exception as e:
                print(f"Error analyzing listing performance: {e}")
        
        print("\n" + "="*80)
        
    except Exception as e:
        print(f"Error in print_detailed_summary: {e}")
        print("Could not print complete summary due to errors.")


if __name__ == "__main__":
    import os
    
    # Specify paths to your data
    train_path = r"C:\Users\mvk\Documents\DATA_school\thesis\Subset\top_price_changers_subset\train.csv"
    test_path = r"C:\Users\mvk\Documents\DATA_school\thesis\Subset\top_price_changers_subset\test_feb.csv"
    output_path = r"C:\Users\mvk\Documents\DATA_school\thesis\Output\NN\enhanced_strap_results.csv"
    
    try:
        # Make sure the output directory exists
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        
        # Define features to drop if needed (empty list means keep all features

In [None]:
if __name__ == "__main__":
    import os
    
    # Specify paths to your data
    train_path = r"C:\Users\mvk\Documents\DATA_school\thesis\Subset\top_price_changers_subset\train.csv"
    test_path = r"C:\Users\mvk\Documents\DATA_school\thesis\Subset\top_price_changers_subset\test_feb.csv"
    output_path = r"C:\Users\mvk\Documents\DATA_school\thesis\Output\NN\enhanced_strap_results.csv"
    
    # Make sure the output directory exists
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    features_to_drop = []
    
    # Set CUDA options for better error handling
    os.environ['CUDA_LAUNCH_BLOCKING'] = '1'  # Makes CUDA errors more informative
    
    # Enable warnings to catch potential issues
    warnings.filterwarnings('always')
    
    print("Starting enhanced STRAP model training and evaluation...")
    
    # Run prediction with enhanced model
    results = run_day_by_day_enhanced_strap_prediction(
        train_path=train_path,
        test_path=test_path,
        features_to_drop=features_to_drop,
        prediction_days=7,
        output_path=output_path
    )
    
    # Print detailed summary
    print_detailed_summary(results)
    
    # Generate enhanced visualizations
    plot_enhanced_results(results)
    
    print("\nModel training and evaluation complete. Results and visualizations have been saved.")
    
except Exception as e:
    print(f"Critical error in main execution: {e}")
    print("Please check your data paths and environment setup.")

Processing dataset for 7-day prediction with enhanced ST-RAP model
Features being dropped: []
Loading data...
Test set contains 7 unique dates.
Creating enhanced spatial features...


  return bound(*args, **kwds)
Processing chunks for spatial features: 100%|██████████| 1641/1641 [16:47<00:00,  1.63it/s]


Creating enhanced spatial features...


  return bound(*args, **kwds)
Processing chunks for spatial features: 100%|██████████| 56/56 [00:22<00:00,  2.54it/s]


Using device: cuda




Attempting to build edge index for graph structure...
Creating spatial edge index efficiently...


100%|██████████| 1696/1696 [07:26<00:00,  3.80it/s]


Successfully built edge index with shape torch.Size([2, 26651408])


Processing days:   0%|          | 0/7 [00:00<?, ?it/s]


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
