In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import Dataset, DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.neighbors import NearestNeighbors
import matplotlib.dates as mdates
from torch_geometric.nn import GATv2Conv
from torch_geometric.data import Data
import warnings
warnings.filterwarnings('ignore')
import math
import gc


# ====================== Utility Functions ======================
# Price transformation function
def apply_price_transformation(data, inverse=False):
    """
    Apply log transformation to price data or inverse the transformation
    """
    df = data.copy()
    
    if not inverse:
        # Apply log transformation
        print("Applying log transformation to price data")
        df['original_price'] = df['price']  # Store original price
        df['price'] = np.log1p(df['price'])  # log1p to handle zero values
    else:
        # Inverse transform
        print("Inverting log transformation for predictions")
        df['price'] = np.expm1(df['price'])  # expm1 is the inverse of log1p
    
    return df


# Create calculated features
def create_calculated_features(df):
    """
    Create calculated features for the dataset
    """
    # Create a copy to avoid modifying the original
    df_copy = df.copy()
    
    # Bedroom ratio
    if 'bedrooms' in df_copy.columns and 'accommodates' in df_copy.columns:
        df_copy['bedroom_ratio'] = df_copy['bedrooms'] / df_copy['accommodates'].clip(lower=1)
    
    # Count amenities
    amenity_columns = df_copy.filter(like='has_').columns
    if len(amenity_columns) > 0:
        df_copy['amenity_count'] = df_copy[amenity_columns].sum(axis=1)
    
    # Luxury score - use specific amenities from your dataset
    luxury_amenities = ['has_hot_water', 'has_hair_dryer', 'has_dedicated_workspace', 
                         'has_tv', 'has_wifi', 'has_shampoo']
    available_luxury = [col for col in luxury_amenities if col in df_copy.columns]
    
    if available_luxury:
        df_copy['luxury_score'] = df_copy[available_luxury].sum(axis=1) / len(available_luxury)
    else:
        df_copy['luxury_score'] = 0
    
    # Essential score - basic amenities that are essential
    essential_amenities = ['has_essentials', 'has_bed_linens', 'has_kitchen', 
                          'has_smoke_alarm', 'has_heating']
    available_essential = [col for col in essential_amenities if col in df_copy.columns]
    
    if available_essential:
        df_copy['essential_score'] = df_copy[available_essential].sum(axis=1) / len(available_essential)
    else:
        df_copy['essential_score'] = 0
    
    # Fill any NaN values that might have been created
    numeric_cols = df_copy.select_dtypes(include=['number']).columns
    for col in numeric_cols:
        if df_copy[col].isnull().any():
            df_copy[col] = df_copy[col].fillna(df_copy[col].median())
    
    return df_copy


# Function to evaluate predictions
def evaluate_predictions(y_true, y_pred, print_results=True):
    """
    Evaluate predictions using multiple metrics
    """
    # Calculate metrics
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    mae = mean_absolute_error(y_true, y_pred)
    r2 = r2_score(y_true, y_pred)
    mape = np.mean(np.abs((y_true - y_pred) / (y_true + 1e-8))) * 100
    
    # Store metrics in dictionary
    metrics = {
        'rmse': rmse,
        'mae': mae,
        'r2': r2,
        'mape': mape
    }
    
    # Print results if requested
    if print_results:
        print("=== Model Evaluation ===")
        print(f"RMSE: {rmse:.2f}")
        print(f"MAE: {mae:.2f}")
        print(f"R²: {r2:.4f}")
        print(f"MAPE: {mape:.2f}%")
    
    return metrics


# Function to plot results
def plot_results(y_true, y_pred, history=None, output_dir=None):
    """
    Plot prediction results
    """
    # Create figure
    plt.figure(figsize=(15, 10))
    
    # Plot 1: Actual vs Predicted
    plt.subplot(2, 2, 1)
    plt.scatter(y_true, y_pred, alpha=0.5)
    min_val = min(np.min(y_true), np.min(y_pred))
    max_val = max(np.max(y_true), np.max(y_pred))
    plt.plot([min_val, max_val], [min_val, max_val], 'r--')
    plt.title('Actual vs Predicted Prices')
    plt.xlabel('Actual Price')
    plt.ylabel('Predicted Price')
    
    # Add correlation coefficient
    corr = np.corrcoef(y_true, y_pred)[0, 1]
    plt.annotate(f'Correlation: {corr:.4f}', xy=(0.05, 0.95), xycoords='axes fraction',
                 bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.8))
    
    # Plot 2: Error Distribution
    plt.subplot(2, 2, 2)
    errors = y_true - y_pred
    plt.hist(errors, bins=50, alpha=0.7)
    plt.axvline(0, color='r', linestyle='--')
    plt.title('Error Distribution')
    plt.xlabel('Error (Actual - Predicted)')
    plt.ylabel('Frequency')
    
    # Add mean, median error
    mean_error = np.mean(errors)
    median_error = np.median(errors)
    plt.axvline(mean_error, color='g', linestyle='-', label=f'Mean: {mean_error:.2f}')
    plt.axvline(median_error, color='b', linestyle='-', label=f'Median: {median_error:.2f}')
    plt.legend()
    
    # Plot 3: Error vs Actual Price
    plt.subplot(2, 2, 3)
    plt.scatter(y_true, errors, alpha=0.5)
    plt.axhline(0, color='r', linestyle='--')
    plt.title('Error vs Actual Price')
    plt.xlabel('Actual Price')
    plt.ylabel('Error')
    
    # Plot 4: Percentage Error Distribution
    plt.subplot(2, 2, 4)
    pct_errors = np.abs(errors / (y_true + 1e-8)) * 100
    plt.hist(pct_errors, bins=50, alpha=0.7)
    
    # Mark median and mean
    median_pct = np.median(pct_errors)
    mean_pct = np.mean(pct_errors)
    plt.axvline(median_pct, color='r', linestyle='--', label=f'Median: {median_pct:.2f}%')
    plt.axvline(mean_pct, color='g', linestyle='--', label=f'Mean: {mean_pct:.2f}%')
    
    plt.title('Percentage Error Distribution')
    plt.xlabel('Percentage Error')
    plt.ylabel('Frequency')
    plt.legend()
    
    plt.tight_layout()
    
    # Save plot if output_dir is provided
    if output_dir:
        plt.savefig(os.path.join(output_dir, 'prediction_results.png'))
        print(f"Plot saved to {os.path.join(output_dir, 'prediction_results.png')}")
    
    plt.show()


# ====================== Neighbor Selection Functions ======================
def select_neighbors(data, k=5, feature_weight=0.3):
    """
    Select k nearest neighbors for each listing based on geographic and feature similarity
    """
    print(f"Selecting {k} neighbors for each listing...")
    
    # Extract coordinates
    coords = data[['latitude', 'longitude']].values
    
    # Extract and normalize key features for similarity
    features = ['accommodates', 'bedrooms', 'bathrooms']
    available_features = [f for f in features if f in data.columns]
    
    if available_features:
        scaler = StandardScaler()
        feature_values = scaler.fit_transform(data[available_features].fillna(0))
    else:
        # Fallback if no features are available
        print("Warning: No property features available for similarity calculation")
        feature_values = np.ones((len(coords), 1))
    
    # Find k+1 nearest neighbors for each listing (including itself)
    nn = NearestNeighbors(n_neighbors=min(k+1, len(coords)))
    nn.fit(coords)
    distances, indices = nn.kneighbors(coords)
    
    # For each listing, store its neighbors (excluding itself)
    neighbors_dict = {}
    
    for i, (neighbor_indices, neighbor_distances) in enumerate(zip(indices, distances)):
        # Skip the first index if it's the listing itself (distance = 0)
        if neighbor_distances[0] < 1e-8:
            neighbor_indices = neighbor_indices[1:k+1]
            neighbor_distances = neighbor_distances[1:k+1]
        else:
            # Take first k neighbors
            neighbor_indices = neighbor_indices[:k]
            neighbor_distances = neighbor_distances[:k]
            
        # Calculate feature similarity for each neighbor
        neighbors_with_similarity = []
        
        for j, neighbor_idx in enumerate(neighbor_indices):
            if j >= len(neighbor_distances):
                break
                
            geo_distance = neighbor_distances[j]
            
            # Calculate feature similarity (cosine similarity)
            listing_feat = feature_values[i]
            neighbor_feat = feature_values[neighbor_idx]
            
            feat_norm_product = np.linalg.norm(listing_feat) * np.linalg.norm(neighbor_feat)
            if feat_norm_product > 1e-8:
                feat_sim = np.dot(listing_feat, neighbor_feat) / feat_norm_product
            else:
                feat_sim = 0.0
                
            # Combined weight
            geo_weight = 1.0 / (geo_distance + 1e-6)  # Inverse distance
            combined_score = (1 - feature_weight) * geo_weight + feature_weight * max(0, feat_sim)
            
            neighbors_with_similarity.append((neighbor_idx, combined_score))
        
        # Sort by combined similarity and take top k
        neighbors_with_similarity.sort(key=lambda x: x[1], reverse=True)
        selected_neighbors = [n[0] for n in neighbors_with_similarity[:k]]
        
        # Store in dictionary
        neighbors_dict[i] = selected_neighbors
    
    print(f"Selected neighbors for {len(neighbors_dict)} listings")
    return neighbors_dict


def collect_neighbor_price_history(data, listing_ids, neighbors_dict, seq_length=30):
    """
    Collect historical price data from neighbors for each listing
    Returns a dictionary mapping listing indices to arrays of neighbor price histories
    """
    print(f"Collecting price history for neighbors (sequence length: {seq_length})...")
    
    # Create mapping from listing ID to index
    listing_id_to_idx = {id: idx for idx, id in enumerate(listing_ids)}
    idx_to_listing_id = {idx: id for idx, id in enumerate(listing_ids)}
    
    # Initialize dictionary to store neighbor histories
    neighbor_histories = {}
    
    # For each listing
    for idx, listing_id in enumerate(listing_ids):
        # Get neighbors of this listing
        if idx not in neighbors_dict:
            continue
            
        neighbors = neighbors_dict[idx]
        
        # For each date this listing has data for
        listing_dates = data[data['listing_id'] == listing_id]['date'].sort_values().unique()
        
        for date in listing_dates:
            # For each neighbor, collect price history up to this date
            neighbor_prices = []
            
            for neighbor_idx in neighbors:
                # Get neighbor's listing_id
                neighbor_id = idx_to_listing_id.get(neighbor_idx)
                if neighbor_id is None:
                    continue
                
                # Get neighbor's historical data up to this date
                neighbor_history = data[
                    (data['listing_id'] == neighbor_id) & 
                    (data['date'] < date)
                ].sort_values('date', ascending=False).head(seq_length)
                
                # Extract prices
                if len(neighbor_history) > 0:
                    prices = neighbor_history['price'].values
                    # Pad if shorter than seq_length
                    padded = np.pad(prices, (0, max(0, seq_length - len(prices))), 
                                    'constant', constant_values=prices[0] if len(prices) > 0 else 0)
                    neighbor_prices.append(padded[:seq_length])
                else:
                    # If no history, use zeros
                    neighbor_prices.append(np.zeros(seq_length))
            
            # Store in dictionary
            key = (listing_id, date)
            neighbor_histories[key] = np.array(neighbor_prices)
    
    print(f"Collected price histories for {len(neighbor_histories)} listing-date combinations")
    return neighbor_histories


# Build enhanced spatial graph for GNN
def build_spatial_graph(train_data, test_data, k=5, feature_weight=0.3):
    """
    Build a graph with edge weights based on both geographic and feature similarity
    """
    # Extract coordinates
    train_coords = train_data[['latitude', 'longitude']].values
    test_coords = test_data[['latitude', 'longitude']].values
    
    print(f"Building spatial graph with {len(test_coords)} test listings and {k} nearest neighbors...")
    
    # Extract and normalize key features for similarity calculation
    features = ['accommodates', 'bedrooms', 'bathrooms']
    available_features = [f for f in features if f in train_data.columns]
    
    if available_features:
        scaler = StandardScaler()
        train_features = scaler.fit_transform(train_data[available_features].fillna(0))
        test_features = scaler.transform(test_data[available_features].fillna(0))
    else:
        # Fallback if no features are available
        print("Warning: No property features available for similarity calculation")
        train_features = np.ones((len(train_coords), 1))
        test_features = np.ones((len(test_coords), 1))
    
    # Find k nearest neighbors for each test listing
    nn = NearestNeighbors(n_neighbors=min(k, len(train_coords)))
    nn.fit(train_coords)
    distances, indices = nn.kneighbors(test_coords)
    
    # Create edge indices and attributes
    edge_index = []
    edge_attr = []
    
    for test_idx, (neighbor_indices, neighbor_distances) in enumerate(zip(indices, distances)):
        test_feat = test_features[test_idx]
        
        for train_idx, distance in zip(neighbor_indices, neighbor_distances):
            # Calculate feature similarity (cosine similarity)
            train_feat = train_features[train_idx]
            feat_norm_product = np.linalg.norm(test_feat) * np.linalg.norm(train_feat)
            
            if feat_norm_product > 1e-8:  # Avoid division by zero
                feat_sim = np.dot(test_feat, train_feat) / feat_norm_product
            else:
                feat_sim = 0.0
            
            # Normalize distance for better numerical stability
            geo_weight = 1.0 / (distance + 1e-6)
            
            # Combined weight: (1-α) * geo_weight + α * feature_weight
            combined_weight = (1 - feature_weight) * geo_weight + feature_weight * max(0, feat_sim)
            
            # Add edge from test listing to train listing
            edge_index.append([test_idx + len(train_data), train_idx])
            edge_attr.append([combined_weight])
            
            # Add reverse edge
            edge_index.append([train_idx, test_idx + len(train_data)])
            edge_attr.append([combined_weight])
    
    # Add edges between training listings (for smaller datasets)
    if len(train_coords) <= 5000:
        train_nn = NearestNeighbors(n_neighbors=min(5, len(train_coords) - 1))
        train_nn.fit(train_coords)
        train_distances, train_indices = train_nn.kneighbors(train_coords)
        
        for train_idx, (neighbor_indices, neighbor_distances) in enumerate(zip(train_indices, train_distances)):
            for neighbor_idx, distance in zip(neighbor_indices, neighbor_distances):
                if train_idx != neighbor_idx:  # Skip self-loops
                    # Calculate feature similarity
                    train_feat_i = train_features[train_idx]
                    train_feat_j = train_features[neighbor_idx]
                    
                    feat_norm_product = np.linalg.norm(train_feat_i) * np.linalg.norm(train_feat_j)
                    if feat_norm_product > 1e-8:
                        feat_sim = np.dot(train_feat_i, train_feat_j) / feat_norm_product
                    else:
                        feat_sim = 0.0
                    
                    geo_weight = 1.0 / (distance + 1e-6)
                    combined_weight = (1 - feature_weight) * geo_weight + feature_weight * max(0, feat_sim)
                    
                    edge_index.append([train_idx, neighbor_idx])
                    edge_attr.append([combined_weight])
    
    # Convert to tensors
    edge_index_tensor = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr_tensor = torch.tensor(edge_attr, dtype=torch.float32)
    
    print(f"Created graph with {edge_index_tensor.shape[1]} edges")
    
    return edge_index_tensor, edge_attr_tensor


# ====================== Model Definition ======================
class LightweightLSTMEncoder(nn.Module):
    """
    Lightweight LSTM encoder for time series with reduced dimensionality
    """
    def __init__(self, hidden_dim=8, output_dim=64, bidirectional=True):
        super(LightweightLSTMEncoder, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.bidirectional = bidirectional
        
        # LSTM with small hidden dimension
        self.lstm = nn.LSTM(
            input_size=1,  # Single feature per timestep (price)
            hidden_size=hidden_dim,
            num_layers=1,
            batch_first=True,
            bidirectional=bidirectional
        )
        
        # Projection from LSTM output to desired output dimension
        lstm_output_dim = hidden_dim * 2 if bidirectional else hidden_dim
        self.projection = nn.Linear(lstm_output_dim, output_dim)
        
    def forward(self, x):
        # Reshape input if needed: [batch_size, seq_len] -> [batch_size, seq_len, 1]
        if x.dim() == 2:
            x = x.unsqueeze(-1)
        
        # Pass through LSTM
        output, (hidden, _) = self.lstm(x)
        
        if self.bidirectional:
            # Concatenate the final hidden states from both directions
            hidden_forward = hidden[-2, :, :]
            hidden_backward = hidden[-1, :, :]
            combined = torch.cat((hidden_forward, hidden_backward), dim=1)
        else:
            # Use only the final hidden state
            combined = hidden[-1, :, :]
            
        # Project to output dimension
        projected = self.projection(combined)
        return projected


class NeighborPricePredictor(nn.Module):
    """
    GNN model that predicts listing prices using only neighboring listings' histories
    """
    def __init__(self, 
                 spatial_features_dim,
                 temporal_features_dim,
                 amenity_features_dim,
                 num_neighbors=5,
                 seq_length=30,
                 lstm_hidden_dim=8,  # Lightweight LSTM dimension
                 hidden_dim=64,
                 dropout=0.3,
                 heads=4,
                 edge_dim=1):
        super(NeighborPricePredictor, self).__init__()
        
        # Store parameters
        self.num_neighbors = num_neighbors
        self.seq_length = seq_length
        self.lstm_hidden_dim = lstm_hidden_dim
        
        # For multi-head attention
        self.h_dim = hidden_dim
        self.heads = heads
        self.head_dim = hidden_dim // heads
        
        # GAT layers for spatial features
        gat_out_dim = self.head_dim * heads
        self.gat1 = GATv2Conv(spatial_features_dim, self.head_dim, heads=heads, edge_dim=edge_dim)
        self.gat2 = GATv2Conv(gat_out_dim, self.head_dim, heads=heads, edge_dim=edge_dim)
        
        # Batch normalization
        self.bn1 = nn.BatchNorm1d(gat_out_dim)
        self.bn2 = nn.BatchNorm1d(gat_out_dim)
        
        # Temporal feature processing
        self.temporal_layer1 = nn.Linear(temporal_features_dim, hidden_dim)
        self.temporal_bn1 = nn.BatchNorm1d(hidden_dim)
        self.temporal_layer2 = nn.Linear(hidden_dim, hidden_dim)
        self.temporal_bn2 = nn.BatchNorm1d(hidden_dim)
        
        # Amenity feature processing
        self.amenity_layer1 = nn.Linear(amenity_features_dim, hidden_dim)
        self.amenity_bn1 = nn.BatchNorm1d(hidden_dim)
        self.amenity_layer2 = nn.Linear(hidden_dim, hidden_dim)
        self.amenity_bn2 = nn.BatchNorm1d(hidden_dim)
        
        # Lightweight LSTM for neighbor price history
        self.neighbor_lstm = LightweightLSTMEncoder(
            hidden_dim=lstm_hidden_dim,
            output_dim=hidden_dim,
            bidirectional=True
        )
        
        # Neighbor attention mechanism
        self.neighbor_attention = nn.Parameter(torch.ones(num_neighbors, 1))
        
        # Feature fusion mechanism
        self.fusion_weights = nn.Parameter(torch.ones(4, hidden_dim))  # spatial, temporal, amenity, neighbor
        self.fusion_bias = nn.Parameter(torch.zeros(hidden_dim))
        
        # Final prediction layers
        self.fc1 = nn.Linear(hidden_dim, hidden_dim * 2)
        self.fc1_bn = nn.BatchNorm1d(hidden_dim * 2)
        self.fc2 = nn.Linear(hidden_dim * 2, hidden_dim)
        self.fc2_bn = nn.BatchNorm1d(hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 1)
        
        self.dropout = nn.Dropout(dropout)
        self.dropout_heavy = nn.Dropout(dropout + 0.1)
        
        # Dimension adjustment if needed
        self.dim_adjust = None
        if gat_out_dim != hidden_dim:
            self.dim_adjust = nn.Linear(gat_out_dim, hidden_dim)
    
    def forward(self, data):
        # Unpack the data object
        x, edge_index, edge_attr, temporal_x, amenity_x, neighbor_histories = (
            data.x, data.edge_index, data.edge_attr, data.temporal_x, 
            data.amenity_x, data.neighbor_histories
        )
        
        # Process spatial features with GAT
        spatial_features = self.gat1(x, edge_index, edge_attr=edge_attr)
        spatial_features = F.elu(spatial_features)
        spatial_features = self.bn1(spatial_features)
        spatial_features = self.dropout(spatial_features)
        
        spatial_features_res = spatial_features
        spatial_features = self.gat2(spatial_features, edge_index, edge_attr=edge_attr)
        spatial_features = self.bn2(spatial_features)
        
        # Add residual connection if dimensions match
        if spatial_features.shape == spatial_features_res.shape:
            spatial_features = spatial_features + spatial_features_res
        
        # Apply dimension adjustment if needed
        if self.dim_adjust is not None:
            spatial_features = self.dim_adjust(spatial_features)
        
        # Process temporal features
        temporal_features = F.elu(self.temporal_layer1(temporal_x))
        temporal_features = self.temporal_bn1(temporal_features)
        temporal_features = self.dropout(temporal_features)
        temporal_features_res = temporal_features
        temporal_features = F.elu(self.temporal_layer2(temporal_features))
        temporal_features = self.temporal_bn2(temporal_features)
        temporal_features = temporal_features + temporal_features_res
        
        # Process amenity features
        amenity_features = F.elu(self.amenity_layer1(amenity_x))
        amenity_features = self.amenity_bn1(amenity_features)
        amenity_features = self.dropout(amenity_features)
        amenity_features_res = amenity_features
        amenity_features = F.elu(self.amenity_layer2(amenity_features))
        amenity_features = self.amenity_bn2(amenity_features)
        amenity_features = amenity_features + amenity_features_res
        
        # Process neighbor histories
        batch_size = x.shape[0]
        
        # Initialize neighbor embeddings for each neighbor in the batch
        neighbor_embeddings = []
        
        # For each neighbor position
        for n_idx in range(self.num_neighbors):
            # Get this neighbor's price history for all listings in batch
            # Shape: [batch_size, seq_length]
            neighbor_history = neighbor_histories[:, n_idx, :]
            
            # Pass through lightweight LSTM encoder
            neighbor_embedding = self.neighbor_lstm(neighbor_history)
            neighbor_embeddings.append(neighbor_embedding)
        
        # Apply attention to weight different neighbors
        normalized_attention = F.softmax(self.neighbor_attention, dim=0)
        
        # Weighted sum of neighbor embeddings
        neighbor_features = torch.zeros_like(neighbor_embeddings[0])
        for n_idx in range(self.num_neighbors):
            neighbor_features += neighbor_embeddings[n_idx] * normalized_attention[n_idx]
        
        # Feature fusion with learned weights
        normalized_weights = F.softmax(self.fusion_weights, dim=0)
        
        # Apply weights to each feature type
        fused_features = (
            spatial_features * normalized_weights[0] +
            temporal_features * normalized_weights[1] +
            amenity_features * normalized_weights[2] +
            neighbor_features * normalized_weights[3] +
            self.fusion_bias
        )
        
        # Final prediction layers
        out = F.elu(self.fc1(fused_features))
        out = self.fc1_bn(out)
        out = self.dropout(out)
        
        out = F.elu(self.fc2(out))
        out = self.fc2_bn(out)
        out = self.dropout_heavy(out)
        
        price_prediction = self.fc3(out)
        
        return price_prediction


# ====================== Data Preparation Functions ======================
def prepare_neighbor_data(train_data, val_data, spatial_features, temporal_features, 
                         amenity_features, num_neighbors=5, seq_length=30, device='cuda'):
    """
    Prepare data for the neighbor-based price prediction model
    """
    print("Preparing data for neighbor-based price prediction...")
    
    # Initialize scalers
    spatial_scaler = StandardScaler()
    temporal_scaler = StandardScaler()
    amenity_scaler = StandardScaler()
    target_scaler = StandardScaler()
    
    # Fit scalers on training data
    spatial_scaler.fit(train_data[spatial_features])
    temporal_scaler.fit(train_data[temporal_features])
    amenity_scaler.fit(train_data[amenity_features])
    target_scaler.fit(train_data['price'].values.reshape(-1, 1))
    
    # Transform features
    X_train_spatial = spatial_scaler.transform(train_data[spatial_features]).astype(np.float32)
    X_val_spatial = spatial_scaler.transform(val_data[spatial_features]).astype(np.float32)
    
    X_train_temporal = temporal_scaler.transform(train_data[temporal_features]).astype(np.float32)
    X_val_temporal = temporal_scaler.transform(val_data[temporal_features]).astype(np.float32)
    
    X_train_amenity = amenity_scaler.transform(train_data[amenity_features]).astype(np.float32)
    X_val_amenity = amenity_scaler.transform(val_data[amenity_features]).astype(np.float32)
    
    # Transform targets
    y_train = target_scaler.transform(train_data['price'].values.reshape(-1, 1)).flatten().astype(np.float32)
    y_val = target_scaler.transform(val_data['price'].values.reshape(-1, 1)).flatten().astype(np.float32)
    
    # Combine train and val for graph construction
    X_combined_spatial = np.vstack([X_train_spatial, X_val_spatial])
    X_combined_temporal = np.vstack([X_train_temporal, X_val_temporal])
    X_combined_amenity = np.vstack([X_train_amenity, X_val_amenity])
    
    # Get unique listing IDs
    train_listings = train_data['listing_id'].unique()
    val_listings = val_data['listing_id'].unique()
    
    # Create combined dataset for neighbor selection
    combined_data = pd.concat([train_data, val_data], ignore_index=True)
    
    # Select neighbors for each listing
    neighbors_dict = select_neighbors(combined_data, k=num_neighbors)
    
    # Build spatial graph
    edge_index, edge_attr = build_spatial_graph(
        train_data[['latitude', 'longitude'] + [f for f in ['accommodates', 'bedrooms', 'bathrooms'] if f in train_data.columns]], 
        val_data[['latitude', 'longitude'] + [f for f in ['accommodates', 'bedrooms', 'bathrooms'] if f in val_data.columns]], 
        k=num_neighbors
    )
    
    # Collect neighbor price histories
    all_listings = np.concatenate([train_listings, val_listings])
    
    # Initialize neighbor histories tensor: [num_nodes, num_neighbors, seq_length]
    X_neighbor_histories = np.zeros((len(combined_data), num_neighbors, seq_length), dtype=np.float32)
    
    # Get price histories from neighbors
    print("Collecting price histories from neighbors...")
    for i, listing_id in enumerate(combined_data['listing_id'].unique()):
        # Get data for this listing
        listing_data = combined_data[combined_data['listing_id'] == listing_id]
        
        # For each row (date) for this listing
        for _, row in listing_data.iterrows():
            # Get index in combined data
            idx = combined_data[(combined_data['listing_id'] == listing_id) & 
                                (combined_data['date'] == row['date'])].index[0]
            
            # Internal index for neighbors dictionary
            internal_idx = np.where(all_listings == listing_id)[0][0] if listing_id in all_listings else None
            
            if internal_idx is not None and internal_idx in neighbors_dict:
                # Get neighbors
                neighbors = neighbors_dict[internal_idx]
                
                # For each neighbor
                for n_idx, neighbor_idx in enumerate(neighbors[:num_neighbors]):
                    if n_idx >= num_neighbors:
                        break
                        
                    # Get neighbor listing ID
                    if neighbor_idx < len(all_listings):
                        neighbor_id = all_listings[neighbor_idx]
                        
                        # Get neighbor's price history prior to this date
                        neighbor_history = combined_data[
                            (combined_data['listing_id'] == neighbor_id) & 
                            (combined_data['date'] < row['date'])
                        ].sort_values('date', ascending=False).head(seq_length)
                        
                        # Store neighbor's prices
                        if len(neighbor_history) > 0:
                            prices = neighbor_history['price'].values
                            # Pad if needed
                            X_neighbor_histories[idx, n_idx, :len(prices)] = prices[:seq_length]
                            
                            # If fewer than seq_length days of history, pad with last known price
                            if len(prices) < seq_length:
                                X_neighbor_histories[idx, n_idx, len(prices):] = prices[-1]
    
    # Create combined y with placeholder values for validation
    y_combined = np.zeros(len(X_combined_spatial), dtype=np.float32)
    y_combined[:len(y_train)] = y_train
    
    # Normalize neighbor histories
    # Reshape to 2D for scaling
    orig_shape = X_neighbor_histories.shape
    X_neighbor_histories_2d = X_neighbor_histories.reshape(-1, seq_length)
    
    # Create scaler
    neighbor_scaler = StandardScaler()
    X_neighbor_histories_scaled = neighbor_scaler.fit_transform(X_neighbor_histories_2d)
    
    # Reshape back to 3D
    X_neighbor_histories = X_neighbor_histories_scaled.reshape(orig_shape)
    
    # Create PyG data object
    data = Data(
        x=torch.FloatTensor(X_combined_spatial).to(device),
        edge_index=edge_index.to(device),
        edge_attr=edge_attr.to(device),
        temporal_x=torch.FloatTensor(X_combined_temporal).to(device),
        amenity_x=torch.FloatTensor(X_combined_amenity).to(device),
        neighbor_histories=torch.FloatTensor(X_neighbor_histories).to(device),
        y=torch.FloatTensor(y_combined.reshape(-1, 1)).to(device),
        train_mask=torch.zeros(len(X_combined_spatial), dtype=torch.bool).to(device),
        val_mask=torch.zeros(len(X_combined_spatial), dtype=torch.bool).to(device),
        val_y=torch.FloatTensor(y_val.reshape(-1, 1)).to(device),
    )
    
    # Set masks
    data.train_mask[:len(X_train_spatial)] = True
    data.val_mask[len(X_train_spatial):] = True
    
    print(f"Prepared data with {len(X_combined_spatial)} nodes")
    print(f"Train nodes: {data.train_mask.sum().item()}, Val nodes: {data.val_mask.sum().item()}")
    
    return data, spatial_scaler, temporal_scaler, amenity_scaler, neighbor_scaler, target_scaler, neighbors_dict


# ====================== Training and Evaluation Functions ======================
def train_model(data, model, epochs=50, lr=0.001, patience=10, device='cuda'):
    """
    Train the neighbor-based price prediction model
    """
    print("\n===== Training Neighbor-Based Price Prediction Model =====")
    
    # Initialize optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    criterion = nn.HuberLoss(delta=1.0)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )
    
    # Training loop
    best_val_loss = float('inf')
    best_model_state = None
    counter = 0
    
    # Store history
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_rmse': [],
        'val_mae': [],
        'lr': []
    }
    
    for epoch in range(epochs):
        # Training
        model.train()
        optimizer.zero_grad()
        
        # Forward pass
        out = model(data)
        
        # Get outputs for training nodes only
        train_out = out[data.train_mask]
        train_y = data.y[data.train_mask]
        
        # Calculate loss
        loss = criterion(train_out, train_y)
        
        # Backward pass and optimize
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        # Validation
        model.eval()
        with torch.no_grad():
            # Forward pass
            val_out = model(data)[data.val_mask]
            val_y = data.val_y
            
            # Calculate validation loss
            val_loss = criterion(val_out, val_y)
            
            # Calculate metrics on scaled data
            val_rmse = torch.sqrt(F.mse_loss(val_out, val_y)).item()
            val_mae = F.l1_loss(val_out, val_y).item()
            
        # Store history
        history['train_loss'].append(loss.item())
        history['val_loss'].append(val_loss.item())
        history['val_rmse'].append(val_rmse)
        history['val_mae'].append(val_mae)
        history['lr'].append(optimizer.param_groups[0]['lr'])
            
        # Print progress
        print(f"Epoch {epoch+1}/{epochs} - Loss: {loss.item():.4f}, Val Loss: {val_loss.item():.4f}, "
              f"Val RMSE: {val_rmse:.4f}, Val MAE: {val_mae:.4f}")
        
        # Learning rate scheduling
        scheduler.step(val_loss)
        
        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict().copy()
            counter = 0
        else:
            counter += 1
            
        if counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break
        
        # Memory management
        if device.type == 'cuda':
            torch.cuda.empty_cache()
    
    # Load best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    
    return model, history


def predict(model, data, target_scaler, device='cuda'):
    """
    Make predictions with the trained model
    """
    model.eval()
    with torch.no_grad():
        # Forward pass
        predictions = model(data)[data.val_mask]
        
        # Transform back to original scale
        predictions_np = target_scaler.inverse_transform(predictions.cpu().numpy())
        
        # Inverse log transformation if applied
        predictions_orig = np.expm1(predictions_np)
        
    return predictions_orig


def evaluate_model(model, data, true_prices, target_scaler, device='cuda'):
    """
    Evaluate the model on test data
    """
    # Make predictions
    predictions = predict(model, data, target_scaler, device)
    
    # Evaluate
    metrics = evaluate_predictions(true_prices, predictions.flatten())
    
    return predictions, metrics


# ====================== Main Function ======================
def run_neighbor_based_prediction(train_path, train_ids_path, test_ids_path, output_dir=None, 
                                 sample_size=None, num_neighbors=5, seq_length=30,
                                 lstm_hidden_dim=8, hidden_dim=64, epochs=50, lr=0.001):
    """
    Complete pipeline for training and evaluating the neighbor-based price prediction model
    """
    try:
        # Create output directory if not exists
        if output_dir:
            os.makedirs(output_dir, exist_ok=True)
        
        # Load training data
        print("Loading data...")
        train_data = pd.read_csv(train_path)

        # Load listing IDs for train/test split
        print("Loading train/test listing IDs...")
        with open(train_ids_path, 'r') as f:
            train_listing_ids = [int(line.strip()) for line in f.readlines()]
            
        with open(test_ids_path, 'r') as f:
            test_listing_ids = [int(line.strip()) for line in f.readlines()]
        
        print(f"Loaded {len(train_listing_ids)} train IDs and {len(test_listing_ids)} test IDs")
        
        # For testing - take only a small sample of listings if specified
        if sample_size:
            print(f"Limiting to {sample_size} random listings for testing")
            np.random.seed(42)
            selected_train = np.random.choice(train_listing_ids, int(sample_size * 0.75), replace=False)
            selected_test = np.random.choice(test_listing_ids, int(sample_size * 0.25), replace=False)
            train_listing_ids = selected_train.tolist()
            test_listing_ids = selected_test.tolist()
        
        # Convert date column to datetime if needed
        if 'date' in train_data.columns and not pd.api.types.is_datetime64_any_dtype(train_data['date']):
            train_data['date'] = pd.to_datetime(train_data['date'])
        
        # Create calculated features
        print("Creating calculated features...")
        train_data = create_calculated_features(train_data)
        
        # Check for NaN values in the dataset and fill them
        nan_columns = train_data.columns[train_data.isna().any()].tolist()
        if nan_columns:
            print(f"Warning: Found NaN values in columns: {nan_columns}")
            print("Filling NaN values with column means/medians")
            
            for col in nan_columns:
                if np.issubdtype(train_data[col].dtype, np.number):
                    # Fill with median for numeric columns
                    train_data[col] = train_data[col].fillna(train_data[col].median())
                else:
                    # For non-numeric, fill with mode
                    train_data[col] = train_data[col].fillna(train_data[col].mode()[0])
        
        # Split data into train and test based on listing IDs
        train_mask = train_data['listing_id'].isin(train_listing_ids)
        test_mask = train_data['listing_id'].isin(test_listing_ids)
        
        train_df = train_data[train_mask].copy()
        test_df = train_data[test_mask].copy()
        
        print(f"Train data: {len(train_df)} rows, {len(train_df['listing_id'].unique())} unique listings")
        print(f"Test data: {len(test_df)} rows, {len(test_df['listing_id'].unique())} unique listings")
        
        # Define feature groups based on your dataset columns
        spatial_features = [
            'latitude', 'longitude'
        ]
        
        # Temporal features - using your DTF prefixed features
        temporal_features = [
            'DTF_day_of_week', 'DTF_month', 'DTF_is_weekend',
            'DTF_season_sin', 'DTF_season_cos'
        ]
        
        # Amenity features - all has_* columns plus accommodates, bedrooms, bathrooms
        amenity_features = [col for col in train_df.columns if col.startswith('has_')]
        basic_property_features = ['accommodates', 'bedrooms', 'bathrooms', 'essential_score', 
                                   'luxury_score', 'amenity_count']
        available_basic_features = [f for f in basic_property_features if f in train_df.columns]
        amenity_features.extend(available_basic_features)
        
        # Ensure all feature lists only contain columns that exist in the dataset
        spatial_features = [f for f in spatial_features if f in train_df.columns]
        temporal_features = [f for f in temporal_features if f in train_df.columns]
        amenity_features = [f for f in amenity_features if f in train_df.columns]
        
        # If any feature group is empty, create dummy features
        if not amenity_features:
            print("No amenity features found, creating dummy feature")
            train_df['dummy_amenity'] = 1
            test_df['dummy_amenity'] = 1
            amenity_features = ['dummy_amenity']
        
        print(f"Using {len(spatial_features)} spatial features, {len(temporal_features)} temporal features, "
              f"and {len(amenity_features)} amenity features")
        
        # Apply log transformation to prices
        train_df = apply_price_transformation(train_df)
        test_df = apply_price_transformation(test_df)
        
        # Split train data into train and validation
        unique_train_listings = train_df['listing_id'].unique()
        train_listings, val_listings = train_test_split(
            unique_train_listings, test_size=0.2, random_state=42
        )
        
        train_subset = train_df[train_df['listing_id'].isin(train_listings)].copy()
        val_subset = train_df[train_df['listing_id'].isin(val_listings)].copy()
        
        print(f"Train subset: {len(train_subset)} rows, {len(train_listings)} listings")
        print(f"Validation subset: {len(val_subset)} rows, {len(val_listings)} listings")
        
        # Initialize device
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {device}")
        
        # Memory management before training
        gc.collect()
        if device.type == 'cuda':
            torch.cuda.empty_cache()
        
        # Prepare data with neighbor price histories
        data, spatial_scaler, temporal_scaler, amenity_scaler, neighbor_scaler, target_scaler, neighbors_dict = prepare_neighbor_data(
            train_subset, val_subset, spatial_features, temporal_features, amenity_features,
            num_neighbors=num_neighbors, seq_length=seq_length, device=device
        )
        
        # Initialize model
        model = NeighborPricePredictor(
            spatial_features_dim=len(spatial_features),
            temporal_features_dim=len(temporal_features),
            amenity_features_dim=len(amenity_features),
            num_neighbors=num_neighbors,
            seq_length=seq_length,
            lstm_hidden_dim=lstm_hidden_dim,
            hidden_dim=hidden_dim,
            dropout=0.3,
            heads=4,
            edge_dim=1
        ).to(device)
        
        # Train model
        model, history = train_model(
            data, model, epochs=epochs, lr=lr, patience=10, device=device
        )
        
        # Plot training history
        plt.figure(figsize=(12, 8))
        
        # Plot training and validation loss
        plt.subplot(2, 2, 1)
        plt.plot(history['train_loss'], label='Train Loss')
        plt.plot(history['val_loss'], label='Validation Loss')
        plt.title('Training and Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        
        # Plot validation RMSE
        plt.subplot(2, 2, 2)
        plt.plot(history['val_rmse'], label='Validation RMSE')
        plt.title('Validation RMSE')
        plt.xlabel('Epoch')
        plt.ylabel('RMSE')
        
        # Plot validation MAE
        plt.subplot(2, 2, 3)
        plt.plot(history['val_mae'], label='Validation MAE')
        plt.title('Validation MAE')
        plt.xlabel('Epoch')
        plt.ylabel('MAE')
        
        # Plot learning rate
        plt.subplot(2, 2, 4)
        plt.plot(history['lr'], label='Learning Rate')
        plt.title('Learning Rate')
        plt.xlabel('Epoch')
        plt.ylabel('LR')
        plt.yscale('log')
        
        plt.tight_layout()
        
        if output_dir:
            plt.savefig(os.path.join(output_dir, 'training_history.png'))
        plt.show()
        
        # Evaluate on validation data
        val_predictions, val_metrics = evaluate_model(
            model, data, val_subset['original_price'].values, target_scaler, device
        )
        
        # Prepare test data for final evaluation
        test_data, _, _, _, _, _, _ = prepare_neighbor_data(
            train_subset, test_df, spatial_features, temporal_features, amenity_features,
            num_neighbors=num_neighbors, seq_length=seq_length, device=device
        )
        
        # Evaluate on test data
        test_predictions, test_metrics = evaluate_model(
            model, test_data, test_df['original_price'].values, target_scaler, device
        )
        
        # Plot results
        plot_results(test_df['original_price'].values, test_predictions.flatten(), output_dir=output_dir)
        
        # Save model and scalers
        if output_dir:
            torch.save(model.state_dict(), os.path.join(output_dir, 'neighbor_model.pt'))
            torch.save({
                'spatial_scaler': spatial_scaler,
                'temporal_scaler': temporal_scaler,
                'amenity_scaler': amenity_scaler,
                'neighbor_scaler': neighbor_scaler,
                'target_scaler': target_scaler,
                'num_neighbors': num_neighbors,
                'seq_length': seq_length,
                'lstm_hidden_dim': lstm_hidden_dim
            }, os.path.join(output_dir, 'scalers.pt'))
            print(f"Model and scalers saved to {output_dir}")
            
            # Save test predictions
            test_results = pd.DataFrame({
                'listing_id': test_df['listing_id'].values,
                'date': test_df['date'].values,
                'actual': test_df['original_price'].values,
                'predicted': test_predictions.flatten(),
                'error': test_df['original_price'].values - test_predictions.flatten(),
                'abs_error': np.abs(test_df['original_price'].values - test_predictions.flatten()),
                'pct_error': np.abs((test_df['original_price'].values - test_predictions.flatten()) / 
                                   (test_df['original_price'].values + 1e-8)) * 100
            })
            test_results.to_csv(os.path.join(output_dir, 'test_predictions.csv'), index=False)
            print(f"Test predictions saved to {os.path.join(output_dir, 'test_predictions.csv')}")
            
            # Save model parameters and results
            with open(os.path.join(output_dir, 'model_summary.txt'), 'w') as f:
                f.write("Neighbor-Based Price Prediction Model Summary\n")
                f.write("===========================================\n\n")
                f.write(f"Number of neighbors: {num_neighbors}\n")
                f.write(f"Sequence length: {seq_length}\n")
                f.write(f"LSTM hidden dimension: {lstm_hidden_dim}\n")
                f.write(f"Hidden dimension: {hidden_dim}\n")
                f.write(f"Learning rate: {lr}\n")
                f.write(f"Epochs: {epochs}\n\n")
                
                f.write("Feature counts:\n")
                f.write(f"  Spatial features: {len(spatial_features)}\n")
                f.write(f"  Temporal features: {len(temporal_features)}\n")
                f.write(f"  Amenity features: {len(amenity_features)}\n\n")
                
                f.write("Test Results:\n")
                for metric, value in test_metrics.items():
                    f.write(f"  {metric}: {value:.6f}\n")
        
        # Print summary
        print("\n===== NEIGHBOR-BASED PRICE PREDICTION SUMMARY =====")
        print(f"Using {len(spatial_features)} spatial features, {len(temporal_features)} temporal features, "
              f"and {len(amenity_features)} amenity features")
        print(f"Number of neighbors: {num_neighbors}")
        print(f"LSTM hidden dimension: {lstm_hidden_dim}")
        
        print("\n=== Test Metrics ===")
        for metric, value in test_metrics.items():
            print(f"{metric}: {value:.4f}")
        
        # Return model and results
        return model, spatial_scaler, temporal_scaler, amenity_scaler, neighbor_scaler, target_scaler, test_metrics
    
    except Exception as e:
        print(f"Error in neighbor-based price prediction: {str(e)}")
        import traceback
        traceback.print_exc()
        return None


def visualize_neighbor_attention(model, output_dir=None):
    """
    Visualize attention weights for neighboring listings
    """
    # Extract neighbor attention weights
    attention_weights = F.softmax(model.neighbor_attention, dim=0).cpu().detach().numpy()
    
    # Create visualization
    plt.figure(figsize=(10, 6))
    
    # Plot attention weights
    plt.bar(range(len(attention_weights)), attention_weights.flatten())
    plt.xticks(range(len(attention_weights)), [f"Neighbor {i+1}" for i in range(len(attention_weights))])
    plt.ylabel('Attention Weight')
    plt.title('Neighbor Attention Weights')
    
    # Add weight values as text
    for i, weight in enumerate(attention_weights.flatten()):
        plt.text(i, weight + 0.01, f"{weight:.3f}", ha='center')
    
    plt.tight_layout()
    
    # Save if output_dir provided
    if output_dir:
        plt.savefig(os.path.join(output_dir, 'neighbor_attention.png'))
    
    plt.show()


def visualize_feature_importance(model, output_dir=None):
    """
    Visualize feature fusion weights
    """
    # Extract feature fusion weights
    fusion_weights = F.softmax(model.fusion_weights, dim=0).cpu().detach().numpy()
    
    # Feature types
    feature_types = ['Spatial', 'Temporal', 'Amenity', 'Neighbor']
    
    # Calculate average weight for each feature type
    avg_weights = [np.mean(fusion_weights[i]) for i in range(len(feature_types))]
    
    # Create visualization
    plt.figure(figsize=(10, 6))
    
    # Plot feature weights
    bars = plt.bar(range(len(feature_types)), avg_weights)
    plt.xticks(range(len(feature_types)), feature_types)
    plt.ylabel('Average Weight')
    plt.title('Feature Type Importance')
    
    # Add weight values as text
    for i, weight in enumerate(avg_weights):
        plt.text(i, weight + 0.01, f"{weight:.3f}", ha='center')
    
    # Highlight neighbor features
    bars[3].set_color('orange')
    
    plt.tight_layout()
    
    # Save if output_dir provided
    if output_dir:
        plt.savefig(os.path.join(output_dir, 'feature_importance.png'))
    
    plt.show()


# ====================== Main Execution ======================
if __name__ == "__main__":
    # Set paths to your data
    train_path = r"C:\Users\mvk\Documents\DATA_school\thesis\Subset\top_price_changers_subset\topic2_transformed.csv"
    train_ids_path = r"C:\Users\mvk\Documents\DATA_school\thesis\Subset\top_price_changers_subset\train_ids.txt"
    test_ids_path = r"C:\Users\mvk\Documents\DATA_school\thesis\Subset\top_price_changers_subset\test_ids.txt"
    
    # Create output directory
    output_dir = "./output/neighbor_price_prediction"
    os.makedirs(output_dir, exist_ok=True)
    
    # Model parameters
    num_neighbors = 5          # Number of neighbors to use
    seq_length = 30            # Sequence length for price history
    lstm_hidden_dim = 16       # Lightweight LSTM hidden dimension (8 or 16)
    hidden_dim = 64            # Hidden dimension for rest of model
    epochs = 50                # Maximum number of epochs
    lr = 0.001                 # Learning rate
    
    # Run full pipeline
    results = run_neighbor_based_prediction(
        train_path=train_path,
        train_ids_path=train_ids_path,
        test_ids_path=test_ids_path,
        output_dir=output_dir,
        num_neighbors=num_neighbors,
        seq_length=seq_length,
        lstm_hidden_dim=lstm_hidden_dim,
        hidden_dim=hidden_dim,
        epochs=epochs,
        lr=lr,
        sample_size=None  # Set to a number for testing or None for full dataset
    )
    
    if results:
        model, *_, test_metrics = results
        
        # Visualize attention weights
        visualize_neighbor_attention(model, output_dir)
        
        # Visualize feature importance
        visualize_feature_importance(model, output_dir)
        
        # Calculate parameter count
        total_params = sum(p.numel() for p in model.parameters())
        lstm_params = sum(p.numel() for name, p in model.named_parameters() if 'lstm' in name)
        
        print("\n===== Model Size Information =====")
        print(f"Total parameters: {total_params:,}")
        print(f"LSTM parameters: {lstm_params:,} ({lstm_params/total_params*100:.2f}% of total)")
        print(f"Model is using {lstm_hidden_dim}-dimensional LSTM")
        
        # Save parameter count
        with open(os.path.join(output_dir, 'model_size.txt'), 'w') as f:
            f.write(f"Total parameters: {total_params:,}\n")
            f.write(f"LSTM parameters: {lstm_params:,} ({lstm_params/total_params*100:.2f}% of total)\n")
            f.write(f"Model is using {lstm_hidden_dim}-dimensional LSTM\n")

Loading data...
Loading train/test listing IDs...
Loaded 6291 train IDs and 1573 test IDs
Creating calculated features...
Train data: 1123327 rows, 6291 unique listings
Test data: 281142 rows, 1573 unique listings
Using 2 spatial features, 5 temporal features, and 26 amenity features
Applying log transformation to price data
Applying log transformation to price data
Train subset: 898711 rows, 5032 listings
Validation subset: 224616 rows, 1259 listings
Using device: cuda
Preparing data for neighbor-based price prediction...
Selecting 5 neighbors for each listing...
Selected neighbors for 1123327 listings
Building spatial graph with 224616 test listings and 5 nearest neighbors...
Created graph with 2246160 edges
Collecting price histories from neighbors...
