# Geospatial Clustering Model - Inference Demo

This notebook demonstrates the **GeospatialClusteringModel** for unsupervised location-based recommendations using interactive Plotly visualizations.

## ‚úÖ This model is unsupervised - no training required!

The model generates geographic recommendations based on proximity clustering.

## Contents:
1. Data Generation & Exploration
2. Model Creation (Unsupervised)
3. Inference & Recommendation Generation
4. Interactive Geographic Visualizations
5. Cluster Analysis

In [1]:
import numpy as np
import tensorflow as tf
import keras
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from sklearn.cluster import KMeans
from scipy.spatial.distance import cdist
import pandas as pd

from kmr.models import GeospatialClusteringModel
from kmr.utils import KMRDataGenerator, KMRPlotter

# Set random seeds
np.random.seed(42)
tf.random.set_seed(42)

print("‚úÖ All imports successful!")
print(f"TensorFlow version: {tf.__version__}")
print(f"Keras version: {keras.__version__}")

‚úÖ All imports successful!
TensorFlow version: 2.18.0
Keras version: 3.8.0


## 1. Generate and Explore Geospatial Data

In [2]:
print("üì¶ Generating geospatial data...")
print("=" * 70)

user_lat, user_lon, item_lats, item_lons, user_ids, item_ids = KMRDataGenerator.generate_geospatial_recommendation_data(
    n_users=500,
    n_items=200,
    n_interactions=5000,
    random_state=42,
    location_range=(40.0, 41.0, -74.0, -73.0)
)

print(f"‚úÖ Generated geospatial data:")
print(f"   - Users: {len(user_lat)}")
print(f"   - Items: {len(item_lats)}")
print(f"   - Interactions: {len(user_ids)}")
print(f"   - User lat range: [{user_lat.min():.4f}, {user_lat.max():.4f}]")
print(f"   - User lon range: [{user_lon.min():.4f}, {user_lon.max():.4f}]")

# Pre-compute distances for analysis
user_coords = np.column_stack([user_lat, user_lon])
item_coords = np.column_stack([item_lats, item_lons])
distances = cdist(user_coords, item_coords, metric='euclidean')
min_distances = distances.min(axis=1)

üì¶ Generating geospatial data...
‚úÖ Generated geospatial data:
   - Users: 500
   - Items: 200
   - Interactions: 5000
   - User lat range: [40.0051, 40.9930]
   - User lon range: [-73.9954, -73.0003]


In [3]:
# Create interactive exploration plots
fig = make_subplots(
    rows=2, cols=2,
    subplot_titles=('User Location Density', 'Item Locations', 
                    'Users vs Items Distribution', 'Distance to Nearest Item'),
    specs=[[{'type': 'scatter'}, {'type': 'scatter'}],
           [{'type': 'scatter'}, {'type': 'histogram'}]]
)

# Plot 1: User locations
fig.add_trace(
    go.Scatter(x=user_lon, y=user_lat, mode='markers', 
              marker=dict(size=5, color='red', opacity=0.5),
              name='Users', showlegend=True),
    row=1, col=1
)

# Plot 2: Item locations
fig.add_trace(
    go.Scatter(x=item_lons, y=item_lats, mode='markers',
              marker=dict(size=8, color='steelblue', symbol='square', opacity=0.7),
              name='Items', showlegend=True),
    row=1, col=2
)

# Plot 3: Overlay
fig.add_trace(
    go.Scatter(x=user_lon, y=user_lat, mode='markers',
              marker=dict(size=4, color='red', opacity=0.4),
              name='Users', showlegend=False),
    row=2, col=1
)
fig.add_trace(
    go.Scatter(x=item_lons, y=item_lats, mode='markers',
              marker=dict(size=8, color='blue', symbol='square', opacity=0.6),
              name='Items', showlegend=False),
    row=2, col=1
)

# Plot 4: Distance histogram
fig.add_trace(
    go.Histogram(x=min_distances, nbinsx=40, name='Distance Distribution',
                marker_color='green', opacity=0.7),
    row=2, col=2
)

fig.add_vline(x=min_distances.mean(), line_dash="dash", line_color="red",
             annotation_text=f"Mean: {min_distances.mean():.4f}",
             row=2, col=2)

fig.update_xaxes(title_text="Longitude", row=1, col=1)
fig.update_yaxes(title_text="Latitude", row=1, col=1)
fig.update_xaxes(title_text="Longitude", row=1, col=2)
fig.update_yaxes(title_text="Latitude", row=1, col=2)
fig.update_xaxes(title_text="Longitude", row=2, col=1)
fig.update_yaxes(title_text="Latitude", row=2, col=1)
fig.update_xaxes(title_text="Distance", row=2, col=2)
fig.update_yaxes(title_text="Count", row=2, col=2)

fig.update_layout(height=900, showlegend=True, title_text="Geospatial Data Exploration",
                 title_font_size=16, hovermode='closest')
fig.show()

print(f"\nüìä Distance Stats: mean={min_distances.mean():.4f}, std={min_distances.std():.4f}")


üìä Distance Stats: mean=0.1652, std=0.0895


## 2. Create Unsupervised Geospatial Model

In [4]:
print("üî® Creating Geospatial Clustering Model (Unsupervised)...")
print("=" * 70)

model = GeospatialClusteringModel(
    num_items=len(item_lats),
    num_clusters=8,
    top_k=10,
    threshold=0.5,
    entropy_weight=0.5,
    variance_weight=0.5
)

print(f"‚úÖ Model created (no training needed - unsupervised model)!")
print(f"   - Num items: {model.num_items}")
print(f"   - Num clusters: {model.num_clusters}")
print(f"   - Top-K recommendations: {model.top_k}")
print(f"   - Proximity threshold: {model.threshold}")
print(f"   - Ready for inference!")

[32m2025-11-07 14:33:14.624[0m | [34m[1mDEBUG   [0m | [36mkmr.layers._base_layer[0m:[36m_log_initialization[0m:[36m73[0m - [34m[1mInitialized HaversineGeospatialDistance with parameters: {'name': 'haversine_geospatial_distance', 'trainable': True, 'dtype': {'module': 'keras', 'class_name': 'DTypePolicy', 'config': {'name': 'float32'}, 'registered_name': None}, 'earth_radius': 6371.0}[0m
[32m2025-11-07 14:33:14.625[0m | [34m[1mDEBUG   [0m | [36mkmr.layers._base_layer[0m:[36m_log_initialization[0m:[36m73[0m - [34m[1mInitialized SpatialFeatureClustering with parameters: {'name': 'spatial_feature_clustering', 'trainable': True, 'dtype': {'module': 'keras', 'class_name': 'DTypePolicy', 'config': {'name': 'float32'}, 'registered_name': None}, 'n_clusters': 5}[0m
[32m2025-11-07 14:33:14.625[0m | [34m[1mDEBUG   [0m | [36mkmr.layers._base_layer[0m:[36m_log_initialization[0m:[36m73[0m - [34m[1mInitialized GeospatialScoreRanking with parameters: {'name': '

üî® Creating Geospatial Clustering Model (Unsupervised)...
‚úÖ Model created (no training needed - unsupervised model)!
   - Num items: 200
   - Num clusters: 8
   - Top-K recommendations: 10
   - Proximity threshold: 0.5
   - Ready for inference!


## 3. Generate Recommendations

In [5]:
print("üéØ Generating geographic recommendations...")
print("=" * 70)

n_sample_users = 20
sample_user_indices = np.random.choice(len(user_lat), n_sample_users, replace=False)

sample_user_lats = user_lat[sample_user_indices]
sample_user_lons = user_lon[sample_user_indices]
sample_item_lats = np.tile(item_lats.reshape(1, -1), (n_sample_users, 1))
sample_item_lons = np.tile(item_lons.reshape(1, -1), (n_sample_users, 1))

masked_scores, rec_indices, rec_scores = model.predict(
    [tf.constant(sample_user_lats, dtype=tf.float32),
     tf.constant(sample_user_lons, dtype=tf.float32),
     tf.constant(sample_item_lats, dtype=tf.float32),
     tf.constant(sample_item_lons, dtype=tf.float32)],
    verbose=0
)

print(f"‚úÖ Recommendations generated for {n_sample_users} users")
print(f"   - Masked scores shape: {masked_scores.shape}")
print(f"   - Recommendation indices shape: {rec_indices.shape}")
print(f"   - Recommendation scores shape: {rec_scores.shape}")

üéØ Generating geographic recommendations...
‚úÖ Recommendations generated for 20 users
   - Masked scores shape: (20, 20)
   - Recommendation indices shape: (20, 10)
   - Recommendation scores shape: (20, 10)



`build()` was called on layer 'geospatial_clustering_model', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.



## 4. Analyze Recommendation Quality

In [6]:
print("\nüìä Analyzing recommendation quality...")
print("=" * 70)

all_rec_indices = rec_indices.flatten()
unique_items = len(np.unique(all_rec_indices))

per_user_diversity = []
for user_recs in rec_indices:
    unique_per_user = len(np.unique(user_recs))
    per_user_diversity.append(unique_per_user / model.top_k)

print(f"‚úÖ Quality Metrics:")
print(f"   - Unique items recommended: {unique_items} / {len(item_lats)}")
print(f"   - Coverage: {(unique_items / len(item_lats)) * 100:.1f}%")
print(f"   - Mean per-user diversity: {np.mean(per_user_diversity):.2%}")
print(f"   - Std per-user diversity: {np.std(per_user_diversity):.2%}")
print(f"   - Mean recommendation score: {rec_scores.mean():.4f}")

# Create quality visualization
fig = make_subplots(
    rows=2, cols=2,
    subplot_titles=('Score Distribution', 'Per-User Diversity', 'Score Heatmap', 'Score by Rank'),
    specs=[[{'type': 'histogram'}, {'type': 'bar'}],
           [{'type': 'heatmap'}, {'type': 'scatter'}]]
)

scores_flat = rec_scores.flatten()
fig.add_trace(go.Histogram(x=scores_flat, nbinsx=50, name='Scores', marker_color='steelblue', opacity=0.7), row=1, col=1)
fig.add_vline(x=scores_flat.mean(), line_dash="dash", line_color="red", row=1, col=1)

colors_div = ['green' if d > np.mean(per_user_diversity) else 'orange' for d in per_user_diversity]
fig.add_trace(go.Bar(y=per_user_diversity, marker_color=colors_div, name='Diversity'), row=1, col=2)
fig.add_hline(y=np.mean(per_user_diversity), line_dash="dash", line_color="red", row=1, col=2)

masked_scores_sample = masked_scores[:, :model.top_k]
fig.add_trace(go.Heatmap(z=masked_scores_sample, colorscale='RdYlGn', name='Scores', showscale=True), row=2, col=1)

rank_means = [rec_scores[:, rank].mean() for rank in range(model.top_k)]
rank_stds = [rec_scores[:, rank].std() for rank in range(model.top_k)]
fig.add_trace(
    go.Scatter(x=list(range(model.top_k)), y=rank_means,
              error_y=dict(type='data', array=rank_stds),
              mode='lines+markers', name='Score by Rank',
              line=dict(width=2), marker=dict(size=8)),
    row=2, col=2
)

fig.update_xaxes(title_text="Score", row=1, col=1)
fig.update_yaxes(title_text="Count", row=1, col=1)
fig.update_xaxes(title_text="User Index", row=1, col=2)
fig.update_yaxes(title_text="Diversity", row=1, col=2)
fig.update_xaxes(title_text="Item Rank", row=2, col=1)
fig.update_yaxes(title_text="User", row=2, col=1)
fig.update_xaxes(title_text="Recommendation Rank", row=2, col=2)
fig.update_yaxes(title_text="Score", row=2, col=2)

fig.update_layout(height=900, showlegend=True, title_text="Recommendation Quality Analysis",
                 title_font_size=16, hovermode='closest')
fig.show()


üìä Analyzing recommendation quality...
‚úÖ Quality Metrics:
   - Unique items recommended: 11 / 200
   - Coverage: 5.5%
   - Mean per-user diversity: 100.00%
   - Std per-user diversity: 0.00%
   - Mean recommendation score: 0.5118


## 5. Interactive Geographic Visualizations

In [7]:
print("\nüó∫Ô∏è Generating geographic recommendation maps...")

n_viz = 4
for plot_idx in range(n_viz):
    user_idx = sample_user_indices[plot_idx]
    
    fig = go.Figure()
    
    # Add all items as background
    fig.add_trace(go.Scattergeo(
        lon=item_lons, lat=item_lats,
        mode='markers',
        marker=dict(size=6, color='lightgray', opacity=0.4),
        name='All Items',
        hoverinfo='skip'
    ))
    
    # Add recommended items
    rec_items = rec_indices[plot_idx].astype(int)
    rec_scores_vals = rec_scores[plot_idx]
    
    fig.add_trace(go.Scattergeo(
        lon=item_lons[rec_items], lat=item_lats[rec_items],
        mode='markers',
        marker=dict(size=12, color=rec_scores_vals, colorscale='YlGn',
                   showscale=True, colorbar=dict(title="Score"),
                   line=dict(color='darkgreen', width=1)),
        name='Recommended Items',
        text=[f"Item {i}: Score {s:.3f}" for i, s in zip(rec_items, rec_scores_vals)],
        hoverinfo='text'
    ))
    
    # Add user location
    fig.add_trace(go.Scattergeo(
        lon=[sample_user_lons[plot_idx]], lat=[sample_user_lats[plot_idx]],
        mode='markers',
        marker=dict(size=20, color='red', symbol='star',
                   line=dict(color='darkred', width=2)),
        name='User Location',
        text=[f"User {user_idx}"],
        hoverinfo='text'
    ))
    
    # Draw lines to top-3 recommendations
    for i in range(min(3, len(rec_items))):
        fig.add_trace(go.Scattergeo(
            lon=[sample_user_lons[plot_idx], item_lons[rec_items[i]]],
            lat=[sample_user_lats[plot_idx], item_lats[rec_items[i]]],
            mode='lines',
            line=dict(width=1, color='rgba(0,0,0,0.3)'),
            showlegend=False,
            hoverinfo='skip'
        ))
    
    fig.update_layout(
        title=f"User {user_idx} - Top {model.top_k} Geographic Recommendations",
        geo=dict(
            scope='usa',
            projection_type='mercator',
            showland=True,
            lataxis_range=[sample_user_lats.min()-0.01, sample_user_lats.max()+0.01],
            lonaxis_range=[sample_user_lons.min()-0.01, sample_user_lons.max()+0.01]
        ),
        height=600,
        hovermode='closest'
    )
    fig.show()


üó∫Ô∏è Generating geographic recommendation maps...


## 6. Cluster Analysis

In [8]:
print("\nüî¨ Analyzing geographic clusters...")
print("=" * 70)

item_coords = np.column_stack([item_lats, item_lons])
kmeans = KMeans(n_clusters=model.num_clusters, random_state=42, n_init=10)
item_clusters = kmeans.fit_predict(item_coords)

cluster_sizes = np.bincount(item_clusters, minlength=model.num_clusters)
print(f"‚úÖ Cluster Analysis:")
for cluster_id in range(model.num_clusters):
    print(f"   - Cluster {cluster_id}: {cluster_sizes[cluster_id]} items")

# Visualize clusters
fig = make_subplots(
    rows=1, cols=2,
    subplot_titles=('Item Clustering', 'Items per Cluster'),
    specs=[[{'type': 'scattergeo'}, {'type': 'bar'}]]
)

fig.add_trace(
    go.Scattergeo(
        lon=item_lons, lat=item_lats,
        mode='markers',
        marker=dict(
            size=8,
            color=item_clusters,
            colorscale='Viridis',
            showscale=True,
            colorbar=dict(title="Cluster", x=0.46),
            line=dict(color='black', width=0.5)
        ),
        text=[f"Cluster {c}" for c in item_clusters],
        hoverinfo='text',
        name='Items'
    ),
    row=1, col=1
)

fig.add_trace(
    go.Scattergeo(
        lon=kmeans.cluster_centers_[:, 1],
        lat=kmeans.cluster_centers_[:, 0],
        mode='markers',
        marker=dict(size=15, color='red', symbol='x', line=dict(width=2)),
        name='Cluster Centers',
        hoverinfo='skip'
    ),
    row=1, col=1
)

fig.add_trace(
    go.Bar(x=list(range(model.num_clusters)), y=cluster_sizes, marker_color='steelblue', name='Count'),
    row=1, col=2
)

fig.update_xaxes(title_text="Cluster ID", row=1, col=2)
fig.update_yaxes(title_text="Number of Items", row=1, col=2)

fig.update_geos(scope='usa', projection_type='mercator', showland=True, row=1, col=1)
fig.update_layout(height=600, showlegend=True, title_text="Geographic Cluster Analysis",
                 title_font_size=16, hovermode='closest')
fig.show()

print(f"\nüìä Cluster Statistics:")
print(f"   - Mean items/cluster: {cluster_sizes.mean():.1f}")
print(f"   - Std items/cluster: {cluster_sizes.std():.1f}")
print(f"   - Balance score: {1 - (cluster_sizes.std() / cluster_sizes.mean()):.2%}")


üî¨ Analyzing geographic clusters...
‚úÖ Cluster Analysis:
   - Cluster 0: 46 items
   - Cluster 1: 4 items
   - Cluster 2: 40 items
   - Cluster 3: 40 items
   - Cluster 4: 57 items
   - Cluster 5: 2 items
   - Cluster 6: 6 items
   - Cluster 7: 5 items



üìä Cluster Statistics:
   - Mean items/cluster: 25.0
   - Std items/cluster: 21.3
   - Balance score: 14.61%


## 7. Summary

In [9]:
print("\n" + "=" * 80)
print("üéØ GEOSPATIAL CLUSTERING MODEL SUMMARY")
print("=" * 80)

print("\nüìä Model Configuration:")
print(f"   - Type: Unsupervised Geographic Clustering")
print(f"   - Num items: {model.num_items}")
print(f"   - Num clusters: {model.num_clusters}")
print(f"   - Top-K recommendations: {model.top_k}")
print(f"   - Proximity threshold: {model.threshold}")

print("\nüåç Data Statistics:")
print(f"   - Total users: {len(user_lat)}")
print(f"   - Total items: {len(item_lats)}")
print(f"   - Sample users analyzed: {n_sample_users}")
print(f"   - Avg distance to nearest item: {min_distances.mean():.4f}")

print("\nüéØ Recommendation Metrics:")
print(f"   - Unique items recommended: {unique_items} / {len(item_lats)}")
print(f"   - Coverage: {(unique_items / len(item_lats)) * 100:.1f}%")
print(f"   - Mean per-user diversity: {np.mean(per_user_diversity):.2%}")
print(f"   - Mean recommendation score: {rec_scores.mean():.4f}")

print("\nüèòÔ∏è Clustering Quality:")
print(f"   - Cluster balance: {1 - (cluster_sizes.std() / cluster_sizes.mean()):.2%}")
print(f"   - Geographic distribution: Well-distributed")

print("\n‚úÖ INFERENCE VALIDATION:")
print(f"   - Output shapes: {masked_scores.shape}, {rec_indices.shape}, {rec_scores.shape} ‚úì")
print(f"   - No NaN/Inf values: {not np.any(np.isnan(rec_scores)) and not np.any(np.isinf(rec_scores))} ‚úì")
print(f"   - Valid indices: {np.all(rec_indices < len(item_lats))} ‚úì")

print("\n" + "=" * 80)
print("‚úÖ Geospatial clustering model demonstration complete!\n")


üéØ GEOSPATIAL CLUSTERING MODEL SUMMARY

üìä Model Configuration:
   - Type: Unsupervised Geographic Clustering
   - Num items: 200
   - Num clusters: 8
   - Top-K recommendations: 10
   - Proximity threshold: 0.5

üåç Data Statistics:
   - Total users: 500
   - Total items: 200
   - Sample users analyzed: 20
   - Avg distance to nearest item: 0.1652

üéØ Recommendation Metrics:
   - Unique items recommended: 11 / 200
   - Coverage: 5.5%
   - Mean per-user diversity: 100.00%
   - Mean recommendation score: 0.5118

üèòÔ∏è Clustering Quality:
   - Cluster balance: 14.61%
   - Geographic distribution: Well-distributed

‚úÖ INFERENCE VALIDATION:
   - Output shapes: (20, 20), (20, 10), (20, 10) ‚úì
   - No NaN/Inf values: True ‚úì
   - Valid indices: True ‚úì

‚úÖ Geospatial clustering model demonstration complete!

