In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from torch_geometric.nn import SAGEConv
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from datetime import datetime, timedelta
from tqdm import tqdm
import math
from torch.utils.data import Dataset, DataLoader
from sklearn.neighbors import BallTree
from math import radians
import os
import pickle
import hashlib


class NewListingPricePredictor(nn.Module):
    """Neural network model to predict prices for new listings without price history"""
    def __init__(self, input_dim, amenities_dim=20, hidden_dim=128):
        super(NewListingPricePredictor, self).__init__()
        
        # Save dimensions for debugging
        self.input_dim = input_dim
        self.amenities_dim = amenities_dim
        
        # Property feature embedding
        self.property_embedding = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        
        # Amenities-specific embedding
        self.amenities_embedding = nn.Sequential(
            nn.Linear(amenities_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        
        # Final prediction layers (simpler architecture)
        self.predictor = nn.Sequential(
            nn.Linear(hidden_dim + hidden_dim // 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, 1)
        )
        
    def forward(self, property_features, amenities_features, **kwargs):
        # Simple forward pass without graph convolution for debugging
        property_emb = self.property_embedding(property_features)
        amenities_emb = self.amenities_embedding(amenities_features)
        
        # Combine embeddings
        combined = torch.cat([property_emb, amenities_emb], dim=1)
        
        # Final prediction
        output = self.predictor(combined).squeeze(-1)
        
        return output

def calculate_distance(lat1, lon1, lat2, lon2):
    """Calculate the distance between two points using the Haversine formula."""
    R = 6371  # Earth's radius in kilometers
    lat1, lon1, lat2, lon2 = map(radians, [lat1, lon1, lat2, lon2])
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = np.sin(dlat/2)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon/2)**2
    c = 2 * np.arcsin(np.sqrt(a))
    return R * c

def get_cache_path(data_path, prefix="spatial_features"):
    """Generate a cache file path based on the input data path"""
    # Create a hash of the data path to use in the cache filename
    data_hash = hashlib.md5(data_path.encode()).hexdigest()[:10]
    
    # Create cache directory if it doesn't exist
    cache_dir = os.path.join(os.path.dirname(data_path), "cache")
    os.makedirs(cache_dir, exist_ok=True)
    
    # Return the cache file path
    return os.path.join(cache_dir, f"{prefix}_{data_hash}.pkl")

def create_enhanced_spatial_features(df, k_neighbors=10, chunk_size=1000, cache_path=None):
    """
    Create enhanced spatial features for new listings prediction
    
    This function creates:
    1. Distance to city center
    2. North-south and east-west position
    3. K-nearest neighbor statistics
    4. Neighborhood aggregate statistics
    
    Args:
        df: DataFrame with listing data
        k_neighbors: Number of neighbors to use for KNN features
        chunk_size: Size of chunks for processing
        cache_path: Path to cache file
        
    Returns:
        DataFrame with enhanced spatial features
    """
    # Check if cache exists and load from it if available
    if cache_path and os.path.exists(cache_path):
        print(f"Loading spatial features from cache: {cache_path}")
        try:
            with open(cache_path, 'rb') as f:
                spatial_features_df = pickle.load(f)
            
            # Verify that the cache contains the expected number of rows
            if len(spatial_features_df) == len(df):
                print("Cache loaded successfully")
                return spatial_features_df
            else:
                print("Cache size mismatch. Recalculating features.")
        except Exception as e:
            print(f"Error loading cache: {e}. Recalculating features.")
    
    print("Creating enhanced spatial features...")
    # Make a copy to avoid modifying the original
    enhanced_df = df.copy()
    
    # Paris coordinates
    city_center_lat, city_center_lon = 48.8566, 2.3522
    
    # Calculate distance to city center and directional positions
    enhanced_df['distance_to_center'] = enhanced_df.apply(
        lambda row: calculate_distance(
            row['latitude'], 
            row['longitude'], 
            city_center_lat, 
            city_center_lon
        ),
        axis=1
    )
    
    enhanced_df['north_south'] = enhanced_df['latitude'] - city_center_lat
    enhanced_df['east_west'] = enhanced_df['longitude'] - city_center_lon
    
    # Create BallTree for nearest neighbor calculations
    coords = np.radians(enhanced_df[['latitude', 'longitude']].values)
    tree = BallTree(coords, metric='haversine')
    
    # Process in chunks to avoid memory issues
    n_chunks = math.ceil(len(enhanced_df) / chunk_size)
    chunks = np.array_split(enhanced_df, n_chunks)
    
    # Initialize arrays for KNN features
    knn_price_mean = np.zeros(len(enhanced_df))
    knn_price_std = np.zeros(len(enhanced_df))
    knn_price_median = np.zeros(len(enhanced_df))
    knn_price_min = np.zeros(len(enhanced_df))
    knn_price_max = np.zeros(len(enhanced_df))
    price_diff = np.zeros(len(enhanced_df))
    
    # Process each chunk
    start_idx = 0
    for chunk in tqdm(chunks, desc="Processing spatial chunks"):
        chunk_size = len(chunk)
        chunk_coords = np.radians(chunk[['latitude', 'longitude']].values)
        
        # Find k+1 nearest neighbors (including self)
        # Use min to avoid requesting more neighbors than we have data points
        k_to_use = min(k_neighbors + 1, len(coords))
        distances, indices = tree.query(chunk_coords, k=k_to_use)
        
        # Calculate neighbor statistics
        for i in range(chunk_size):
            # Skip self (index 0)
            neighbor_indices = indices[i, 1:]
            prices = enhanced_df.iloc[neighbor_indices]['price'].values
            
            if len(prices) > 0:
                # Calculate statistics
                knn_price_mean[start_idx + i] = np.mean(prices)
                knn_price_std[start_idx + i] = np.std(prices)
                knn_price_median[start_idx + i] = np.median(prices)
                knn_price_min[start_idx + i] = np.min(prices)
                knn_price_max[start_idx + i] = np.max(prices)
                
                if pd.notnull(chunk.iloc[i]['price']):
                    price_diff[start_idx + i] = chunk.iloc[i]['price'] - np.mean(prices)
                else:
                    price_diff[start_idx + i] = 0
            else:
                # If no neighbors found, use zeros
                knn_price_mean[start_idx + i] = 0
                knn_price_std[start_idx + i] = 0
                knn_price_median[start_idx + i] = 0
                knn_price_min[start_idx + i] = 0
                knn_price_max[start_idx + i] = 0
                price_diff[start_idx + i] = 0
        
        start_idx += chunk_size
    
    # Add features to dataframe
    enhanced_df['knn_price_mean'] = knn_price_mean
    enhanced_df['knn_price_std'] = knn_price_std
    enhanced_df['knn_price_median'] = knn_price_median
    enhanced_df['knn_price_min'] = knn_price_min
    enhanced_df['knn_price_max'] = knn_price_max
    enhanced_df['price_diff_from_neighbors'] = price_diff
    
    # Add neighborhood aggregated statistics
    if 'neighbourhood_cleansed_encoded' in enhanced_df.columns:
        # Group by neighborhood and calculate statistics
        neighborhood_stats = enhanced_df.groupby('neighbourhood_cleansed_encoded').agg({
            'price': ['mean', 'std', 'median', 'min', 'max', 'count']
        })
        
        # Flatten column names
        neighborhood_stats.columns = ['_'.join(col).strip() for col in neighborhood_stats.columns.values]
        neighborhood_stats = neighborhood_stats.reset_index()
        
        # Merge statistics back to the dataframe
        enhanced_df = enhanced_df.merge(neighborhood_stats, on='neighbourhood_cleansed_encoded', how='left')
    
    # Standardize the new features
    spatial_features = [
        'distance_to_center', 'north_south', 'east_west',
        'knn_price_mean', 'knn_price_std', 'knn_price_median', 
        'knn_price_min', 'knn_price_max', 'price_diff_from_neighbors'
    ]
    
    for col in spatial_features:
        mean_val = enhanced_df[col].mean()
        std_val = enhanced_df[col].std()
        if std_val > 0:
            enhanced_df[col] = (enhanced_df[col] - mean_val) / std_val
    
    # Save to cache if path provided
    if cache_path:
        print(f"Saving spatial features to cache: {cache_path}")
        with open(cache_path, 'wb') as f:
            pickle.dump(enhanced_df, f)
    
    return enhanced_df

def extract_amenity_features(df):
    """Extract amenity features into a separate tensor"""
    amenity_cols = [col for col in df.columns if col.startswith('has_')]
    return df[amenity_cols].values

def extract_location_features(df):
    """Extract latitude and longitude into a separate tensor"""
    return df[['latitude', 'longitude']].values

def extract_temporal_features(df):
    """Extract temporal features into a separate tensor"""
    temporal_cols = ['DTF_day_of_week', 'DTF_month', 'DTF_is_weekend', 'DTF_season_sin', 'DTF_season_cos']
    if all(col in df.columns for col in temporal_cols):
        return df[temporal_cols].values
    else:
        # Create dummy temporal features if not available
        print("Warning: Temporal features not found. Using dummy values.")
        return np.zeros((len(df), 5))

def extract_neighborhood_stats(df):
    """Extract neighborhood statistics into a separate tensor"""
    neighborhood_cols = [
        'knn_price_mean', 'knn_price_std', 'knn_price_median', 
        'knn_price_min', 'knn_price_max'
    ]
    
    if all(col in df.columns for col in neighborhood_cols):
        return df[neighborhood_cols].values
    else:
        # Create dummy neighborhood stats if not available
        print("Warning: Neighborhood statistics not found. Using dummy values.")
        return np.zeros((len(df), 5))


def extract_property_features(df, amenity_cols=None, temporal_cols=None, neighborhood_cols=None, spatial_cols=None):
    """Extract core property features, excluding other feature categories"""
    if amenity_cols is None:
        amenity_cols = [col for col in df.columns if col.startswith('has_')]
    
    if temporal_cols is None:
        temporal_cols = ['DTF_day_of_week', 'DTF_month', 'DTF_is_weekend', 'DTF_season_sin', 'DTF_season_cos']
    
    if neighborhood_cols is None:
        neighborhood_cols = [
            'knn_price_mean', 'knn_price_std', 'knn_price_median', 
            'knn_price_min', 'knn_price_max'
        ]
    
    if spatial_cols is None:
        spatial_cols = ['latitude', 'longitude', 'distance_to_center', 'north_south', 'east_west', 'price_diff_from_neighbors']
    
    # Get all columns that aren't in the excluded categories
    exclude_cols = amenity_cols + temporal_cols + neighborhood_cols + spatial_cols + ['listing_id', 'date', 'price']
    property_cols = [col for col in df.columns if col not in exclude_cols]
    
    return df[property_cols].values


def create_batch_aware_edge_index(batch_data, k_neighbors=5, distance_threshold=2.0):
    """
    Create edge index for the current batch, ensuring all indices are valid
    
    Args:
        batch_data: DataFrame containing the current batch
        k_neighbors: Number of neighbors to connect each node to
        distance_threshold: Maximum distance (in km) to consider for connections
        
    Returns:
        PyTorch tensor with edge index
    """
    batch_size = len(batch_data)
    
    if batch_size <= 1:
        # For single item or empty batch, return empty edge index
        return torch.tensor([], dtype=torch.long).view(2, 0)
    
    # Extract coordinates
    coords = batch_data[['latitude', 'longitude']].values
    
    # Convert to radians for haversine distance
    coords_rad = np.radians(coords)
    
    # Create BallTree for efficient nearest neighbor search
    tree = BallTree(coords_rad, metric='haversine')
    
    # Limit k_neighbors to batch_size - 1 to avoid out of range indices
    k_to_use = min(k_neighbors + 1, batch_size)
    
    # Find nearest neighbors for each point
    distances, indices = tree.query(coords_rad, k=k_to_use)
    
    # Convert distances from radians to km
    distances = distances * 6371.0  # Earth radius in km
    
    # Create edge index
    edge_index = []
    for i in range(batch_size):
        # Skip the first neighbor (self)
        for j in range(1, len(indices[i])):
            neighbor_idx = indices[i, j]
            distance = distances[i, j]
            
            # Only add edges within threshold distance
            if distance <= distance_threshold:
                edge_index.append([i, int(neighbor_idx)])
                # Add reverse edge for undirected graph
                edge_index.append([int(neighbor_idx), i])
    
    # Remove duplicates
    if edge_index:
        edge_index = list(set(map(tuple, edge_index)))
        edge_index = [list(edge) for edge in edge_index]
        return torch.tensor(edge_index, dtype=torch.long).t()
    else:
        # If no edges, return empty edge index
        return torch.tensor([], dtype=torch.long).view(2, 0)

class AirbnbNewListingDataset(Dataset):
    """Dataset class for new listing price prediction"""
    def __init__(self, data):
        self.data = data
        
        # Extract different feature sets
        self.property_features = extract_property_features(data)
        self.amenity_features = extract_amenity_features(data)
        self.location_features = extract_location_features(data)
        self.temporal_features = extract_temporal_features(data)
        self.neighborhood_stats = extract_neighborhood_stats(data)
        
        # Target prices
        self.prices = data['price'].values
        
        # Keep original dataframe for batch-aware edge index creation
        self.original_df = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return (
            torch.FloatTensor(self.property_features[idx]),
            torch.FloatTensor(self.amenity_features[idx]),
            torch.FloatTensor(self.location_features[idx]),
            torch.FloatTensor(self.temporal_features[idx]),
            torch.FloatTensor(self.neighborhood_stats[idx]),
            torch.FloatTensor([self.prices[idx]])
        )

class BatchCollator:
    """Custom collator that also creates batch-aware edge index"""
    def __init__(self, dataset, k_neighbors=5, distance_threshold=2.0):
        self.dataset = dataset
        self.k_neighbors = k_neighbors
        self.distance_threshold = distance_threshold
    
    def __call__(self, batch_indices):
        # Get batch data
        batch_data = self.dataset.original_df.iloc[batch_indices].copy()
        
        # Create edge index for this specific batch
        edge_index = create_batch_aware_edge_index(
            batch_data, 
            k_neighbors=self.k_neighbors,
            distance_threshold=self.distance_threshold
        )
        
        # Get all items for the batch
        property_features = torch.FloatTensor(self.dataset.property_features[batch_indices])
        amenity_features = torch.FloatTensor(self.dataset.amenity_features[batch_indices])
        location_features = torch.FloatTensor(self.dataset.location_features[batch_indices])
        temporal_features = torch.FloatTensor(self.dataset.temporal_features[batch_indices])
        neighborhood_stats = torch.FloatTensor(self.dataset.neighborhood_stats[batch_indices])
        prices = torch.FloatTensor(self.dataset.prices[batch_indices])
        
        return property_features, amenity_features, location_features, temporal_features, neighborhood_stats, prices, edge_index

def train_new_listing_model(train_data, validation_data=None, epochs=10, batch_size=64, learning_rate=0.001, k_neighbors=5):
    """Train the new listing price prediction model with fixed batch handling"""
    # Create datasets
    train_dataset = AirbnbNewListingDataset(train_data)
    
    # Create standard DataLoader (without custom collator for now)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    if validation_data is not None:
        val_dataset = AirbnbNewListingDataset(validation_data)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    # Get dimension of property features
    property_dim = train_dataset.property_features.shape[1]
    amenities_dim = train_dataset.amenity_features.shape[1]
    
    print(f"Property features dimension: {property_dim}")
    print(f"Amenities features dimension: {amenities_dim}")
    
    # Initialize model
    device = torch.device('cpu')  # Use CPU for initial debugging
    print(f"Using device: {device}")
    
    model = NewListingPricePredictor(
        input_dim=property_dim,
        amenities_dim=amenities_dim
    ).to(device)
    
    # Define optimizer and loss function
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
    criterion = nn.MSELoss()
    
    # Training loop
    best_val_loss = float('inf')
    best_model = None
    training_losses = []
    validation_losses = []
    
    for epoch in range(epochs):
        # Training
        model.train()
        total_loss = 0
        batch_count = 0
        
        # NOTE: We're now only expecting 6 values from the DataLoader
        for property_feat, amenity_feat, location_feat, temporal_feat, neighborhood_feat, target in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            # Move data to device
            property_feat = property_feat.to(device)
            amenity_feat = amenity_feat.to(device)
            location_feat = location_feat.to(device)
            temporal_feat = temporal_feat.to(device)
            neighborhood_feat = neighborhood_feat.to(device)
            target = target.to(device).squeeze()
            
            # Forward pass without edge_index 
            output = model(
                property_feat, 
                amenity_feat, 
                location_feat, 
                temporal_feat, 
                neighborhood_feat
            )
            
            # Calculate loss
            loss = criterion(output, target)
            
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            total_loss += loss.item()
            batch_count += 1
        
        avg_train_loss = total_loss / batch_count if batch_count > 0 else float('inf')
        training_losses.append(avg_train_loss)
        
        # Validation
        if validation_data is not None:
            model.eval()
            val_loss = 0
            val_batch_count = 0
            
            with torch.no_grad():
                for property_feat, amenity_feat, location_feat, temporal_feat, neighborhood_feat, target, edge_index in val_loader:
                    # Move data to device
                    property_feat = property_feat.to(device)
                    amenity_feat = amenity_feat.to(device)
                    location_feat = location_feat.to(device)
                    temporal_feat = temporal_feat.to(device)
                    neighborhood_feat = neighborhood_feat.to(device)
                    target = target.to(device).squeeze()
                    edge_index = edge_index.to(device) if edge_index.numel() > 0 else None
                    
                    # Forward pass
                    output = model(
                        property_feat, 
                        amenity_feat, 
                        location_feat, 
                        temporal_feat, 
                        neighborhood_feat, 
                        edge_index
                    )
                    
                    # Calculate loss
                    loss = criterion(output, target)
                    
                    val_loss += loss.item()
                    val_batch_count += 1
            
            avg_val_loss = val_loss / val_batch_count if val_batch_count > 0 else float('inf')
            validation_losses.append(avg_val_loss)
            
            print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {avg_val_loss:.6f}")
            
            # Save best model
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                best_model = model.state_dict().copy()
        else:
            print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.6f}")
    
    # Load best model if validation was used
    if best_model is not None:
        model.load_state_dict(best_model)
    
    return model, {
        'training_losses': training_losses,
        'validation_losses': validation_losses
    }


def predict_new_listings(model, test_data, batch_size=64, k_neighbors=5):
    """Make predictions for new listings using batch-aware edge indices"""
    # Create dataset and batch collator
    test_dataset = AirbnbNewListingDataset(test_data)
    test_collator = BatchCollator(test_dataset, k_neighbors=k_neighbors)
    
    # Create data loader
    test_loader = DataLoader(
        dataset=range(len(test_dataset)),
        batch_size=batch_size,
        shuffle=False,
        collate_fn=test_collator
    )
    
    # Move model to evaluation mode
    device = next(model.parameters()).device
    model.eval()
    
    # Make predictions
    all_indices = []
    all_predictions = []
    
    with torch.no_grad():
        for i, (property_feat, amenity_feat, location_feat, temporal_feat, neighborhood_feat, target, edge_index) in enumerate(test_loader):
            # Track which indices we're processing
            batch_indices = list(range(i * batch_size, min((i + 1) * batch_size, len(test_dataset))))
            all_indices.extend(batch_indices)
            
            # Move data to device
            property_feat = property_feat.to(device)
            amenity_feat = amenity_feat.to(device)
            location_feat = location_feat.to(device)
            temporal_feat = temporal_feat.to(device)
            neighborhood_feat = neighborhood_feat.to(device)
            edge_index = edge_index.to(device) if edge_index.numel() > 0 else None
            
            # Forward pass
            output = model(
                property_feat, 
                amenity_feat, 
                location_feat, 
                temporal_feat, 
                neighborhood_feat,
                edge_index
            )
            
            # Store predictions
            all_predictions.extend(output.cpu().numpy())
    
    # Ensure predictions are in the correct order
    predictions_with_indices = list(zip(all_indices, all_predictions))
    predictions_with_indices.sort()  # Sort by index
    sorted_predictions = [pred for _, pred in predictions_with_indices]
    
    # Get actual targets
    all_targets = test_dataset.prices
    
    # Combine predictions and targets
    predictions_df = test_data[['listing_id', 'date']].copy()
    predictions_df['price'] = test_data['price']  # Actual price
    predictions_df['predicted'] = sorted_predictions
    predictions_df['error'] = predictions_df['price'] - predictions_df['predicted']
    predictions_df['abs_error'] = np.abs(predictions_df['error'])
    
    # Calculate metrics
    metrics = {
        'rmse': np.sqrt(mean_squared_error(all_targets, sorted_predictions)),
        'mae': mean_absolute_error(all_targets, sorted_predictions),
        'r2': r2_score(all_targets, sorted_predictions)
    }
    
    return predictions_df, metrics


def run_new_listing_prediction(train_path, test_path, output_path=None, use_validation=True, val_split=0.2, k_neighbors=5):
    """Run the entire pipeline for new listing price prediction"""
    print("Starting new listing price prediction pipeline...")
    
    # Load data
    print("Loading data...")
    train_data = pd.read_csv(train_path)
    test_data = pd.read_csv(test_path)
    
    # Convert date columns to datetime
    train_data['date'] = pd.to_datetime(train_data['date'])
    test_data['date'] = pd.to_datetime(test_data['date'])
    
    # Create enhanced spatial features with caching
    train_cache_path = get_cache_path(train_path, prefix="spatial_features_train")
    test_cache_path = get_cache_path(test_path, prefix="spatial_features_test")
    
    train_data = create_enhanced_spatial_features(
        train_data, 
        k_neighbors=k_neighbors,
        cache_path=train_cache_path
    )
    
    test_data = create_enhanced_spatial_features(
        test_data, 
        k_neighbors=k_neighbors,
        cache_path=test_cache_path
    )
    
    # Split train data into train and validation if requested
    if use_validation:
        # Sort by date
        train_data = train_data.sort_values('date')
        
        # Use the most recent data as validation
        val_size = int(len(train_data) * val_split)
        validation_data = train_data.tail(val_size)
        train_data = train_data.head(len(train_data) - val_size)
        print(f"Training data size: {len(train_data)}")
        print(f"Validation data size: {len(validation_data)}")
    else:
        validation_data = None
    
    # Train model
    model, training_history = train_new_listing_model(
        train_data=train_data,
        validation_data=validation_data,
        epochs=20,
        batch_size=64,
        learning_rate=0.001,
        k_neighbors=k_neighbors
    )
    
    # Save the trained model
    if output_path:
        model_dir = os.path.dirname(output_path)
        os.makedirs(model_dir, exist_ok=True)
        
        model_path = os.path.join(model_dir, "new_listing_model.pt")
        torch.save(model.state_dict(), model_path)
        print(f"Model saved to {model_path}")
    
    # Make predictions
    predictions_df, metrics = predict_new_listings(
        model=model, 
        test_data=test_data,
        batch_size=64,
        k_neighbors=k_neighbors
    )
    
    # Print metrics
    print("\n=== New Listing Price Prediction Results ===")
    print(f"RMSE: {metrics['rmse']:.4f}")
    print(f"MAE: {metrics['mae']:.4f}")
    print(f"R²: {metrics['r2']:.4f}")
    
    # Calculate MAPE (avoiding division by zero)
    mape = np.mean(np.abs(predictions_df['error'] / (predictions_df['price'] + 1e-8))) * 100
    print(f"MAPE: {mape:.2f}%")
    
    # Save results if output path is provided
    if output_path:
        predictions_df.to_csv(output_path, index=False)
        print(f"Results saved to {output_path}")
        
        # Save metrics to a separate file
        metrics_path = output_path.replace('.csv', '_metrics.csv')
        pd.DataFrame([{
            'rmse': metrics['rmse'],
            'mae': metrics['mae'],
            'r2': metrics['r2'],
            'mape': mape
        }]).to_csv(metrics_path, index=False)
        print(f"Metrics saved to {metrics_path}")
    
    # Plot training history
    if training_history['training_losses']:
        plt.figure(figsize=(10, 6))
        plt.plot(training_history['training_losses'], label='Training Loss')
        
        if training_history['validation_losses']:
            plt.plot(training_history['validation_losses'], label='Validation Loss')
        
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training History')
        plt.legend()
        plt.grid(True)
        
        if output_path:
            history_plot_path = output_path.replace('.csv', '_training_history.png')
            plt.savefig(history_plot_path)
            print(f"Training history plot saved to {history_plot_path}")
        
        plt.show()
    
    # Create visualization of actual vs predicted prices
    plt.figure(figsize=(10, 8))
    
    # Sample a subset if the dataset is large
    if len(predictions_df) > 1000:
        plot_data = predictions_df.sample(1000)
    else:
        plot_data = predictions_df
    
    plt.scatter(plot_data['price'], plot_data['predicted'], alpha=0.5)
    
    # Add perfect prediction line
    min_val = min(plot_data['price'].min(), plot_data['predicted'].min())
    max_val = max(plot_data['price'].max(), plot_data['predicted'].max())
    plt.plot([min_val, max_val], [min_val, max_val], 'r--')
    
    plt.xlabel('Actual Price')
    plt.ylabel('Predicted Price')
    plt.title('New Listing Price Prediction')
    plt.grid(True)
    
    if output_path:
        scatter_plot_path = output_path.replace('.csv', '_scatter_plot.png')
        plt.savefig(scatter_plot_path)
        print(f"Scatter plot saved to {scatter_plot_path}")
    
    plt.show()
    
    # Create heatmap of spatial errors
    if 'latitude' in test_data.columns and 'longitude' in test_data.columns:
        plt.figure(figsize=(12, 10))
        
        # Take a sample for better visualization
        if len(predictions_df) > 2000:
            plot_data = predictions_df.sample(2000)
        else:
            plot_data = predictions_df
        
        # Create a scatter plot with error as color
        scatter = plt.scatter(
            plot_data['longitude'], 
            plot_data['latitude'],
            c=np.abs(plot_data['error']),
            cmap='viridis',
            alpha=0.7,
            s=30
        )
        
        plt.colorbar(scatter, label='Absolute Error')
        plt.xlabel('Longitude')
        plt.ylabel('Latitude')
        plt.title('Spatial Distribution of Prediction Errors')
        plt.grid(True)
        
        if output_path:
            heatmap_path = output_path.replace('.csv', '_error_heatmap.png')
            plt.savefig(heatmap_path)
            print(f"Error heatmap saved to {heatmap_path}")
        
        plt.show()
    
    # Create feature importance plot using correlation analysis
    feature_importance = analyze_feature_importance(predictions_df, test_data)
    
    if feature_importance is not None and len(feature_importance) > 0:
        # Plot top 20 features or fewer if not available
        n_features = min(20, len(feature_importance))
        plt.figure(figsize=(12, 8))
        
        sns.barplot(
            x='importance',
            y='feature',
            data=feature_importance.head(n_features),
            palette='viridis'
        )
        
        plt.title('Feature Importance (Correlation with Target)')
        plt.xlabel('Absolute Correlation with Price')
        plt.ylabel('Feature')
        plt.tight_layout()
        
        if output_path:
            feature_plot_path = output_path.replace('.csv', '_feature_importance.png')
            plt.savefig(feature_plot_path)
            print(f"Feature importance plot saved to {feature_plot_path}")
        
        plt.show()
    
    return model, predictions_df, metrics


def analyze_feature_importance(predictions_df, test_data):
    """
    Analyze feature importance using correlation with target value
    
    Args:
        predictions_df: DataFrame with predictions
        test_data: Original test data with features
        
    Returns:
        DataFrame with feature importance
    """
    # Get listing IDs from predictions
    listing_ids = predictions_df['listing_id'].values
    
    # Filter test data to match predictions
    matched_test_data = test_data[test_data['listing_id'].isin(listing_ids)].copy()
    
    # Get features (excluding non-feature columns)
    non_feature_cols = ['listing_id', 'date', 'price']
    feature_cols = [col for col in matched_test_data.columns if col not in non_feature_cols]
    
    if len(feature_cols) == 0:
        print("No features found for importance analysis")
        return None
    
    # Calculate correlation with price
    correlations = []
    
    for col in feature_cols:
        try:
            corr = np.abs(np.corrcoef(matched_test_data[col].values, matched_test_data['price'].values)[0, 1])
            if not np.isnan(corr):
                correlations.append({
                    'feature': col,
                    'importance': corr
                })
        except:
            # Skip features that cause errors
            pass
    
    if not correlations:
        print("No valid correlations found")
        return None
    
    # Create DataFrame and sort by importance
    importance_df = pd.DataFrame(correlations)
    importance_df = importance_df.sort_values('importance', ascending=False)
    
    return importance_df

def compare_model_performance(model_results):
    """
    Compare performance across multiple models
    
    Args:
        model_results: Dictionary mapping model names to their results dictionaries
        
    Returns:
        DataFrame with comparison metrics
    """
    comparison = []
    
    for model_name, results in model_results.items():
        comparison.append({
            'Model': model_name,
            'RMSE': results['metrics']['rmse'],
            'MAE': results['metrics']['mae'],
            'R²': results['metrics']['r2'],
            'MAPE (%)': results.get('mape', 0)
        })
    
    comparison_df = pd.DataFrame(comparison)
    
    # Create comparison plot
    plt.figure(figsize=(12, 8))
    
    # Create a grouped bar chart
    comparison_df_melted = pd.melt(
        comparison_df, 
        id_vars=['Model'],
        value_vars=['RMSE', 'MAE', 'MAPE (%)'],
        var_name='Metric',
        value_name='Value'
    )
    
    sns.barplot(
        data=comparison_df_melted,
        x='Model',
        y='Value',
        hue='Metric'
    )
    
    plt.title('Model Performance Comparison')
    plt.ylabel('Error Value (lower is better)')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()
    
    # Create R² comparison
    plt.figure(figsize=(10, 6))
    
    sns.barplot(
        data=comparison_df,
        x='Model',
        y='R²',
        palette='viridis'
    )
    
    plt.title('Model R² Comparison (higher is better)')
    plt.ylabel('R² Score')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()
    
    return comparison_df

if __name__ == "__main__":
    # Example usage
    train_path = r"C:\Users\mvk\Documents\DATA_school\thesis\Subset\top_price_changers_subset\train.csv"
    test_path = r"C:\Users\mvk\Documents\DATA_school\thesis\Subset\top_price_changers_subset\test_feb.csv"
    output_path = r"C:\Users\mvk\Documents\DATA_school\thesis\Output\NN\new_listing_predictions.csv"
    
    # Make sure the output directory exists
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    # Run the complete pipeline
    model, predictions, metrics = run_new_listing_prediction(
        train_path=train_path,
        test_path=test_path,
        output_path=output_path,
        k_neighbors=5  # Reduce neighbor count to avoid issues with small datasets
    )

Starting new listing price prediction pipeline...
Loading data...
Loading spatial features from cache: C:\Users\mvk\Documents\DATA_school\thesis\Subset\top_price_changers_subset\cache\spatial_features_train_781ee4073c.pkl
Cache loaded successfully
Loading spatial features from cache: C:\Users\mvk\Documents\DATA_school\thesis\Subset\top_price_changers_subset\cache\spatial_features_test_9f5884e967.pkl
Cache loaded successfully
Training data size: 1312312
Validation data size: 328077
Property features dimension: 22
Amenities features dimension: 20
Using device: cpu


Epoch 1/20:   0%|          | 0/20505 [00:00<?, ?it/s]


TypeError: NewListingPricePredictor.forward() takes 3 positional arguments but 6 were given