In [3]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.neighbors import NearestNeighbors
from torch_geometric.nn import GATv2Conv
from torch_geometric.data import Data
from torch.cuda.amp import autocast, GradScaler
import warnings
warnings.filterwarnings('ignore')

# 1. Price transformation function
def apply_price_transformation(train_data, inverse=False):
    """
    Apply log transformation to price data or inverse the transformation
    
    Parameters:
    -----------
    train_data : DataFrame
        The dataframe containing price data
    inverse : bool
        If True, apply inverse transformation; otherwise apply log transformation
        
    Returns:
    --------
    DataFrame
        Modified dataframe with transformed prices
    """
    df = train_data.copy()
    
    if not inverse:
        # Apply log transformation
        print("Applying log transformation to price data")
        df['original_price'] = df['price']  # Store original price
        df['price'] = np.log1p(df['price'])  # log1p to handle zero values
    else:
        # Inverse transform
        print("Inverting log transformation for predictions")
        df['price'] = np.expm1(df['price'])  # expm1 is the inverse of log1p
    
    return df

# 2. Create calculated features
def create_calculated_features(df):
    """
    Adapt calculated features to work with provided dataset columns
    
    Parameters:
    -----------
    df : DataFrame
        The dataframe to add features to
    
    Returns:
    --------
    DataFrame
        Modified dataframe with new features
    """
    # Create a copy to avoid modifying the original
    df_copy = df.copy()
    
    # Bedroom ratio
    if 'bedrooms' in df_copy.columns and 'accommodates' in df_copy.columns:
        df_copy['bedroom_ratio'] = df_copy['bedrooms'] / df_copy['accommodates'].clip(lower=1)
    
    # Count amenities
    amenity_columns = df_copy.filter(like='has_').columns
    if len(amenity_columns) > 0:
        df_copy['amenity_count'] = df_copy[amenity_columns].sum(axis=1)
    
    # Luxury score - use specific amenities from your dataset
    luxury_amenities = ['has_hot_water', 'has_hair_dryer', 'has_dedicated_workspace', 
                         'has_tv', 'has_wifi', 'has_shampoo']
    available_luxury = [col for col in luxury_amenities if col in df_copy.columns]
    
    if available_luxury:
        df_copy['luxury_score'] = df_copy[available_luxury].sum(axis=1) / len(available_luxury)
    else:
        df_copy['luxury_score'] = 0
    
    # Essential score - basic amenities that are essential
    essential_amenities = ['has_essentials', 'has_bed_linens', 'has_kitchen', 
                           'has_smoke_alarm', 'has_heating']
    available_essential = [col for col in essential_amenities if col in df_copy.columns]
    
    if available_essential:
        df_copy['essential_score'] = df_copy[available_essential].sum(axis=1) / len(available_essential)
    else:
        df_copy['essential_score'] = 0
    
    # Price volatility features based on rolling statistics
    if all(col in df_copy.columns for col in ['rolling_max_7d', 'rolling_min_7d']):
        df_copy['price_range_7d'] = df_copy['rolling_max_7d'] - df_copy['rolling_min_7d']
    
    if all(col in df_copy.columns for col in ['rolling_max_14d', 'rolling_min_14d']):
        df_copy['price_range_14d'] = df_copy['rolling_max_14d'] - df_copy['rolling_min_14d']
    
    if all(col in df_copy.columns for col in ['rolling_max_30d', 'rolling_min_30d']):
        df_copy['price_range_30d'] = df_copy['rolling_max_30d'] - df_copy['rolling_min_30d']
    
    # Fill any NaN values that might have been created
    numeric_cols = df_copy.select_dtypes(include=['number']).columns
    for col in numeric_cols:
        if df_copy[col].isnull().any():
            df_copy[col] = df_copy[col].fillna(df_copy[col].median())
    
    return df_copy

# 3. Function to evaluate predictions
def evaluate_gnn_predictions(y_true, y_pred, print_results=True):
    """
    Evaluate GNN predictions using multiple metrics
    """
    # Calculate metrics
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    mae = mean_absolute_error(y_true, y_pred)
    r2 = r2_score(y_true, y_pred)
    mape = np.mean(np.abs((y_true - y_pred) / (y_true + 1e-8))) * 100
    
    # Store metrics in dictionary
    metrics = {
        'rmse': rmse,
        'mae': mae,
        'r2': r2,
        'mape': mape
    }
    
    # Print results if requested
    if print_results:
        print("=== GNN Model Evaluation ===")
        print(f"RMSE: {rmse:.2f}")
        print(f"MAE: {mae:.2f}")
        print(f"R²: {r2:.4f}")
        print(f"MAPE: {mape:.2f}%")
    
    return metrics

# 4. Create simple price histories
def create_simple_price_histories(data, sequence_length=30):
    """Create raw price histories without complex feature engineering"""
    histories = {}
    
    for listing_id in data['listing_id'].unique():
        listing_data = data[data['listing_id'] == listing_id].sort_values('date')
        
        if len(listing_data) >= sequence_length:
            # Just use the actual prices - minimal preprocessing
            prices = listing_data['price'].values[-sequence_length:]
            
            # Optionally add a few key features like day of week, is_weekend
            temporal_info = listing_data[['DTF_is_weekend', 'DTF_month']].values[-sequence_length:]
            
            # Combine price with minimal temporal features
            history = np.column_stack([prices, temporal_info])
            histories[listing_id] = torch.FloatTensor(history)
    
    return histories

# 5. Generate price history features
def generate_price_history_features(train_data, test_data=None, use_train_for_test=True, window_sizes=[7, 14, 30]):
    """
    Generate price history features for training and test data
    
    Parameters:
    -----------
    train_data : DataFrame
        Training data with 'listing_id', 'date', and 'price' columns
    test_data : DataFrame, optional
        Test data with 'listing_id', 'date', and 'price' columns
    use_train_for_test : bool, optional
        If True, use training data to compute price history features for test data
    window_sizes : list, optional
        List of window sizes (in days) for rolling statistics
        
    Returns:
    --------
    tuple
        (train_data_with_features, test_data_with_features) DataFrames with added price history features
    """
    # Make copies to avoid modifying originals
    train_df = train_data.copy()
    test_df = test_data.copy() if test_data is not None else None
    
    # Ensure date column is datetime
    if not pd.api.types.is_datetime64_any_dtype(train_df['date']):
        train_df['date'] = pd.to_datetime(train_df['date'])
    
    if test_df is not None and not pd.api.types.is_datetime64_any_dtype(test_df['date']):
        test_df['date'] = pd.to_datetime(test_df['date'])
    
    # Sort by listing_id and date
    train_df = train_df.sort_values(['listing_id', 'date'])
    
    # Prepare data for feature generation
    if test_df is not None and use_train_for_test:
        # Create a combined dataframe for computing features
        combined_df = pd.concat([train_df, test_df], ignore_index=True)
        combined_df = combined_df.sort_values(['listing_id', 'date'])
        data_for_features = combined_df
    else:
        data_for_features = train_df
    
    # Initialize list of created features
    created_features = []
    
    # Generate lag features
    for window in window_sizes:
        feature_name = f'price_lag_{window}d'
        data_for_features[feature_name] = data_for_features.groupby('listing_id')['price'].shift(window)
        created_features.append(feature_name)
    
    # Generate rolling statistics
    for window in window_sizes:
        # Rolling mean
        feature_name = f'rolling_mean_{window}d'
        data_for_features[feature_name] = data_for_features.groupby('listing_id')['price'].transform(
            lambda x: x.rolling(window, min_periods=1).mean()
        )
        created_features.append(feature_name)
        
        # Rolling max
        feature_name = f'rolling_max_{window}d'
        data_for_features[feature_name] = data_for_features.groupby('listing_id')['price'].transform(
            lambda x: x.rolling(window, min_periods=1).max()
        )
        created_features.append(feature_name)
        
        # Rolling min
        feature_name = f'rolling_min_{window}d'
        data_for_features[feature_name] = data_for_features.groupby('listing_id')['price'].transform(
            lambda x: x.rolling(window, min_periods=1).min()
        )
        created_features.append(feature_name)
        
        # Rolling std
        feature_name = f'rolling_std_{window}d'
        data_for_features[feature_name] = data_for_features.groupby('listing_id')['price'].transform(
            lambda x: x.rolling(window, min_periods=1).std()
        )
        created_features.append(feature_name)
        
        # Price volatility (max - min)
        feature_name = f'price_range_{window}d'
        data_for_features[feature_name] = data_for_features[f'rolling_max_{window}d'] - data_for_features[f'rolling_min_{window}d']
        created_features.append(feature_name)
    
    # Fill NaN values
    for feature in created_features:
        if data_for_features[feature].isnull().any():
            # Fill within each listing_id group
            data_for_features[feature] = data_for_features.groupby('listing_id')[feature].transform(
                lambda x: x.fillna(x.median() if x.notna().any() else 0)
            )
    
    # Split combined data back to train and test if necessary
    if test_df is not None and use_train_for_test:
        # Get indices of train and test rows
        train_indices = data_for_features.index[:len(train_df)]
        test_indices = data_for_features.index[len(train_df):]
        
        # Extract features for train and test
        train_df = data_for_features.loc[train_indices].copy()
        test_df = data_for_features.loc[test_indices].copy()
        
        return train_df, test_df
    else:
        if test_df is not None:
            # Generate features for test separately
            test_df = test_df.sort_values(['listing_id', 'date'])
            for feature in created_features:
                if feature in data_for_features.columns:
                    # Just create empty columns that will be populated later
                    test_df[feature] = np.nan
            
            return data_for_features, test_df
        else:
            return data_for_features, None

# 6. Extract basic features
def extract_basic_features(data, feature_groups):
    """Extract basic features from the data"""
    features = []
    for group in ['spatial', 'property']:
        if group in feature_groups:
            features.extend(feature_groups[group])
    
    # Get only columns that exist in the data
    valid_features = [f for f in features if f in data.columns]
    
    if not valid_features:
        # Return dummy features if no valid features found
        return np.zeros((len(data), 1))
    
    # Extract and handle NaN values
    feature_matrix = data[valid_features].values
    feature_matrix = np.nan_to_num(feature_matrix, nan=0.0)
    
    return feature_matrix

# 7. Extract temporal features
def extract_temporal_features(train_data, test_data, feature_groups, device):
    """Extract temporal features"""
    if 'temporal' not in feature_groups or not feature_groups['temporal']:
        return torch.zeros(len(train_data) + len(test_data), 1).to(device)
    
    features = feature_groups['temporal']
    valid_features = [f for f in features if f in train_data.columns and f in test_data.columns]
    
    if not valid_features:
        return torch.zeros(len(train_data) + len(test_data), 1).to(device)
    
    train_features = train_data[valid_features].values
    test_features = test_data[valid_features].values
    
    combined = np.vstack([train_features, test_features])
    combined = np.nan_to_num(combined, nan=0.0)
    
    return torch.FloatTensor(combined).to(device)

# 8. Extract amenity features
def extract_amenity_features(train_data, test_data, feature_groups, device):
    """Extract amenity features"""
    if 'amenity' not in feature_groups or not feature_groups['amenity']:
        return torch.zeros(len(train_data) + len(test_data), 1).to(device)
    
    features = feature_groups['amenity']
    valid_features = [f for f in features if f in train_data.columns and f in test_data.columns]
    
    if not valid_features:
        return torch.zeros(len(train_data) + len(test_data), 1).to(device)
    
    train_features = train_data[valid_features].values
    test_features = test_data[valid_features].values
    
    combined = np.vstack([train_features, test_features])
    combined = np.nan_to_num(combined, nan=0.0)
    
    return torch.FloatTensor(combined).to(device)

# 9. Extract price history features
def extract_price_history_features(train_data, test_data, feature_groups, device):
    """Extract price history features"""
    price_history_cols = [col for col in train_data.columns if col.startswith(('price_lag', 'rolling'))]
    
    if not price_history_cols:
        return torch.zeros(len(train_data) + len(test_data), 1).to(device)
    
    train_features = train_data[price_history_cols].values
    test_features = test_data[price_history_cols].values
    
    combined = np.vstack([train_features, test_features])
    combined = np.nan_to_num(combined, nan=0.0)
    
    return torch.FloatTensor(combined).to(device)

# 10. Process listing history
def process_listing_history(listing_data, sequence_length=30):
    """Process the history for a single listing"""
    if len(listing_data) < 2:
        return torch.zeros((1, 1))
    
    # Sort by date
    listing_data = listing_data.sort_values('date')
    
    # Extract price
    prices = listing_data['price'].values[-sequence_length:]
    
    # Extract temporal features if available
    temporal_features = []
    if 'DTF_is_weekend' in listing_data.columns and 'DTF_month' in listing_data.columns:
        temporal_info = listing_data[['DTF_is_weekend', 'DTF_month']].values[-sequence_length:]
        temporal_features.append(temporal_info)
    
    # Combine features
    history_features = [prices.reshape(-1, 1)]
    history_features.extend(temporal_features)
    
    history = np.concatenate(history_features, axis=1) if temporal_features else prices.reshape(-1, 1)
    
    return torch.FloatTensor(history)

In [4]:
# Keep utility functions unchanged
# apply_price_transformation, create_calculated_features, evaluate_gnn_predictions, etc.
# [...]

# Modified: Memory efficient graph building
def build_enhanced_spatial_graph(train_data, test_data, k=3, feature_weight=0.3):
    """
    Build a graph with edge weights based on both geographic and feature similarity
    with reduced memory usage
    """
    # Extract coordinates
    train_coords = train_data[['latitude', 'longitude']].values
    test_coords = test_data[['latitude', 'longitude']].values
    
    print(f"Building enhanced spatial graph with {len(test_coords)} test listings and {k} nearest neighbors...")
    
    # Extract and normalize key features
    features = ['accommodates', 'bedrooms', 'bathrooms']
    available_features = [f for f in features if f in train_data.columns]
    
    if available_features:
        scaler = StandardScaler()
        train_features = scaler.fit_transform(train_data[available_features].fillna(0))
        test_features = scaler.transform(test_data[available_features].fillna(0))
    else:
        train_features = np.ones((len(train_coords), 1))
        test_features = np.ones((len(test_coords), 1))
    
    # Find k nearest neighbors for each test listing
    nn = NearestNeighbors(n_neighbors=min(k, len(train_coords)))
    nn.fit(train_coords)
    distances, indices = nn.kneighbors(test_coords)
    
    # Create edge indices and attributes in smaller batches
    edge_index = []
    edge_attr = []
    
    batch_size = 1000  # Process in batches to reduce memory
    for batch_start in range(0, len(test_coords), batch_size):
        batch_end = min(batch_start + batch_size, len(test_coords))
        
        for test_idx in range(batch_start, batch_end):
            test_feat = test_features[test_idx - batch_start]
            
            for neighbor_idx, distance in zip(indices[test_idx - batch_start], distances[test_idx - batch_start]):
                # Calculate feature similarity
                train_feat = train_features[neighbor_idx]
                feat_norm_product = np.linalg.norm(test_feat) * np.linalg.norm(train_feat)
                
                if feat_norm_product > 1e-8:
                    feat_sim = np.dot(test_feat, train_feat) / feat_norm_product
                else:
                    feat_sim = 0.0
                
                # Normalize distance for better numerical stability
                geo_weight = 1.0 / (distance + 1e-6)
                
                # Combined weight
                combined_weight = (1 - feature_weight) * geo_weight + feature_weight * max(0, feat_sim)
                
                # Add edges
                edge_index.append([test_idx + len(train_data), neighbor_idx])
                edge_attr.append([combined_weight])
                
                # Add reverse edge
                edge_index.append([neighbor_idx, test_idx + len(train_data)])
                edge_attr.append([combined_weight])
    
    # Convert to tensors with explicit dtype
    edge_index_tensor = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr_tensor = torch.tensor(edge_attr, dtype=torch.float32)
    
    print(f"Created graph with {edge_index_tensor.shape[1]} edges")
    
    return edge_index_tensor, edge_attr_tensor

# NEW: Memory-efficient TimeSeriesEncoder
class TimeSeriesEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_heads=2, num_layers=1, dropout=0.1):
        super(TimeSeriesEncoder, self).__init__()
        self.embedding = nn.Linear(input_dim, hidden_dim)
        
        # Simpler transformer with fewer parameters
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,  # Reduced from 4 to 2
            dim_feedforward=hidden_dim*2,  # Reduced from 4x to 2x
            dropout=dropout,
            batch_first=True,
            norm_first=True  # More stable training
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
    def forward(self, x, src_key_padding_mask=None):
        if x.shape[0] == 0:  # Handle empty batch case
            return torch.zeros((0, self.embedding.out_features), device=x.device)
            
        # x shape: [batch_size, seq_len, features]
        x = self.embedding(x)
        
        # Handle potential memory issues with very large batches
        if x.shape[0] > 1000:  # Process in sub-batches if batch size is large
            outputs = []
            sub_batch_size = 1000
            
            for i in range(0, x.shape[0], sub_batch_size):
                end_idx = min(i + sub_batch_size, x.shape[0])
                sub_batch = x[i:end_idx]
                sub_mask = src_key_padding_mask[i:end_idx] if src_key_padding_mask is not None else None
                
                sub_output = self.transformer(sub_batch, src_key_padding_mask=sub_mask)
                outputs.append(sub_output)
                
            return torch.cat(outputs, dim=0)
        else:
            return self.transformer(x, src_key_padding_mask=src_key_padding_mask)

# NEW: Memory-efficient CrossAttention
class CrossAttention(nn.Module):
    def __init__(self, spatial_dim, temporal_dim, heads=2, dropout=0.1):
        super(CrossAttention, self).__init__()
        # Project to same dimension
        self.project_spatial = nn.Linear(spatial_dim, spatial_dim) if spatial_dim != temporal_dim else nn.Identity()
        self.project_temporal = nn.Linear(temporal_dim, spatial_dim) if temporal_dim != spatial_dim else nn.Identity()
        
        # Multi-head attention with fewer heads
        self.multihead_attn = nn.MultiheadAttention(
            embed_dim=spatial_dim,
            num_heads=heads,
            dropout=dropout,
            batch_first=True
        )
        
        # Layer normalization and feed-forward network
        self.norm1 = nn.LayerNorm(spatial_dim)
        self.norm2 = nn.LayerNorm(spatial_dim)
        self.dropout = nn.Dropout(dropout)
        self.ffn = nn.Sequential(
            nn.Linear(spatial_dim, spatial_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(spatial_dim * 2, spatial_dim)
        )
        
        # FIXED: Output projection with correct dimensions
        # Use spatial_dim + temporal_dim to match concatenated dimensions
        self.output_proj = nn.Linear(spatial_dim + temporal_dim, spatial_dim * 2)

    def forward(self, spatial_features, temporal_context):
        # Check for empty inputs
        if spatial_features.shape[0] == 0:
            output_dim = self.output_proj.out_features
            return torch.zeros((0, output_dim), device=spatial_features.device)
        
        # Project features if needed
        spatial_proj = self.project_spatial(spatial_features)
        temporal_proj = self.project_temporal(temporal_context)
        
        # Reshape temporal context if needed
        if temporal_proj.dim() == 2:
            # Add sequence dimension (treat as single token)
            temporal_proj = temporal_proj.unsqueeze(1)
        
        # Multi-head attention - spatial as query, temporal as key/value
        spatial_proj_seq = spatial_proj.unsqueeze(1) if spatial_proj.dim() == 2 else spatial_proj
        
        # Process in smaller batches if input is large
        if spatial_proj_seq.shape[0] > 1000:
            outputs = []
            sub_batch_size = 1000
            
            for i in range(0, spatial_proj_seq.shape[0], sub_batch_size):
                end_idx = min(i + sub_batch_size, spatial_proj_seq.shape[0])
                
                # Extract sub-batches
                sub_spatial = spatial_proj_seq[i:end_idx]
                sub_temporal = temporal_proj[i:end_idx] if temporal_proj.shape[0] > 1 else temporal_proj
                
                # Process sub-batch
                sub_attn_output, _ = self.multihead_attn(
                    query=sub_spatial,
                    key=sub_temporal,
                    value=sub_temporal
                )
                outputs.append(sub_attn_output)
                
            attn_output = torch.cat(outputs, dim=0)
        else:
            attn_output, _ = self.multihead_attn(
                query=spatial_proj_seq,
                key=temporal_proj,
                value=temporal_proj
            )
        
        # Remove sequence dimension if needed
        if spatial_features.dim() == 2:
            attn_output = attn_output.squeeze(1)
        
        # Residual connection and normalization
        spatial_features = self.norm1(spatial_proj + self.dropout(attn_output))
        
        # FFN
        ffn_output = self.ffn(spatial_features)
        
        # Second residual connection
        spatial_features = self.norm2(spatial_features + self.dropout(ffn_output))
        
        # Concatenate with temporal
        final_output = torch.cat([
            spatial_features, 
            temporal_context.squeeze(1) if temporal_context.dim() == 3 else temporal_context
        ], dim=1)
        
        # Apply final projection
        return self.output_proj(final_output)

# Simplified and memory-efficient GNN
class SimplifiedListingGNN(nn.Module):
    def __init__(self, spatial_dim, temporal_dim, price_history_dim, hidden_dim=32, heads=2):
        super(SimplifiedListingGNN, self).__init__()
        
        # Spatial component
        self.gat = GATv2Conv(spatial_dim, hidden_dim, heads=heads, edge_dim=1)
        
        # Temporal encoder (transformer-based)
        self.time_series_encoder = TimeSeriesEncoder(
            input_dim=price_history_dim,
            hidden_dim=hidden_dim,
            num_heads=heads
        )
        
        # Calculate proper dimensions
        gat_output_dim = hidden_dim * heads  # 64 with default values
        
        # CrossAttention between spatial and temporal
        self.cross_attention = CrossAttention(gat_output_dim, hidden_dim)
        
        # FIXED: Output projection needs to match CrossAttention output
        # CrossAttention outputs spatial_dim*2 (64*2 = 128) features
        self.output_mlp = nn.Sequential(
            nn.Linear(gat_output_dim * 2, hidden_dim),  # 128 -> 32
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, data, batch_indices=None):
        # Get the device of the model
        model_device = next(self.parameters()).device
        
        # If batch_indices is provided, use only those indices
        if batch_indices is not None:
            # CRITICAL FIX: Move indices to the same device as the data (CPU) for indexing
            cpu_indices = batch_indices.cpu()
            
            # Use CPU indices to index CPU tensors
            batch_x = data.x[cpu_indices]
            
            # After indexing, move results to the model's device
            batch_x = batch_x.to(model_device)
            
            # Get filtered edges
            batch_edge_index, batch_edge_attr = self.filter_edges(data.edge_index, data.edge_attr, cpu_indices)
            
            # Move edges to the model's device
            batch_edge_index = batch_edge_index.to(model_device)
            batch_edge_attr = batch_edge_attr.to(model_device)
            
            # Process spatial data with GAT
            spatial_features = self.gat(batch_x, batch_edge_index, edge_attr=batch_edge_attr)
            
            # Get temporal data for batch - also use CPU indices for indexing
            if hasattr(data, 'price_history'):
                batch_history = data.price_history[cpu_indices].to(model_device)
                batch_mask = data.price_history_mask[cpu_indices].to(model_device) if hasattr(data, 'price_history_mask') else None
                
                # Process temporal data with transformer
                temporal_features = self.time_series_encoder(batch_history, src_key_padding_mask=batch_mask)
                
                # Get the last non-masked position or just the last one
                if batch_mask is not None:
                    # Find the last non-masked position
                    non_mask_indices = (~batch_mask).sum(dim=1) - 1
                    non_mask_indices = torch.clamp(non_mask_indices, min=0)
                    
                    # Get embeddings at these positions
                    batch_indices_local = torch.arange(len(batch_indices), device=model_device)
                    temporal_context = temporal_features[batch_indices_local, non_mask_indices]
                else:
                    # Use the last timestep
                    temporal_context = temporal_features[:, -1]
            else:
                # Create dummy temporal context
                temporal_context = torch.zeros((len(batch_indices), self.time_series_encoder.embedding.out_features), 
                                        device=model_device)
        else:
            # Process entire graph (original code)
            # Process spatial data with GAT
            spatial_features = self.gat(data.x, data.edge_index, edge_attr=data.edge_attr)
            
            # Check if price_history attribute exists
            if hasattr(data, 'price_history'):
                # Process temporal data with transformer
                temporal_features = self.time_series_encoder(data.price_history, 
                                                          src_key_padding_mask=data.price_history_mask 
                                                          if hasattr(data, 'price_history_mask') else None)
                
                # Use the last non-masked timestep or just the last one
                batch_size = data.x.shape[0]
                
                if hasattr(data, 'price_history_mask') and data.price_history_mask is not None:
                    # Find the last non-masked position
                    non_mask_indices = (~data.price_history_mask).sum(dim=1) - 1
                    non_mask_indices = torch.clamp(non_mask_indices, min=0)
                    
                    # Get embeddings at these positions
                    batch_indices = torch.arange(batch_size, device=data.x.device)
                    temporal_context = temporal_features[batch_indices, non_mask_indices]
                else:
                    # Use the last timestep
                    temporal_context = temporal_features[:, -1]
            else:
                # Create dummy temporal context if price_history is missing
                temporal_context = torch.zeros((data.x.shape[0], self.time_series_encoder.embedding.out_features), 
                                          device=data.x.device)
        
        # Cross-attention between spatial and temporal
        fused_features = self.cross_attention(spatial_features, temporal_context)
        
        # Prediction
        return self.output_mlp(fused_features)
    
    def filter_edges(self, edge_index, edge_attr, node_indices):
        """Filter edges to keep only those connecting nodes in the batch"""
        # Get the model's device
        model_device = next(self.parameters()).device
        
        # Convert node_indices to CPU for set operations
        node_indices_set = set(node_indices.cpu().numpy())
        
        # Find edges where both source and target are in node_indices
        mask = []
        edge_index_np = edge_index.cpu().numpy()
        
        for i in range(edge_index.shape[1]):
            src, dst = edge_index_np[0, i], edge_index_np[1, i]
            if src in node_indices_set and dst in node_indices_set:
                mask.append(i)
        
        # Create new edge_index and edge_attr
        if not mask:
            # Return empty tensors on the model's device
            return (torch.zeros((2, 0), dtype=edge_index.dtype, device=model_device),
                torch.zeros((0, edge_attr.shape[1]), dtype=edge_attr.dtype, device=model_device))
        
        # Get the filtered edges
        filtered_edge_index = edge_index[:, mask]
        filtered_edge_attr = edge_attr[mask]
        
        # Create remapped edge_index
        node_remap = {old_idx: new_idx for new_idx, old_idx in enumerate(node_indices.cpu().numpy())}
        
        # Create remapped edge_index
        remapped_edge_index = torch.zeros_like(filtered_edge_index, device=model_device)
        for i in range(filtered_edge_index.shape[1]):
            remapped_edge_index[0, i] = node_remap[filtered_edge_index[0, i].item()]
            remapped_edge_index[1, i] = node_remap[filtered_edge_index[1, i].item()]
        
        # Move to model device
        return remapped_edge_index.to(model_device), filtered_edge_attr.to(model_device)

# MODIFIED: Memory-efficient data preparation
def prepare_simplified_graph_data(train_data, test_data, feature_groups, device, sequence_length=10):
    """Prepare graph data with direct access to raw time series, keeping data on CPU initially"""
    print("Preparing graph data...")
    
    # Create scalers for features
    scalers = {}
    
    # Scale the target (price)
    target_scaler = StandardScaler()
    train_data['price'] = target_scaler.fit_transform(train_data[['price']])
    if 'price' in test_data.columns:
        test_data['price'] = target_scaler.transform(test_data[['price']])
    scalers['target'] = target_scaler
    
    # Process basic features
    train_features = extract_basic_features(train_data, feature_groups)
    test_features = extract_basic_features(test_data, feature_groups)
    
    # Create simple price histories (but keep on CPU)
    print("Creating price histories...")
    train_histories = create_simple_price_histories(train_data, sequence_length)
    test_histories = create_simple_price_histories(test_data, sequence_length)
    
    # Get a sample history to determine dimensions
    sample_history = None
    for hist in train_histories.values():
        if hist is not None and hist.shape[0] > 0:
            sample_history = hist
            break
            
    if sample_history is None:
        for hist in test_histories.values():
            if hist is not None and hist.shape[0] > 0:
                sample_history = hist
                break
    
    # Prepare batch price history tensor - but keep on CPU
    all_listing_ids = np.concatenate([
        train_data['listing_id'].values,
        test_data['listing_id'].values
    ])
    
    if sample_history is not None:
        # Determine dimensions from sample
        seq_len, feature_dim = sample_history.shape
        batch_size = len(all_listing_ids)
        
        # Create batch tensor on CPU
        batch_histories = torch.zeros((batch_size, seq_len, feature_dim))
        
        # Combine histories
        all_histories = {**train_histories, **test_histories}
        
        # Fill batch tensor
        for i, lid in enumerate(all_listing_ids):
            if lid in all_histories and all_histories[lid] is not None:
                history = all_histories[lid]
                if history.shape[0] > 0:
                    # Make sure dimensions match
                    actual_seq_len = min(history.shape[0], seq_len)
                    batch_histories[i, :actual_seq_len, :] = history[:actual_seq_len]
    else:
        # Fallback if no valid histories found
        batch_histories = torch.zeros((len(all_listing_ids), 1, 1))
    
    # Build graph structure - this stays on CPU initially
    print("Building spatial graph...")
    edge_index, edge_attr = build_enhanced_spatial_graph(train_data, test_data, k=3)  # Reduced neighbors
    
    # Create masks and combine data
    train_mask = torch.zeros(len(train_features) + len(test_features), dtype=torch.bool)
    train_mask[:len(train_features)] = True
    
    val_mask = torch.zeros(len(train_features) + len(test_features), dtype=torch.bool)
    val_mask[len(train_features):] = True
    
    # Combine features
    all_features = np.vstack([train_features, test_features])
    
    # Get target values
    train_y = train_data['price'].values
    test_y = test_data['price'].values if 'price' in test_data.columns else np.zeros(len(test_features))
    all_y = np.concatenate([train_y, test_y])
    
    # Create PyG Data object with price_history attribute - all on CPU
    data = Data(
        x=torch.FloatTensor(all_features),
        edge_index=edge_index,
        edge_attr=edge_attr,
        y=torch.FloatTensor(all_y.reshape(-1, 1)),
        train_mask=train_mask,
        val_mask=val_mask,
        listing_ids=torch.LongTensor(all_listing_ids),
        price_history=batch_histories,
        price_history_mask=(batch_histories.sum(dim=-1) == 0)  # Add mask for attention
    )
    
    # Add other feature types - keep on CPU
    data.temporal_x = extract_temporal_features(train_data, test_data, feature_groups, 'cpu')
    data.amenity_x = extract_amenity_features(train_data, test_data, feature_groups, 'cpu')
    data.price_history_x = extract_price_history_features(train_data, test_data, feature_groups, 'cpu')
    
    return data, scalers, train_histories

# NEW: Function to create mini-batches from train indices
# Fix: Update the create_batch_data function to properly handle device
def create_batch_data(graph_data, batch_indices, device):
    """Create a sub-batch of data for specified indices and move to device"""
    # Create indices tensor but keep on CPU initially for indexing
    indices_tensor = torch.tensor(batch_indices)
    
    # Extract features for this batch using CPU indices
    # (since graph_data tensors are on CPU)
    batch_x = graph_data.x[indices_tensor]
    batch_y = graph_data.y[indices_tensor]
    batch_price_history = graph_data.price_history[indices_tensor]
    batch_price_history_mask = graph_data.price_history_mask[indices_tensor]
    
    # After indexing is complete, THEN move everything to the target device
    if device is not None:
        indices_tensor = indices_tensor.to(device)
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)
        batch_price_history = batch_price_history.to(device)
        batch_price_history_mask = batch_price_history_mask.to(device)
    
    return {
        'x': batch_x,
        'y': batch_y,
        'price_history': batch_price_history,
        'price_history_mask': batch_price_history_mask,
        'indices': indices_tensor
    }

# MODIFIED: Memory-efficient training function with mini-batches
def train_simplified_gnn_model(train_data, val_data, feature_groups, device='cuda', 
                              hidden_dim=32, epochs=30, lr=0.001, sequence_length=10,
                              batch_size=16384, grad_accum_steps=2):  # Added batch_size
    """Train the simplified GNN model with transformer-based temporal processing using mini-batches"""
    # Prepare data - keep on CPU
    graph_data, scalers, histories = prepare_simplified_graph_data(
        train_data, val_data, feature_groups, 'cpu', sequence_length
    )
    
    # Get input dimensions
    sample_history = next(iter(histories.values())) if histories else None
    price_history_dim = sample_history.shape[1] if sample_history is not None else 1
    spatial_dim = graph_data.x.shape[1]
    
    print(f"Input dimensions - Spatial: {spatial_dim}, Price history: {price_history_dim}")
    
    # Initialize model with reduced dimensions
    model = SimplifiedListingGNN(
        spatial_dim=spatial_dim,
        temporal_dim=hidden_dim,
        price_history_dim=price_history_dim,
        hidden_dim=hidden_dim,
        heads=2  # Reduced from 4
    ).to(device)
    
    # Initialize optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    criterion = nn.HuberLoss(delta=1.0)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )
    
    # Initialize gradient scaler for mixed precision
    scaler = GradScaler()
    
    # Get train indices
    train_indices = torch.where(graph_data.train_mask)[0].cpu().numpy()
    val_indices = torch.where(graph_data.val_mask)[0].cpu().numpy()
    
    # Calculate number of batches
    num_train_batches = int(np.ceil(len(train_indices) / batch_size))
    
    # Training loop
    best_val_loss = float('inf')
    best_model_state = None
    patience = 10
    counter = 0
    
    # Move small parts of the graph to device as needed
    edge_index = graph_data.edge_index.to(device)
    edge_attr = graph_data.edge_attr.to(device)
    
    print(f"Training with {num_train_batches} mini-batches per epoch, batch size: {batch_size}")
    
    # Wrap graph_data in a simple object with to(device) method for compatibility
    # Fix the GraphData class to properly expose attributes
    class GraphData:
        def __init__(self, data, edge_index, edge_attr):
            self.data = data
            self.edge_index = edge_index
            self.edge_attr = edge_attr
            
        def __getattr__(self, name):
            # Forward attribute access to the data object
            if hasattr(self.data, name):
                return getattr(self.data, name)
            raise AttributeError(f"'GraphData' object has no attribute '{name}'")
            
        def to(self, device):
            return self
    
    graph_obj = GraphData(graph_data, edge_index, edge_attr)
    
    for epoch in range(epochs):
        # Training
        model.train()
        total_loss = 0
        
        # Shuffle train indices for each epoch
        np.random.shuffle(train_indices)
        
        # Zero gradients once at the beginning
        optimizer.zero_grad()
        accumulated_batches = 0
        
        # Process in mini-batches
        for batch_idx in range(num_train_batches):
            start_idx = batch_idx * batch_size
            end_idx = min(start_idx + batch_size, len(train_indices))
            batch_train_indices = train_indices[start_idx:end_idx]
            
            # Create batch data and move to device
            batch_data = create_batch_data(graph_data, batch_train_indices, device)
            
            # Forward pass with mixed precision
            with autocast():
                # Model will handle moving tensors to the correct device
                out = model(graph_obj, batch_data['indices'])
                # Move y to the same device as the output
                y = batch_data['y'].to(out.device)
                # Scale loss by accumulation steps
                loss = criterion(out, y) / grad_accum_steps
            
            # Backward pass with gradient scaling
            scaler.scale(loss).backward()
            
            # Track total loss (multiply back by grad_accum_steps for reporting)
            total_loss += loss.item() * grad_accum_steps
            
            accumulated_batches += 1
            
            # Only optimize after several batches or at the end
            if accumulated_batches == grad_accum_steps or batch_idx == num_train_batches - 1:
                # Update weights now
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                accumulated_batches = 0
            
            # Free up memory
            del batch_data
            torch.cuda.empty_cache()
        
        avg_loss = total_loss / num_train_batches
        
        # Validation - also in batches
        model.eval()
        val_loss = 0
        val_preds = []
        val_targets = []
        
        # Calculate validation in smaller batches
        val_batch_size = batch_size
        num_val_batches = int(np.ceil(len(val_indices) / val_batch_size))
        
        with torch.no_grad():
            for batch_idx in range(num_val_batches):
                start_idx = batch_idx * val_batch_size
                end_idx = min(start_idx + val_batch_size, len(val_indices))
                batch_val_indices = val_indices[start_idx:end_idx]
                
                # Create batch data and move to device
                batch_data = create_batch_data(graph_data, batch_val_indices, device)
                
                # Forward pass with mixed precision
                with autocast():
                    val_out = model(graph_obj, batch_data['indices'].to(device))
                    batch_val_loss = criterion(val_out, batch_data['y'])
                
                val_loss += batch_val_loss.item() * len(batch_val_indices)
                
                # Store predictions and targets
                val_preds.append(val_out.cpu().numpy())
                val_targets.append(batch_data['y'].cpu().numpy())
                
                # Free memory
                del batch_data
                torch.cuda.empty_cache()
        
        # Combine validation results
        val_preds = np.vstack(val_preds)
        val_targets = np.vstack(val_targets)
        
        # Calculate validation loss and metrics
        val_loss = val_loss / len(val_indices)
        
        # Convert predictions to original scale for metrics
        val_pred_orig = np.expm1(scalers['target'].inverse_transform(val_preds))
        val_true_orig = np.expm1(scalers['target'].inverse_transform(val_targets))
        
        # Calculate metrics
        val_rmse = np.sqrt(mean_squared_error(val_true_orig, val_pred_orig))
        val_mae = mean_absolute_error(val_true_orig, val_pred_orig)
        
        # Print progress
        print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}, Val Loss: {val_loss:.4f}, "
              f"RMSE: {val_rmse:.2f}, MAE: {val_mae:.2f}")
        
        # Learning rate scheduling
        scheduler.step(val_loss)
        
        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict().copy()
            counter = 0
        else:
            counter += 1
            
        if counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break
        
        # Clear memory
        torch.cuda.empty_cache()
    
    # Load best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    
    return model, scalers

# MODIFIED: Memory-efficient prediction function
def predict_with_simplified_gnn(model, test_data, feature_groups, scalers, train_data, device, sequence_length=10, batch_size=256):
    """Make predictions with the simplified GNN model using batching"""
    # Prepare data
    graph_data, _, _ = prepare_simplified_graph_data(
        train_data, test_data, feature_groups, 'cpu', sequence_length
    )
    
    # Get test indices
    test_indices = torch.where(graph_data.val_mask)[0].cpu().numpy()
    
    # Calculate number of batches
    num_batches = int(np.ceil(len(test_indices) / batch_size))
    
    # Prepare for predictions
    all_predictions = []
    
    # Move edge info to device
    edge_index = graph_data.edge_index.to(device)
    edge_attr = graph_data.edge_attr.to(device)
    
    # Create GraphData object
    # Fix the GraphData class to properly expose attributes
    class GraphData:
        def __init__(self, data, edge_index, edge_attr):
            self.data = data
            self.edge_index = edge_index
            self.edge_attr = edge_attr
            
        def __getattr__(self, name):
            # Forward attribute access to the data object
            if hasattr(self.data, name):
                return getattr(self.data, name)
            raise AttributeError(f"'GraphData' object has no attribute '{name}'")
            
        def to(self, device):
            return self
    
    graph_obj = GraphData(graph_data, edge_index, edge_attr)
    
    # Make predictions in batches
    model.eval()
    with torch.no_grad():
        for batch_idx in range(num_batches):
            start_idx = batch_idx * batch_size
            end_idx = min(start_idx + batch_size, len(test_indices))
            batch_test_indices = test_indices[start_idx:end_idx]
            
            # Create batch data and move to device
            batch_data = create_batch_data(graph_data, batch_test_indices, device)
            
            # Forward pass with mixed precision
            with autocast():
                batch_preds = model(graph_obj, batch_data['indices'].to(device))
            
            # Store predictions
            all_predictions.append(batch_preds.cpu().numpy())
            
            # Free memory
            del batch_data
            torch.cuda.empty_cache()
    
    # Combine predictions
    predictions = np.vstack(all_predictions)
    
    # Transform back to original scale
    predictions_np = scalers['target'].inverse_transform(predictions)
    
    # Inverse log transformation
    predictions_orig = np.expm1(predictions_np)
    
    return predictions_orig

# MODIFIED: Run with smaller batches and reduced dimensions
def run_simplified_gnn_with_rolling_window_cv(train_path, train_ids_path, test_ids_path, output_dir=None, 
                                            window_size=35, n_splits=5, sample_size=None, sequence_length=10,
                                            batch_size=256, hidden_dim=32):
    """
    Run simplified GNN model with rolling window cross-validation using transformer-based temporal processing
    """
    print(f"Processing dataset: {os.path.basename(train_path)}")
    
    # Create output directory if not exists
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
    
    # Load training data
    print("Loading data...")
    train_data = pd.read_csv(train_path)

    # Load listing IDs for train/test split
    print("Loading train/test listing IDs...")
    with open(train_ids_path, 'r') as f:
        train_listing_ids = [int(line.strip()) for line in f.readlines()]
        
    with open(test_ids_path, 'r') as f:
        test_listing_ids = [int(line.strip()) for line in f.readlines()]
    
    print(f"Loaded {len(train_listing_ids)} train IDs and {len(test_listing_ids)} test IDs")

    # Drop legacy price columns if they exist
    price_cols_to_remove = ['price_lag_1d', 'simulated_price']
    for col in price_cols_to_remove:
        if col in train_data.columns:
            print(f"Dropping {col} column from the dataset")
            train_data = train_data.drop(col, axis=1)
    
    # For testing - take only a small sample of listings if specified
    if sample_size:
        print(f"Limiting to {sample_size} random listings for testing")
        np.random.seed(42)
        selected_train = np.random.choice(train_listing_ids, int(sample_size * 0.7), replace=False)
        selected_test = np.random.choice(test_listing_ids, int(sample_size * 0.3), replace=False)
        train_listing_ids = selected_train.tolist()
        test_listing_ids = selected_test.tolist()
    
    # Convert date column to datetime
    train_data['date'] = pd.to_datetime(train_data['date'])

    # Filter data to include only dates from 7/8/23 till 2/8/24
    start_date = pd.to_datetime('2023-07-08')
    end_date = pd.to_datetime('2024-02-08')
    train_data = train_data[(train_data['date'] >= start_date) & (train_data['date'] <= end_date)]
    
    # Apply log transformation to price
    train_data = apply_price_transformation(train_data)
    
    # Create calculated features
    print("Creating calculated features...")
    train_data = create_calculated_features(train_data)
    
    # Check for NaN values in the dataset and fill them
    nan_columns = train_data.columns[train_data.isna().any()].tolist()
    if nan_columns:
        print(f"Warning: Found NaN values in columns: {nan_columns}")
        print("Filling NaN values with column means/medians")
        
        for col in nan_columns:
            if np.issubdtype(train_data[col].dtype, np.number):
                # Fill with median for numeric columns
                train_data[col] = train_data[col].fillna(train_data[col].median())
            else:
                # For non-numeric, fill with mode
                train_data[col] = train_data[col].fillna(train_data[col].mode()[0])
    
    # Define feature groups - using a dictionary structure for the simplified approach
    feature_groups = {
        'spatial': ['latitude', 'longitude'],
        'property': ['accommodates', 'bedrooms', 'bathrooms', 'essential_score', 'luxury_score', 'amenity_count'],
        'amenity': [col for col in train_data.columns if col.startswith('has_')],
        'temporal': ['DTF_day_of_week', 'DTF_month', 'DTF_is_weekend', 'DTF_season_sin', 'DTF_season_cos'],
    }
    
    # Add additional temporal features if available
    additional_temporal = ['DTF_day', 'DTF_is_holiday', 'DTF_days_to_weekend', 
                         'DTF_days_since_start', 'DTF_days_to_end_month']
    for feat in additional_temporal:
        if feat in train_data.columns:
            feature_groups['temporal'].append(feat)
    
    # Ensure all feature groups only contain columns that exist in the dataset
    for group in feature_groups:
        feature_groups[group] = [f for f in feature_groups[group] if f in train_data.columns]
    
    # Get unique dates and create test periods
    unique_dates = sorted(train_data['date'].dt.date.unique())
    last_35_days = unique_dates[-window_size:]
    
    # Define explicit test periods
    test_periods = []
    for i in range(n_splits):
        start_idx = i * (window_size // n_splits)
        end_idx = start_idx + (window_size // n_splits)
        if end_idx <= len(last_35_days):
            test_periods.append((last_35_days[start_idx], last_35_days[end_idx-1]))
    
    n_splits = len(test_periods)
    
    print(f"Created {n_splits} test periods:")
    for i, (test_start, test_end) in enumerate(test_periods):
        print(f"  Period {i+1}: {test_start} to {test_end}")
    
    # Storage for results
    cv_results = []
    all_predictions = []
    all_targets = []
    split_metrics = []
    
    # Initialize device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Set memory optimization configurations
    if device.type == 'cuda':
        torch.cuda.empty_cache()
        # Set memory growth strategy
        os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
    
    # Print model configuration
    print(f"Using simplified GNN with transformer-based temporal processing")
    print(f"Sequence length: {sequence_length}, Batch size: {batch_size}, Hidden dim: {hidden_dim}")
    
    # Run time series cross-validation using our explicit test periods
    for i, (test_start, test_end) in enumerate(test_periods):
        print(f"\n===== Split {i+1}/{n_splits} =====")
        
        # Define training period: everything before test_start
        train_end = pd.to_datetime(test_start) - pd.Timedelta(days=1)
        train_end_date = train_end.date()
        
        print(f"Training period: {unique_dates[0]} to {train_end_date}")
        print(f"Testing period: {test_start} to {test_end}")
        
        # Split by date first
        train_date_mask = train_data['date'].dt.date <= train_end_date
        test_date_mask = (train_data['date'].dt.date >= test_start) & (train_data['date'].dt.date <= test_end)
        
        date_filtered_train = train_data[train_date_mask]
        date_filtered_test = train_data[test_date_mask]
        
        # Now further split by listing IDs
        train_id_mask = date_filtered_train['listing_id'].isin(train_listing_ids)
        test_id_mask = date_filtered_test['listing_id'].isin(test_listing_ids)
        
        split_train_data = date_filtered_train[train_id_mask].copy()
        split_test_data = date_filtered_test[test_id_mask].copy()
        
        print(f"Train data: {len(split_train_data)} rows, {len(split_train_data['listing_id'].unique())} unique listings")
        print(f"Test data: {len(split_test_data)} rows, {len(split_test_data['listing_id'].unique())} unique listings")
        
        # Check if we have enough data for this split
        if len(split_train_data) < 100 or len(split_test_data) < 10:
            print(f"Insufficient data for split {i+1}, skipping")
            continue
            
        # Split train data into train and validation
        unique_train_listings = split_train_data['listing_id'].unique()
        train_listings, val_listings = train_test_split(
            unique_train_listings, test_size=0.2, random_state=42
        )
        
        train_subset = split_train_data[split_train_data['listing_id'].isin(train_listings)].copy()
        val_subset = split_train_data[split_train_data['listing_id'].isin(val_listings)].copy()
        
        # Train simplified GNN model
        try:
            print(f"\n----- Training Simplified GNN Model (Split {i+1}) -----")
            
            # Clear GPU memory before training
            torch.cuda.empty_cache()
            
            # Train the model with memory-efficient approach
            model, scalers = train_simplified_gnn_model(
                train_subset, val_subset, feature_groups, 
                device=device, 
                hidden_dim=hidden_dim, 
                epochs=30,  # Reduced from 50 
                lr=0.001, 
                sequence_length=sequence_length,
                batch_size=batch_size
            )
            
            # Evaluate on test data
            print(f"\n----- Evaluating Simplified GNN on Test Data (Split {i+1}) -----")
            test_predictions = predict_with_simplified_gnn(
                model, split_test_data, feature_groups, scalers,
                train_subset, device, sequence_length=sequence_length,
                batch_size=batch_size
            )
            
            # Get actual test values (original scale)
            test_actuals = split_test_data['original_price'].values if 'original_price' in split_test_data.columns else split_test_data['price'].values
            
            # Evaluate predictions
            metrics = evaluate_gnn_predictions(test_actuals, test_predictions.flatten(), print_results=True)
            
            print(f"Split {i+1} Results - RMSE: {metrics['rmse']:.4f}, MAE: {metrics['mae']:.4f}, R²: {metrics['r2']:.4f}")
            
            # Store results for this split
            split_results = pd.DataFrame({
                'split': i,
                'date': split_test_data['date'],
                'listing_id': split_test_data['listing_id'],
                'price': test_actuals,
                'predicted': test_predictions.flatten(),
                'error': test_actuals - test_predictions.flatten(),
                'abs_error': np.abs(test_actuals - test_predictions.flatten()),
                'pct_error': np.abs((test_actuals - test_predictions.flatten()) / (test_actuals + 1e-8)) * 100
            })
            
            cv_results.append(split_results)
            all_predictions.extend(test_predictions.flatten())
            all_targets.extend(test_actuals)
            
            # Save model for this split if output_dir is provided
            if output_dir:
                model_path = os.path.join(output_dir, f'simplified_gnn_model_split_{i+1}.pt')
                torch.save(model.state_dict(), model_path)
                print(f"Model for split {i+1} saved to {model_path}")
            
            # Store metrics for this split
            split_metrics.append({
                'split': i,
                'rmse': metrics['rmse'],
                'mae': metrics['mae'],
                'r2': metrics['r2'],
                'mape': metrics['mape'],
                'n_samples': len(test_actuals)
            })
            
            # Clear memory after each split
            del model, scalers
            torch.cuda.empty_cache()
            
        except Exception as e:
            print(f"Error in split {i+1}: {str(e)}")
            import traceback
            traceback.print_exc()
            continue
    
    # Process results - similar to the existing function
    if not cv_results:
        print("No valid splits completed. Check your data and parameters.")
        return None
        
    all_results = pd.concat(cv_results, ignore_index=True)
    
    # Calculate overall metrics
    all_targets_array = np.array(all_targets)
    all_predictions_array = np.array(all_predictions)
    
    overall_metrics = {
        'rmse': np.sqrt(mean_squared_error(all_targets_array, all_predictions_array)),
        'mae': mean_absolute_error(all_targets_array, all_predictions_array),
        'r2': r2_score(all_targets_array, all_predictions_array),
        'mape': np.mean(np.abs((all_targets_array - all_predictions_array) / (all_targets_array + 1e-8))) * 100
    }
    
    # Calculate daily metrics
    all_results['date_str'] = pd.to_datetime(all_results['date']).dt.strftime('%Y-%m-%d')
    
    daily_metrics = []
    for day, group in all_results.groupby('date_str'):
        y_true_day = group['price']
        y_pred_day = group['predicted']
        
        daily_metrics.append({
            'date': day,
            'rmse': np.sqrt(mean_squared_error(y_true_day, y_pred_day)),
            'mae': mean_absolute_error(y_true_day, y_pred_day),
            'r2': r2_score(y_true_day, y_pred_day) if len(set(y_true_day)) > 1 else np.nan,
            'mape': np.mean(np.abs((y_true_day - y_pred_day) / (y_true_day + 1e-8))) * 100,
            'n_samples': len(y_true_day)
        })
    
    daily_metrics_df = pd.DataFrame(daily_metrics)
    daily_metrics_df['date'] = pd.to_datetime(daily_metrics_df['date'])
    daily_metrics_df = daily_metrics_df.sort_values('date')
    
    split_metrics_df = pd.DataFrame(split_metrics)
    
    # Create a results dictionary
    evaluation_results = {
        'overall_metrics': overall_metrics,
        'split_metrics': split_metrics_df,
        'daily_metrics': daily_metrics_df,
        'all_results': all_results,
        'train_listings': len(train_listing_ids),
        'test_listings': len(test_listing_ids),
        'config': {
            'model_type': 'simplified_memory_efficient_gnn',
            'sequence_length': sequence_length,
            'batch_size': batch_size,
            'hidden_dim': hidden_dim
        }
    }
    
    # Print summary
    print("\n===== MEMORY-EFFICIENT TRANSFORMER-GNN MODEL SUMMARY =====")
    print(f"Using {len(train_listing_ids)} listings for training and {len(test_listing_ids)} listings for testing")
    
    print("\n=== Overall Metrics ===")
    print(f"RMSE: {overall_metrics['rmse']:.4f}")
    print(f"MAE: {overall_metrics['mae']:.4f}")
    print(f"R²: {overall_metrics['r2']:.4f}")
    print(f"MAPE: {overall_metrics['mape']:.4f}%")
    
    print("\n=== Split Performance ===")
    print(split_metrics_df[['split', 'rmse', 'mae', 'r2', 'n_samples']].to_string(index=False))
    
    return evaluation_results

if __name__ == "__main__":
    # Set paths to your data
    train_path = r"C:\Users\mvk\Documents\DATA_school\thesis\Subset\top_price_changers_subset\train_up3.csv" 
    train_ids_path = r"C:\Users\mvk\Documents\DATA_school\thesis\Subset\top_price_changers_subset\train_ids.txt"
    test_ids_path = r"C:\Users\mvk\Documents\DATA_school\thesis\Subset\top_price_changers_subset\test_ids.txt"
    
    # Output directory
    output_dir = "./output/memory_efficient_gnn"
    os.makedirs(output_dir, exist_ok=True)
    
    try:
        # Run with memory-efficient approach
        results = run_simplified_gnn_with_rolling_window_cv(
            train_path=train_path,
            train_ids_path=train_ids_path,
            test_ids_path=test_ids_path,
            output_dir=output_dir,
            window_size=35,
            n_splits=5,
            sample_size=None,  # Use full dataset
            sequence_length=10,  # Reduced from 30
            batch_size=16384,     # Added batch size
            hidden_dim=32       # Reduced from 64
        )
        print(f"Memory-efficient transformer-based GNN model training completed successfully!")
        
    except Exception as e:
        print(f"Error running memory-efficient transformer-based GNN model: {str(e)}")
        import traceback
        traceback.print_exc()

Processing dataset: train_up3.csv
Loading data...
Loading train/test listing IDs...
Loaded 6291 train IDs and 1573 test IDs
Applying log transformation to price data
Creating calculated features...
Created 5 test periods:
  Period 1: 2024-01-05 to 2024-01-11
  Period 2: 2024-01-12 to 2024-01-18
  Period 3: 2024-01-19 to 2024-01-25
  Period 4: 2024-01-26 to 2024-02-01
  Period 5: 2024-02-02 to 2024-02-08
Using device: cuda
Using simplified GNN with transformer-based temporal processing
Sequence length: 10, Batch size: 16384, Hidden dim: 32

===== Split 1/5 =====
Training period: 2023-08-07 to 2024-01-04
Testing period: 2024-01-05 to 2024-01-11
Train data: 903142 rows, 6291 unique listings
Test data: 11011 rows, 1573 unique listings

----- Training Simplified GNN Model (Split 1) -----
Preparing graph data...
Creating price histories...
Building spatial graph...
Building enhanced spatial graph with 180551 test listings and 3 nearest neighbors...
Created graph with 1083306 edges
Input dime

KeyboardInterrupt: 