In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.cuda.amp import autocast, GradScaler
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.neighbors import NearestNeighbors
import matplotlib.dates as mdates
from torch_geometric.nn import GATv2Conv
from torch_geometric.data import Data
import warnings
import math
warnings.filterwarnings('ignore')

# ----- Core Functions from paste2 (Fast Implementation) -----

# 1. Price transformation function (from paste2)
def apply_price_transformation(train_data, inverse=False):
    """
    Apply log transformation to price data or inverse the transformation
    """
    df = train_data.copy()
    
    if not inverse:
        # Apply log transformation
        print("Applying log transformation to price data")
        df['original_price'] = df['price']  # Store original price
        df['price'] = np.log1p(df['price'])  # log1p to handle zero values
    else:
        # Inverse transform
        print("Inverting log transformation for predictions")
        df['price'] = np.expm1(df['price'])  # expm1 is the inverse of log1p
    
    return df

# 2. Create calculated features (from paste2)
def create_calculated_features(df):
    """
    Adapt calculated features to work with provided dataset columns
    """
    # Create a copy to avoid modifying the original
    df_copy = df.copy()
    
    # Bedroom ratio
    if 'bedrooms' in df_copy.columns and 'accommodates' in df_copy.columns:
        df_copy['bedroom_ratio'] = df_copy['bedrooms'] / df_copy['accommodates'].clip(lower=1)
    
    # Count amenities
    amenity_columns = df_copy.filter(like='has_').columns
    if len(amenity_columns) > 0:
        df_copy['amenity_count'] = df_copy[amenity_columns].sum(axis=1)
    
    # Luxury score - use specific amenities from your dataset
    luxury_amenities = ['has_hot_water', 'has_hair_dryer', 'has_dedicated_workspace', 
                         'has_tv', 'has_wifi', 'has_shampoo']
    available_luxury = [col for col in luxury_amenities if col in df_copy.columns]
    
    if available_luxury:
        df_copy['luxury_score'] = df_copy[available_luxury].sum(axis=1) / len(available_luxury)
    else:
        df_copy['luxury_score'] = 0
    
    # Essential score - basic amenities that are essential
    essential_amenities = ['has_essentials', 'has_bed_linens', 'has_kitchen', 
                           'has_smoke_alarm', 'has_heating']
    available_essential = [col for col in essential_amenities if col in df_copy.columns]
    
    if available_essential:
        df_copy['essential_score'] = df_copy[available_essential].sum(axis=1) / len(available_essential)
    else:
        df_copy['essential_score'] = 0
    
    # Price volatility features based on rolling statistics
    if all(col in df_copy.columns for col in ['rolling_max_7d', 'rolling_min_7d']):
        df_copy['price_range_7d'] = df_copy['rolling_max_7d'] - df_copy['rolling_min_7d']
    
    if all(col in df_copy.columns for col in ['rolling_max_14d', 'rolling_min_14d']):
        df_copy['price_range_14d'] = df_copy['rolling_max_14d'] - df_copy['rolling_min_14d']
    
    if all(col in df_copy.columns for col in ['rolling_max_30d', 'rolling_min_30d']):
        df_copy['price_range_30d'] = df_copy['rolling_max_30d'] - df_copy['rolling_min_30d']
    
    # Fill any NaN values that might have been created
    numeric_cols = df_copy.select_dtypes(include=['number']).columns
    for col in numeric_cols:
        if df_copy[col].isnull().any():
            df_copy[col] = df_copy[col].fillna(df_copy[col].median())
    
    return df_copy

# 3. Build enhanced spatial graph for GNN (from paste2, but with k=5)
def build_enhanced_spatial_graph(train_data, test_data, k=5, feature_weight=0.3):
    """
    Build a graph with edge weights based on both geographic and feature similarity
    """
    # Extract coordinates
    train_coords = train_data[['latitude', 'longitude']].values
    test_coords = test_data[['latitude', 'longitude']].values
    
    print(f"Building enhanced spatial graph with {len(test_coords)} test listings and {k} nearest neighbors...")
    
    # Extract and normalize key features for similarity calculation
    features = ['accommodates', 'bedrooms', 'bathrooms']
    available_features = [f for f in features if f in train_data.columns]
    
    if available_features:
        scaler = StandardScaler()
        train_features = scaler.fit_transform(train_data[available_features].fillna(0))
        test_features = scaler.transform(test_data[available_features].fillna(0))
    else:
        # Fallback if no features are available
        print("Warning: No property features available for similarity calculation")
        train_features = np.ones((len(train_coords), 1))
        test_features = np.ones((len(test_coords), 1))
    
    # Find k nearest neighbors for each test listing
    nn = NearestNeighbors(n_neighbors=min(k, len(train_coords)))
    nn.fit(train_coords)
    distances, indices = nn.kneighbors(test_coords)
    
    # Create edge indices and attributes
    edge_index = []
    edge_attr = []
    
    for test_idx, (neighbor_indices, neighbor_distances) in enumerate(zip(indices, distances)):
        test_feat = test_features[test_idx]
        
        for train_idx, distance in zip(neighbor_indices, neighbor_distances):
            # Calculate feature similarity (cosine similarity)
            train_feat = train_features[train_idx]
            feat_norm_product = np.linalg.norm(test_feat) * np.linalg.norm(train_feat)
            
            if feat_norm_product > 1e-8:  # Avoid division by zero
                feat_sim = np.dot(test_feat, train_feat) / feat_norm_product
            else:
                feat_sim = 0.0
            
            # Normalize distance for better numerical stability
            geo_weight = 1.0 / (distance + 1e-6)
            
            # Combined weight: (1-α) * geo_weight + α * feature_weight
            combined_weight = (1 - feature_weight) * geo_weight + feature_weight * max(0, feat_sim)
            
            # Add edge from test listing to train listing
            edge_index.append([test_idx + len(train_data), train_idx])
            edge_attr.append([combined_weight])
            
            # Add reverse edge
            edge_index.append([train_idx, test_idx + len(train_data)])
            edge_attr.append([combined_weight])
    
    # Convert to tensors with explicit dtype
    edge_index_tensor = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr_tensor = torch.tensor(edge_attr, dtype=torch.float32)
    
    print(f"Created graph with {edge_index_tensor.shape[1]} edges")
    
    return edge_index_tensor, edge_attr_tensor

# 8. Function to evaluate predictions (from paste2)
def evaluate_gnn_predictions(y_true, y_pred, print_results=True):
    """
    Evaluate GNN predictions using multiple metrics
    """
    # Calculate metrics
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    mae = mean_absolute_error(y_true, y_pred)
    r2 = r2_score(y_true, y_pred)
    mape = np.mean(np.abs((y_true - y_pred) / (y_true + 1e-8))) * 100
    
    # Store metrics in dictionary
    metrics = {
        'rmse': rmse,
        'mae': mae,
        'r2': r2,
        'mape': mape
    }
    
    # Print results if requested
    if print_results:
        print("=== GNN Model Evaluation ===")
        print(f"RMSE: {rmse:.2f}")
        print(f"MAE: {mae:.2f}")
        print(f"R²: {r2:.4f}")
        print(f"MAPE: {mape:.2f}%")
    
    return metrics

# ----- Components from paste1 (Transformer-based architecture) -----

# TimeSeriesEncoder from paste1 (simplified version without sub-batching)
class TimeSeriesEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_heads=2, num_layers=1, dropout=0.1):
        super(TimeSeriesEncoder, self).__init__()
        self.embedding = nn.Linear(input_dim, hidden_dim)
        
        # Transformer encoder layer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim*2,
            dropout=dropout,
            batch_first=True,
            norm_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
    def forward(self, x, src_key_padding_mask=None):
        if x.shape[0] == 0:  # Handle empty batch case
            return torch.zeros((0, self.embedding.out_features), device=x.device)
            
        # x shape: [batch_size, seq_len, features]
        x = self.embedding(x)
        
        # Process entire batch at once
        return self.transformer(x, src_key_padding_mask=src_key_padding_mask)

# CrossAttention from paste1 (simplified version without sub-batching)
class CrossAttention(nn.Module):
    def __init__(self, spatial_dim, temporal_dim, heads=2, dropout=0.1):
        super(CrossAttention, self).__init__()
        # Project to same dimension
        self.project_spatial = nn.Linear(spatial_dim, spatial_dim) if spatial_dim != temporal_dim else nn.Identity()
        self.project_temporal = nn.Linear(temporal_dim, spatial_dim) if temporal_dim != spatial_dim else nn.Identity()
        
        # Multi-head attention with fewer heads
        self.multihead_attn = nn.MultiheadAttention(
            embed_dim=spatial_dim,
            num_heads=heads,
            dropout=dropout,
            batch_first=True
        )
        
        # Layer normalization and feed-forward network
        self.norm1 = nn.LayerNorm(spatial_dim)
        self.norm2 = nn.LayerNorm(spatial_dim)
        self.dropout = nn.Dropout(dropout)
        self.ffn = nn.Sequential(
            nn.Linear(spatial_dim, spatial_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(spatial_dim * 2, spatial_dim)
        )
        
        # Output projection
        self.output_proj = nn.Linear(spatial_dim + temporal_dim, spatial_dim * 2)

    def forward(self, spatial_features, temporal_context):
        # Check for empty inputs
        if spatial_features.shape[0] == 0:
            output_dim = self.output_proj.out_features
            return torch.zeros((0, output_dim), device=spatial_features.device)
        
        # Project features if needed
        spatial_proj = self.project_spatial(spatial_features)
        temporal_proj = self.project_temporal(temporal_context)
        
        # Reshape temporal context if needed
        if temporal_proj.dim() == 2:
            # Add sequence dimension (treat as single token)
            temporal_proj = temporal_proj.unsqueeze(1)
        
        # Multi-head attention - spatial as query, temporal as key/value
        spatial_proj_seq = spatial_proj.unsqueeze(1) if spatial_proj.dim() == 2 else spatial_proj
        
        attn_output, _ = self.multihead_attn(
            query=spatial_proj_seq,
            key=temporal_proj,
            value=temporal_proj
        )
        
        # Remove sequence dimension if needed
        if spatial_features.dim() == 2:
            attn_output = attn_output.squeeze(1)
        
        # Residual connection and normalization
        spatial_features = self.norm1(spatial_proj + self.dropout(attn_output))
        
        # FFN
        ffn_output = self.ffn(spatial_features)
        
        # Second residual connection
        spatial_features = self.norm2(spatial_features + self.dropout(ffn_output))
        
        # Concatenate with temporal
        final_output = torch.cat([
            spatial_features, 
            temporal_context.squeeze(1) if temporal_context.dim() == 3 else temporal_context
        ], dim=1)
        
        # Apply final projection
        return self.output_proj(final_output)

# ListingGNN (hybrid of paste1 architecture with paste2 efficiency)
class ListingGNN(nn.Module):
    def __init__(self, spatial_dim, temporal_dim, price_history_dim, hidden_dim=64, heads=4):
        super(ListingGNN, self).__init__()
        
        # Spatial component with GATv2Conv (from paste2)
        self.gat = GATv2Conv(spatial_dim, hidden_dim // heads, heads=heads, edge_dim=1)
        
        # Temporal encoder (transformer-based from paste1)
        self.time_series_encoder = TimeSeriesEncoder(
            input_dim=price_history_dim,
            hidden_dim=hidden_dim,
            num_heads=heads,
            num_layers=1
        )
        
        # Calculate proper dimensions
        gat_output_dim = hidden_dim  # hidden_dim // heads * heads = hidden_dim
        
        # CrossAttention between spatial and temporal (from paste1)
        self.cross_attention = CrossAttention(gat_output_dim, hidden_dim, heads=heads)
        
        # Output projection
        self.output_mlp = nn.Sequential(
            nn.Linear(gat_output_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, data):
        # Process spatial data with GAT
        spatial_features = self.gat(data.x, data.edge_index, edge_attr=data.edge_attr)
        
        # Process temporal data with transformer
        if hasattr(data, 'price_history'):
            temporal_features = self.time_series_encoder(
                data.price_history, 
                src_key_padding_mask=data.price_history_mask if hasattr(data, 'price_history_mask') else None
            )
            
            # Get the last non-masked timestep or just the last one
            if hasattr(data, 'price_history_mask') and data.price_history_mask is not None:
                # Find the last non-masked position
                non_mask_indices = (~data.price_history_mask).sum(dim=1) - 1
                non_mask_indices = torch.clamp(non_mask_indices, min=0)
                
                # Get embeddings at these positions
                batch_indices = torch.arange(data.x.shape[0], device=data.x.device)
                temporal_context = temporal_features[batch_indices, non_mask_indices]
            else:
                # Use the last timestep
                temporal_context = temporal_features[:, -1]
        else:
            # Create dummy temporal context if price_history is missing
            temporal_context = torch.zeros((data.x.shape[0], self.time_series_encoder.embedding.out_features), 
                                      device=data.x.device)
        
        # Cross-attention between spatial and temporal
        fused_features = self.cross_attention(spatial_features, temporal_context)
        
        # Prediction
        return self.output_mlp(fused_features)

# ----- Feature Extraction and Data Preparation (optimized) -----

# Extract basic features (modified from paste1)
def extract_basic_features(data, feature_groups):
    """Extract and combine basic features"""
    features = []
    
    # Spatial features (must include latitude/longitude)
    if 'spatial' in feature_groups:
        spatial_features = data[feature_groups['spatial']].copy()
        features.append(spatial_features)
    
    # Property features
    if 'property' in feature_groups:
        property_features = data[feature_groups['property']].copy()
        features.append(property_features)
    
    # Convert to numpy and fill NaN
    result = pd.concat(features, axis=1).values
    result = np.nan_to_num(result, 0)
    
    return result

# Create simple price histories
def create_simple_price_histories(data, sequence_length=10):
    """Create price histories for each listing"""
    price_histories = {}
    
    # Group by listing_id
    for listing_id, group in data.groupby('listing_id'):
        if len(group) >= 2:  # Need at least 2 rows to have a history
            # Sort by date
            sorted_group = group.sort_values('date')
            
            # Get the price values
            prices = sorted_group['price'].values
            
            # Create features: [price(t), price(t)-price(t-1), rolling_mean if available]
            features = []
            features.append(prices.reshape(-1, 1))  # Current price
            
            # Price diff
            diff = np.diff(prices, prepend=prices[0]).reshape(-1, 1)
            features.append(diff)
            
            # Add rolling means if available
            for window in [7, 14, 30]:
                col = f'rolling_mean_{window}d'
                if col in sorted_group.columns:
                    rolling_mean = sorted_group[col].values.reshape(-1, 1)
                    features.append(rolling_mean)
            
            # Combine features
            feature_array = np.hstack(features)
            
            # Create history with correct sequence length
            if len(feature_array) <= sequence_length:
                # Pad with zeros at the beginning
                padded = np.zeros((sequence_length, feature_array.shape[1]))
                padded[-len(feature_array):] = feature_array
                price_histories[listing_id] = torch.FloatTensor(padded)
            else:
                # Take the last sequence_length entries
                price_histories[listing_id] = torch.FloatTensor(feature_array[-sequence_length:])
        else:
            # Create empty history for listings with insufficient data
            price_histories[listing_id] = None
    
    return price_histories

# Extract temporal features
def extract_temporal_features(train_data, test_data, feature_groups, device):
    """Extract temporal features"""
    if 'temporal' not in feature_groups or not feature_groups['temporal']:
        # Return dummy tensor if no temporal features
        return torch.zeros((len(train_data) + len(test_data), 1), device=device)
    
    # Extract temporal features from both datasets
    train_temporal = train_data[feature_groups['temporal']].copy().values
    test_temporal = test_data[feature_groups['temporal']].copy().values
    
    # Combine and convert to tensor
    combined = np.vstack([train_temporal, test_temporal])
    combined = np.nan_to_num(combined, 0)
    
    return torch.FloatTensor(combined).to(device)

# Extract amenity features
def extract_amenity_features(train_data, test_data, feature_groups, device):
    """Extract amenity features"""
    if 'amenity' not in feature_groups or not feature_groups['amenity']:
        # Return dummy tensor if no amenity features
        return torch.zeros((len(train_data) + len(test_data), 1), device=device)
    
    # Extract amenity features from both datasets
    train_amenity = train_data[feature_groups['amenity']].copy().values
    test_amenity = test_data[feature_groups['amenity']].copy().values
    
    # Combine and convert to tensor
    combined = np.vstack([train_amenity, test_amenity])
    combined = np.nan_to_num(combined, 0)
    
    return torch.FloatTensor(combined).to(device)

# Extract price history features
def extract_price_history_features(train_data, test_data, feature_groups, device):
    """Extract price history features"""
    # Define price history features if not in feature_groups
    if 'price_history' not in feature_groups or not feature_groups['price_history']:
        price_history_features = [
            'price_lag_7d', 'price_lag_14d', 'price_lag_30d',
            'rolling_mean_7d', 'rolling_mean_14d', 'rolling_mean_30d'
        ]
        available_features = [f for f in price_history_features if f in train_data.columns]
    else:
        available_features = feature_groups['price_history']
    
    if not available_features:
        # Return dummy tensor if no price history features
        return torch.zeros((len(train_data) + len(test_data), 1), device=device)
    
    # Extract price history features from both datasets
    train_price_history = train_data[available_features].copy().fillna(0).values
    test_price_history = test_data[available_features].copy().fillna(0).values
    
    # Combine and convert to tensor
    combined = np.vstack([train_price_history, test_price_history])
    
    return torch.FloatTensor(combined).to(device)

# Prepare graph data (hybrid of paste1 and paste2)
def prepare_graph_data(train_data, test_data, feature_groups, device, sequence_length=10):
    """Prepare graph data combining approaches from paste1 and paste2"""
    print("Preparing graph data...")
    
    # Create scalers for features
    scalers = {}
    
    # Scale the target (price)
    target_scaler = StandardScaler()
    train_data['price'] = target_scaler.fit_transform(train_data[['price']])
    if 'price' in test_data.columns:
        test_data['price'] = target_scaler.transform(test_data[['price']])
    scalers['target'] = target_scaler
    
    # Process basic features
    spatial_features = feature_groups.get('spatial', [])
    property_features = feature_groups.get('property', [])
    
    # Combine spatial and property features for the main node features
    basic_features = spatial_features + property_features
    
    if not basic_features:
        raise ValueError("No spatial or property features provided!")
    
    train_features = train_data[basic_features].copy().fillna(0).values
    test_features = test_data[basic_features].copy().fillna(0).values
    
    # Create price histories
    print("Creating price histories...")
    train_histories = create_simple_price_histories(train_data, sequence_length)
    test_histories = create_simple_price_histories(test_data, sequence_length)
    
    # Determine dimensions from sample history
    sample_history = None
    for hist in train_histories.values():
        if hist is not None and hist.shape[0] > 0:
            sample_history = hist
            break
            
    if sample_history is None:
        for hist in test_histories.values():
            if hist is not None and hist.shape[0] > 0:
                sample_history = hist
                break
    
    # Prepare batch price history tensor
    all_listing_ids = np.concatenate([
        train_data['listing_id'].values,
        test_data['listing_id'].values
    ])
    
    if sample_history is not None:
        # Determine dimensions from sample
        seq_len, feature_dim = sample_history.shape
        batch_size = len(all_listing_ids)
        
        # Create batch tensor
        batch_histories = torch.zeros((batch_size, seq_len, feature_dim))
        
        # Combine histories
        all_histories = {**train_histories, **test_histories}
        
        # Fill batch tensor
        for i, lid in enumerate(all_listing_ids):
            if lid in all_histories and all_histories[lid] is not None:
                history = all_histories[lid]
                if history.shape[0] > 0:
                    # Make sure dimensions match
                    actual_seq_len = min(history.shape[0], seq_len)
                    batch_histories[i, :actual_seq_len, :] = history[:actual_seq_len]
    else:
        # Fallback if no valid histories found
        batch_histories = torch.zeros((len(all_listing_ids), 1, 1))
    
    # Build graph structure
    print("Building spatial graph...")
    edge_index, edge_attr = build_enhanced_spatial_graph(
        train_data[spatial_features], test_data[spatial_features], k=5
    )
    
    # Create masks and combine data
    train_mask = torch.zeros(len(train_features) + len(test_features), dtype=torch.bool)
    train_mask[:len(train_features)] = True
    
    val_mask = torch.zeros(len(train_features) + len(test_features), dtype=torch.bool)
    val_mask[len(train_features):] = True
    
    # Get target values
    train_y = train_data['price'].values
    test_y = test_data['price'].values if 'price' in test_data.columns else np.zeros(len(test_features))
    all_y = np.concatenate([train_y, test_y])
    
    # Combine features
    all_features = np.vstack([train_features, test_features])
    
    # Create PyG Data object and move everything to device
    data = Data(
        x=torch.FloatTensor(all_features).to(device),
        edge_index=edge_index.to(device),
        edge_attr=edge_attr.to(device),
        y=torch.FloatTensor(all_y.reshape(-1, 1)).to(device),
        train_mask=train_mask.to(device),
        val_mask=val_mask.to(device),
        listing_ids=torch.LongTensor(all_listing_ids).to(device),
        price_history=batch_histories.to(device),
        price_history_mask=(batch_histories.sum(dim=-1) == 0).to(device)
    )
    
    # Extract and add other feature types
    data.temporal_x = extract_temporal_features(train_data, test_data, feature_groups, device)
    data.amenity_x = extract_amenity_features(train_data, test_data, feature_groups, device)
    data.price_history_x = extract_price_history_features(train_data, test_data, feature_groups, device)
    
    return data, scalers

# ----- Optimized Training Function -----

def train_model(train_data, val_data, feature_groups, device='cuda', hidden_dim=64, 
               epochs=50, lr=0.001, sequence_length=10):
    """Train the optimized hybrid GNN model"""
    # Prepare data
    graph_data, scalers = prepare_graph_data(
        train_data, val_data, feature_groups, device, sequence_length
    )
    
    # Get input dimensions
    spatial_dim = graph_data.x.shape[1]
    price_history_dim = graph_data.price_history.shape[2] if hasattr(graph_data, 'price_history') else 1
    
    print(f"Input dimensions - Spatial: {spatial_dim}, Price history: {price_history_dim}")
    
    # Initialize model
    model = ListingGNN(
        spatial_dim=spatial_dim,
        temporal_dim=hidden_dim,
        price_history_dim=price_history_dim,
        hidden_dim=hidden_dim,
        heads=4
    ).to(device)
    
    # Initialize optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    criterion = nn.HuberLoss(delta=1.0)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )
    
    # Initialize gradient scaler for mixed precision
    scaler = GradScaler()
    
    # Training loop
    best_val_loss = float('inf')
    best_model_state = None
    patience = 10
    counter = 0
    
    for epoch in range(epochs):
        # Training
        model.train()
        optimizer.zero_grad()
        
        # Forward pass with mixed precision
        with autocast():
            # Process entire graph at once
            train_out = model(graph_data)[graph_data.train_mask]
            train_y = graph_data.y[graph_data.train_mask]
            loss = criterion(train_out, train_y)
        
        # Backward pass with gradient scaling
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        
        # Validation
        model.eval()
        with torch.no_grad():
            # Forward pass with mixed precision
            with autocast():
                val_out = model(graph_data)[graph_data.val_mask]
                val_y = graph_data.y[graph_data.val_mask]
                val_loss = criterion(val_out, val_y)
            
            # Convert predictions to original scale for metrics
            val_pred_orig = np.expm1(scalers['target'].inverse_transform(val_out.cpu().numpy()))
            val_true_orig = np.expm1(scalers['target'].inverse_transform(val_y.cpu().numpy()))
            
            # Calculate metrics
            val_rmse = np.sqrt(mean_squared_error(val_true_orig, val_pred_orig))
            val_mae = mean_absolute_error(val_true_orig, val_pred_orig)
        
        # Print progress
        print(f"Epoch {epoch+1}/{epochs} - Loss: {loss.item():.4f}, Val Loss: {val_loss.item():.4f}, "
              f"RMSE: {val_rmse:.2f}, MAE: {val_mae:.2f}")
        
        # Learning rate scheduling
        scheduler.step(val_loss)
        
        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict().copy()
            counter = 0
        else:
            counter += 1
            
        if counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break
    
    # Load best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    
    return model, scalers

# Prediction function
def predict_with_model(model, test_data, train_data, feature_groups, scalers, device, sequence_length=10):
    """Make predictions with the trained hybrid model"""
    # Prepare graph data
    graph_data, _ = prepare_graph_data(
        train_data, test_data, feature_groups, device, sequence_length
    )
    
    # Make predictions
    model.eval()
    with torch.no_grad():
        with autocast():
            predictions = model(graph_data)[graph_data.val_mask]
        
        # Transform back to original scale
        predictions_np = scalers['target'].inverse_transform(predictions.cpu().numpy())
        
        # Inverse log transformation
        predictions_orig = np.expm1(predictions_np)
    
    return predictions_orig


# Main function to run the model with rolling window CV
def run_hybrid_gnn_with_rolling_window_cv(train_path, train_ids_path, test_ids_path, output_dir=None, 
                                       window_size=35, n_splits=5, sample_size=None, sequence_length=10,
                                       hidden_dim=64):
    """
    Run hybrid GNN model with rolling window cross-validation
    
    Parameters:
    -----------
    train_path : str
        Path to the training CSV file
    train_ids_path : str
        Path to text file with training listing IDs
    test_ids_path : str
        Path to text file with test listing IDs
    output_dir : str, optional
        Directory to save results
    window_size : int, optional
        Size of the rolling window in days
    n_splits : int, optional
        Number of splits for time series cross-validation
    sample_size : int, optional
        Limit dataset to this number of random listings (for testing)
    sequence_length : int, optional
        Number of previous time steps to use in price history
    hidden_dim : int, optional
        Hidden dimension size for neural network
    """
    print(f"Processing dataset: {os.path.basename(train_path)}")
    
    # Create output directory if not exists
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
    
    # Load training data
    print("Loading data...")
    train_data = pd.read_csv(train_path)

    # Load listing IDs for train/test split
    print("Loading train/test listing IDs...")
    with open(train_ids_path, 'r') as f:
        train_listing_ids = [int(line.strip()) for line in f.readlines()]
        
    with open(test_ids_path, 'r') as f:
        test_listing_ids = [int(line.strip()) for line in f.readlines()]
    
    print(f"Loaded {len(train_listing_ids)} train IDs and {len(test_listing_ids)} test IDs")

    # Drop legacy price columns if they exist
    price_cols_to_remove = ['price_lag_1d', 'simulated_price']
    for col in price_cols_to_remove:
        if col in train_data.columns:
            print(f"Dropping {col} column from the dataset")
            train_data = train_data.drop(col, axis=1)
    
    # For testing - take only a small sample of listings if specified
    if sample_size:
        print(f"Limiting to {sample_size} random listings for testing")
        np.random.seed(42)
        selected_train = np.random.choice(train_listing_ids, int(sample_size * 0.7), replace=False)
        selected_test = np.random.choice(test_listing_ids, int(sample_size * 0.3), replace=False)
        train_listing_ids = selected_train.tolist()
        test_listing_ids = selected_test.tolist()
    
    # Convert date column to datetime
    train_data['date'] = pd.to_datetime(train_data['date'])
    
    # Filter data to include only dates from 7/8/23 till 2/8/24
    start_date = pd.to_datetime('2023-07-08')
    end_date = pd.to_datetime('2024-02-08')
    train_data = train_data[(train_data['date'] >= start_date) & (train_data['date'] <= end_date)]
    
    # Apply log transformation to price
    train_data = apply_price_transformation(train_data)
    
    # Create calculated features
    print("Creating calculated features...")
    train_data = create_calculated_features(train_data)
    
    # Check for NaN values in the dataset and fill them
    nan_columns = train_data.columns[train_data.isna().any()].tolist()
    if nan_columns:
        print(f"Warning: Found NaN values in columns: {nan_columns}")
        print("Filling NaN values with column means/medians")
        
        for col in nan_columns:
            if np.issubdtype(train_data[col].dtype, np.number):
                # Fill with median for numeric columns
                train_data[col] = train_data[col].fillna(train_data[col].median())
            else:
                # For non-numeric, fill with mode
                train_data[col] = train_data[col].fillna(train_data[col].mode()[0])
    
    # Define feature groups using a dictionary structure (for flexibility)
    feature_groups = {
        'spatial': ['latitude', 'longitude'],
        'property': ['accommodates', 'bedrooms', 'bathrooms', 'essential_score', 'luxury_score', 'amenity_count', 'bedroom_ratio'],
        'amenity': [col for col in train_data.columns if col.startswith('has_')],
        'temporal': [
            'DTF_day_of_week', 'DTF_month', 'DTF_is_weekend',
            'DTF_season_sin', 'DTF_season_cos'
        ],
        'price_history': [
            'price_lag_7d', 'price_lag_14d', 'price_lag_30d',
            'rolling_mean_7d', 'rolling_max_7d', 'rolling_min_7d',
            'rolling_mean_14d', 'rolling_max_14d', 'rolling_min_14d',
            'rolling_mean_30d', 'rolling_max_30d', 'rolling_min_30d',
            'price_range_7d', 'price_range_14d', 'price_range_30d'
        ]
    }
    
    # Add additional temporal features if available
    additional_temporal = ['DTF_day', 'DTF_is_holiday', 'DTF_days_to_weekend', 
                         'DTF_days_since_start', 'DTF_days_to_end_month']
    for feat in additional_temporal:
        if feat in train_data.columns:
            feature_groups['temporal'].append(feat)
    
    # Ensure all feature groups only contain columns that exist in the dataset
    for group in feature_groups:
        feature_groups[group] = [f for f in feature_groups[group] if f in train_data.columns]
    
    # Get unique dates and create test periods
    unique_dates = sorted(train_data['date'].dt.date.unique())
    last_35_days = unique_dates[-window_size:]
    
    # Define explicit test periods
    test_periods = []
    for i in range(n_splits):
        start_idx = i * (window_size // n_splits)
        end_idx = start_idx + (window_size // n_splits)
        if end_idx <= len(last_35_days):
            test_periods.append((last_35_days[start_idx], last_35_days[end_idx-1]))
    
    n_splits = len(test_periods)
    
    print(f"Created {n_splits} test periods:")
    for i, (test_start, test_end) in enumerate(test_periods):
        print(f"  Period {i+1}: {test_start} to {test_end}")
    
    # Storage for results
    cv_results = []
    all_predictions = []
    all_targets = []
    split_metrics = []
    
    # Initialize device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Print model configuration
    print(f"Using hybrid GNN with transformer-based temporal processing")
    print(f"Sequence length: {sequence_length}, Hidden dim: {hidden_dim}")
    
    # Run time series cross-validation using our explicit test periods
    for i, (test_start, test_end) in enumerate(test_periods):
        print(f"\n===== Split {i+1}/{n_splits} =====")
        
        # Define training period: everything before test_start
        train_end = pd.to_datetime(test_start) - pd.Timedelta(days=1)
        train_end_date = train_end.date()
        
        print(f"Training period: {unique_dates[0]} to {train_end_date}")
        print(f"Testing period: {test_start} to {test_end}")
        
        # Split by date first
        train_date_mask = train_data['date'].dt.date <= train_end_date
        test_date_mask = (train_data['date'].dt.date >= test_start) & (train_data['date'].dt.date <= test_end)
        
        date_filtered_train = train_data[train_date_mask]
        date_filtered_test = train_data[test_date_mask]
        
        # Now further split by listing IDs
        train_id_mask = date_filtered_train['listing_id'].isin(train_listing_ids)
        test_id_mask = date_filtered_test['listing_id'].isin(test_listing_ids)
        
        split_train_data = date_filtered_train[train_id_mask].copy()
        split_test_data = date_filtered_test[test_id_mask].copy()
        
        print(f"Train data: {len(split_train_data)} rows, {len(split_train_data['listing_id'].unique())} unique listings")
        print(f"Test data: {len(split_test_data)} rows, {len(split_test_data['listing_id'].unique())} unique listings")
        
        # Check if we have enough data for this split
        if len(split_train_data) < 100 or len(split_test_data) < 10:
            print(f"Insufficient data for split {i+1}, skipping")
            continue
            
        # Split train data into train and validation
        unique_train_listings = split_train_data['listing_id'].unique()
        train_listings, val_listings = train_test_split(
            unique_train_listings, test_size=0.2, random_state=42
        )
        
        train_subset = split_train_data[split_train_data['listing_id'].isin(train_listings)].copy()
        val_subset = split_train_data[split_train_data['listing_id'].isin(val_listings)].copy()
        
        # Train hybrid GNN model
        try:
            print(f"\n----- Training Hybrid GNN Model (Split {i+1}) -----")
            
            # Clear GPU memory before training
            torch.cuda.empty_cache()
            
            # Train the model
            model, scalers = train_model(
                train_subset, val_subset, feature_groups, 
                device=device, 
                hidden_dim=hidden_dim, 
                epochs=30,  # Reduced from 50 
                lr=0.001, 
                sequence_length=sequence_length
            )
            
            # Evaluate on test data
            print(f"\n----- Evaluating Hybrid GNN on Test Data (Split {i+1}) -----")
            test_predictions = predict_with_model(
                model, split_test_data, train_subset, feature_groups, scalers,
                device, sequence_length=sequence_length
            )
            
            # Get actual test values (original scale)
            test_actuals = split_test_data['original_price'].values if 'original_price' in split_test_data.columns else split_test_data['price'].values
            
            # Evaluate predictions
            metrics = evaluate_gnn_predictions(test_actuals, test_predictions.flatten(), print_results=True)
            
            print(f"Split {i+1} Results - RMSE: {metrics['rmse']:.4f}, MAE: {metrics['mae']:.4f}, R²: {metrics['r2']:.4f}")
            
            # Store results for this split
            split_results = pd.DataFrame({
                'split': i,
                'date': split_test_data['date'],
                'listing_id': split_test_data['listing_id'],
                'price': test_actuals,
                'predicted': test_predictions.flatten(),
                'error': test_actuals - test_predictions.flatten(),
                'abs_error': np.abs(test_actuals - test_predictions.flatten()),
                'pct_error': np.abs((test_actuals - test_predictions.flatten()) / (test_actuals + 1e-8)) * 100
            })
            
            cv_results.append(split_results)
            all_predictions.extend(test_predictions.flatten())
            all_targets.extend(test_actuals)
            
            # Save model for this split if output_dir is provided
            if output_dir:
                model_path = os.path.join(output_dir, f'hybrid_gnn_model_split_{i+1}.pt')
                torch.save(model.state_dict(), model_path)
                print(f"Model for split {i+1} saved to {model_path}")
            
            # Store metrics for this split
            split_metrics.append({
                'split': i,
                'rmse': metrics['rmse'],
                'mae': metrics['mae'],
                'r2': metrics['r2'],
                'mape': metrics['mape'],
                'n_samples': len(test_actuals)
            })
            
            # Clear memory after each split
            del model, scalers
            torch.cuda.empty_cache()
            
        except Exception as e:
            print(f"Error in split {i+1}: {str(e)}")
            import traceback
            traceback.print_exc()
            continue
    
    # Process results
    if not cv_results:
        print("No valid splits completed. Check your data and parameters.")
        return None
        
    all_results = pd.concat(cv_results, ignore_index=True)
    
    # Calculate overall metrics
    all_targets_array = np.array(all_targets)
    all_predictions_array = np.array(all_predictions)
    
    overall_metrics = {
        'rmse': np.sqrt(mean_squared_error(all_targets_array, all_predictions_array)),
        'mae': mean_absolute_error(all_targets_array, all_predictions_array),
        'r2': r2_score(all_targets_array, all_predictions_array),
        'mape': np.mean(np.abs((all_targets_array - all_predictions_array) / (all_targets_array + 1e-8))) * 100
    }
    
    # Calculate daily metrics
    all_results['date_str'] = pd.to_datetime(all_results['date']).dt.strftime('%Y-%m-%d')
    
    daily_metrics = []
    for day, group in all_results.groupby('date_str'):
        y_true_day = group['price']
        y_pred_day = group['predicted']
        
        daily_metrics.append({
            'date': day,
            'rmse': np.sqrt(mean_squared_error(y_true_day, y_pred_day)),
            'mae': mean_absolute_error(y_true_day, y_pred_day),
            'r2': r2_score(y_true_day, y_pred_day) if len(set(y_true_day)) > 1 else np.nan,
            'mape': np.mean(np.abs((y_true_day - y_pred_day) / (y_true_day + 1e-8))) * 100,
            'n_samples': len(y_true_day)
        })
    
    daily_metrics_df = pd.DataFrame(daily_metrics)
    daily_metrics_df['date'] = pd.to_datetime(daily_metrics_df['date'])
    daily_metrics_df = daily_metrics_df.sort_values('date')
    
    split_metrics_df = pd.DataFrame(split_metrics)
    
    # Create a results dictionary
    evaluation_results = {
        'overall_metrics': overall_metrics,
        'split_metrics': split_metrics_df,
        'daily_metrics': daily_metrics_df,
        'all_results': all_results,
        'train_listings': len(train_listing_ids),
        'test_listings': len(test_listing_ids),
        'config': {
            'model_type': 'hybrid_transformer_gnn',
            'sequence_length': sequence_length,
            'hidden_dim': hidden_dim
        }
    }
    
    # Print summary
    print("\n===== HYBRID TRANSFORMER-GNN MODEL SUMMARY =====")
    print(f"Using {len(train_listing_ids)} listings for training and {len(test_listing_ids)} listings for testing")
    
    print("\n=== Overall Metrics ===")
    print(f"RMSE: {overall_metrics['rmse']:.4f}")
    print(f"MAE: {overall_metrics['mae']:.4f}")
    print(f"R²: {overall_metrics['r2']:.4f}")
    print(f"MAPE: {overall_metrics['mape']:.4f}%")
    
    print("\n=== Split Performance ===")
    print(split_metrics_df[['split', 'rmse', 'mae', 'r2', 'n_samples']].to_string(index=False))
    
    # Save results if output directory is provided
    if output_dir:
        # Save all results
        results_file = os.path.join(output_dir, 'hybrid_gnn_results.csv')
        all_results.to_csv(results_file, index=False)
        print(f"Results saved to {results_file}")
        
        # Save metrics
        metrics_file = os.path.join(output_dir, 'hybrid_gnn_metrics.csv')
        daily_metrics_df.to_csv(metrics_file, index=False)
        print(f"Daily metrics saved to {metrics_file}")
        
        # Plot and save visualizations
        plot_gnn_rolling_window_results(evaluation_results, output_dir)
        plot_listing_predictions(all_results, num_listings=10, output_dir=output_dir)
    
    return evaluation_results

# Function to plot rolling window results
def plot_gnn_rolling_window_results(evaluation_results, output_dir=None):
    """Plot the results from GNN rolling window cross-validation"""
    # Set style
    sns.set_theme(style="whitegrid")
    
    # Extract data
    daily_metrics = evaluation_results['daily_metrics']
    all_results = evaluation_results['all_results']
    splits = evaluation_results['split_metrics']
    
    # Create a figure with multiple subplots
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # Add title
    fig.suptitle('Hybrid GNN Model Evaluation with Rolling Window CV', fontsize=16)
    
    # Plot 1: Daily MAE
    sns.lineplot(
        x=pd.to_datetime(daily_metrics['date']),
        y=daily_metrics['mae'],
        marker='o',
        ax=axes[0, 0]
    )
    axes[0, 0].set_title('Mean Absolute Error by Day')
    axes[0, 0].set_xlabel('Date')
    axes[0, 0].set_ylabel('MAE')
    axes[0, 0].xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
    axes[0, 0].xaxis.set_major_locator(mdates.MonthLocator())
    plt.setp(axes[0, 0].xaxis.get_majorticklabels(), rotation=45)
    
    # Plot 2: Cross-validation splits performance
    splits_x = splits['split']
    metrics_to_plot = ['rmse', 'mae']
    
    for metric in metrics_to_plot:
        sns.lineplot(
            x=splits_x,
            y=splits[metric],
            marker='o',
            label=metric.upper(),
            ax=axes[0, 1]
        )
    
    axes[0, 1].set_title('Performance Across CV Splits')
    axes[0, 1].set_xlabel('CV Split')
    axes[0, 1].set_ylabel('Error Metric')
    axes[0, 1].legend()
    
    # Plot 3: Actual vs Predicted (colored by split)
    scatter = axes[1, 0].scatter(
        all_results['price'],
        all_results['predicted'],
        c=all_results['split'],
        alpha=0.6,
        cmap='viridis'
    )
    min_val = min(all_results['price'].min(), all_results['predicted'].min())
    max_val = max(all_results['price'].max(), all_results['predicted'].max())
    axes[1, 0].plot([min_val, max_val], [min_val, max_val], 'k--')
    axes[1, 0].set_title('Actual vs Predicted (Colored by CV Split)')
    axes[1, 0].set_xlabel('Actual')
    axes[1, 0].set_ylabel('Predicted')
    
    # Add colorbar
    cbar = plt.colorbar(scatter, ax=axes[1, 0])
    cbar.set_label('CV Split')
    
    # Plot 4: Error distribution
    sns.histplot(all_results['error'], kde=True, ax=axes[1, 1])
    axes[1, 1].axvline(0, color='r', linestyle='--')
    axes[1, 1].set_title('Error Distribution')
    axes[1, 1].set_xlabel('Error (Actual - Predicted)')
    axes[1, 1].set_ylabel('Frequency')
    
    plt.tight_layout()
    plt.subplots_adjust(top=0.92)  # Make room for the suptitle
    
    # Save if output directory provided
    if output_dir:
        plt.savefig(os.path.join(output_dir, 'hybrid_gnn_results.png'))
    
    plt.show()

# Function to plot individual listing predictions
def plot_listing_predictions(all_results, num_listings=10, output_dir=None):
    """
    Plot actual vs predicted prices for a sample of individual listings
    """
    # Get unique listing IDs
    unique_listings = all_results['listing_id'].unique()
    
    # Select a sample of listings that have multiple data points
    listing_counts = all_results.groupby('listing_id').size()
    listings_with_multiple_points = listing_counts[listing_counts > 3].index.tolist()
    
    # If we don't have enough listings with multiple points, use what we have
    if len(listings_with_multiple_points) < num_listings:
        sample_listings = listings_with_multiple_points + list(unique_listings[:num_listings - len(listings_with_multiple_points)])
        sample_listings = sample_listings[:num_listings]  # Ensure we don't exceed requested number
    else:
        # Randomly select listings
        np.random.seed(42)  # For reproducibility
        sample_listings = np.random.choice(listings_with_multiple_points, num_listings, replace=False)
    
    # Create subplots
    fig, axes = plt.subplots(math.ceil(num_listings/2), 2, figsize=(15, 3*math.ceil(num_listings/2)))
    axes = axes.flatten()
    
    for i, listing_id in enumerate(sample_listings):
        if i >= len(axes):  # Safety check
            break
            
        # Get data for this listing
        listing_data = all_results[all_results['listing_id'] == listing_id].copy()
        listing_data = listing_data.sort_values('date')
        
        # Plot actual and predicted prices
        ax = axes[i]
        ax.plot(pd.to_datetime(listing_data['date']), listing_data['price'], 'o-', label='Actual', color='blue')
        ax.plot(pd.to_datetime(listing_data['date']), listing_data['predicted'], 's--', label='Predicted', color='red')
        
        # Calculate RMSE for this listing
        rmse = np.sqrt(mean_squared_error(listing_data['price'], listing_data['predicted']))
        
        ax.set_title(f'Listing ID: {listing_id} (RMSE: {rmse:.2f})')
        ax.set_xlabel('Date')
        ax.set_ylabel('Price')
        ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d'))
        ax.tick_params(axis='x', rotation=45)
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    # Hide any unused subplots
    for j in range(i+1, len(axes)):
        axes[j].set_visible(False)
    
    plt.tight_layout()
    
    # Save plot if output_dir is provided
    if output_dir:
        plt.savefig(os.path.join(output_dir, 'listing_predictions.png'))
    
    plt.show()

if __name__ == "__main__":
    # Set paths to your data
    train_path = "train_up3.csv"
    train_ids_path = "train_ids.txt"
    test_ids_path = "test_ids.txt"
    
    # Output directory
    output_dir = "./output/hybrid_gnn"
    os.makedirs(output_dir, exist_ok=True)
    
    try:
        # Run with hybrid approach
        results = run_hybrid_gnn_with_rolling_window_cv(
            train_path=train_path,
            train_ids_path=train_ids_path,
            test_ids_path=test_ids_path,
            output_dir=output_dir,
            window_size=35,
            n_splits=5,
            sample_size=None,  # Use full dataset
            sequence_length=10,
            hidden_dim=64
        )
        print(f"Hybrid transformer-based GNN model training completed successfully!")
        
    except Exception as e:
        print(f"Error running hybrid transformer-based GNN model: {str(e)}")
        import traceback
        traceback.print_exc()