In [1]:
import pandas as pd
import requests
from bs4 import BeautifulSoup
import osmnx as ox
import numpy as np
import geopandas as gpd
from shapely.geometry import box
from pyproj import Transformer
import matplotlib.pyplot as plt
import re
from rasterio.features import rasterize
from rasterio.transform import from_origin
import os
import json
from pathlib import Path

class BoundingBox:
    def __init__(self, minLon, maxLon, minLat, maxLat):
        self.minLon = minLon
        self.maxLon = maxLon
        self.minLat = minLat
        self.maxLat = maxLat
        
    def __str__(self):
        return f"BoundingBox(lon: {self.minLon:.6f} to {self.maxLon:.6f}, lat: {self.minLat:.6f} to {self.maxLat:.6f})"

class CityRasterizer:
    def __init__(self, place_name=None, country=None, continent=None, bounding_box=None, grid_width=1000, grid_height=1000, gdp=None, population=None):
        """
        Initialize the city rasterizer with either a place name or bounding box.
        
        Args:
            place_name (str): Name of the place to rasterize (e.g., "Munich, Germany")
            country (str): Name of the country
            continent (str): Name of the continent
            bounding_box (BoundingBox): Geographic bounding box (alternative to place_name)
            grid_width (int): Number of cells in x-direction
            grid_height (int): Number of cells in y-direction
            gdp (float): Nominal GDP in USD billions
            population (int): Population of the city
        """
        if place_name is None and bounding_box is None:
            raise ValueError("Either place_name or bounding_box must be provided")
        
        self.place_name = place_name
        self.country = country
        self.continent = continent
        self.bounding_box = bounding_box
        self.grid_width = grid_width
        self.grid_height = grid_height
        self.grid = None
        self.buildings = None
        self.gdp = gdp
        self.population = population
        self.transformer = None
        self.transform = None
        self.cell_width = None
        self.cell_height = None
        self.x_min = None
        self.y_min = None
        self.x_max = None
        self.y_max = None
        
        if self.place_name and self.bounding_box is None:
            self._get_place_bounds()
        
        if self.bounding_box:
            self._initialize_grid()
            self._process_buildings()

    def _get_place_bounds(self):
        """Get bounding box from place name"""
        print(f"Getting bounds for {self.place_name}...")
        try:
            place_gdf = ox.geocode_to_gdf(self.place_name)
            bounds = place_gdf.total_bounds
            self.bounding_box = BoundingBox(
                minLon=bounds[0], maxLon=bounds[2],
                minLat=bounds[1], maxLat=bounds[3]
            )
            print(f"Bounding box: {self.bounding_box}")
        except Exception as e:
            print(f"Error geocoding {self.place_name}: {e}")
            raise

    def _initialize_grid(self):
        """Initialize the grid with coordinate transformations"""
        self.transformer = Transformer.from_crs("EPSG:4326", "EPSG:3857", always_xy=True)
        
        # Transform bounds to Web Mercator
        self.x_min, self.y_min = self.transformer.transform(
            self.bounding_box.minLon, self.bounding_box.minLat
        )
        self.x_max, self.y_max = self.transformer.transform(
            self.bounding_box.maxLon, self.bounding_box.maxLat
        )
        
        self.cell_width = (self.x_max - self.x_min) / self.grid_width
        self.cell_height = (self.y_max - self.y_min) / self.grid_height
        
        self.transform = from_origin(
            west=self.x_min, 
            north=self.y_max,
            xsize=self.cell_width, 
            ysize=self.cell_height
        )
        
        self.grid = np.zeros((self.grid_height, self.grid_width), dtype=np.uint8)
        
        print(f"Grid initialized: {self.grid_width}x{self.grid_height}")
        print(f"Cell size: {self.cell_width:.2f}m x {self.cell_height:.2f}m")

    def _fetch_osm_buildings(self):
        """Fetch building data from OSM with error handling"""
        print(f"Fetching building data for {self.place_name or 'custom bounding box'}...")
        
        try:
            if self.place_name:
                buildings = ox.features_from_place(self.place_name, tags={'building': True})
            else:
                buildings = ox.features_from_bbox(
                    self.bounding_box.minLat, self.bounding_box.maxLat,
                    self.bounding_box.minLon, self.bounding_box.maxLon, 
                    tags={'building': True}
                )
            
            if buildings.empty or not isinstance(buildings, gpd.GeoDataFrame):
                print(f"No valid building data returned for {self.place_name or 'custom bounding box'}.")
                return gpd.GeoDataFrame()
            
            if 'geometry' not in buildings.columns or buildings.geometry.isna().all():
                print(f"Invalid or missing geometry data for {self.place_name or 'custom bounding box'}.")
                return gpd.GeoDataFrame()
            
            return buildings
            
        except Exception as e:
            print(f"Error fetching buildings for {self.place_name or 'custom bounding box'}: {e}")
            return gpd.GeoDataFrame()

    def _process_buildings(self, min_buildings=1000):
        """Process building data, rasterize if enough buildings exist
        Args:
            min_buildings (int): Minimum number of buildings required to process
        """
        self.buildings = self._fetch_osm_buildings()
        
        if self.buildings.empty:
            print(f"No buildings found for {self.place_name or 'custom bounding box'}. Skipping.")
            self._cleanup()
            return

        try:
            building_count = len(self.buildings)
            print(f"Found {building_count} buildings in OSM data")
            
            if building_count < min_buildings:
                print(f"Skipping {self.place_name} - only {building_count} buildings found (needs {min_buildings}+)")
                self._cleanup()
                return

            print("Converting to Web Mercator projection...")
            self.buildings = self.buildings.to_crs(epsg=3857)
            
            buffer = max(self.cell_width, self.cell_height)
            bounds_box = box(self.x_min - buffer, self.y_min - buffer, 
                        self.x_max + buffer, self.y_max + buffer)
            
            self.buildings = self.buildings[self.buildings.geometry.intersects(bounds_box)]
            filtered_count = len(self.buildings)
            print(f"Buildings within bounds: {filtered_count}")
            
            if filtered_count < min_buildings:
                print(f"Skipping {self.place_name} - only {filtered_count} buildings within bounds (needs {min_buildings}+)")
                self._cleanup()
                return
                
            self._rasterize_buildings()
            self._save_data()
            
        except Exception as e:
            print(f"Error processing buildings for {self.place_name or 'custom bounding box'}: {e}")
        finally:
            self._cleanup()

    def _rasterize_buildings(self):
        """Ultra-fast rasterization using rasterio with error handling"""
        print("Performing ultra-fast rasterization with rasterio...")
        
        try:
            shapes = ((geom, 1) for geom in self.buildings.geometry if geom is not None and geom.is_valid)
            
            self.grid = rasterize(
                shapes=shapes,
                out_shape=(self.grid_height, self.grid_width),
                transform=self.transform,
                dtype=np.uint8,
                all_touched=True
            )
            
            building_cells = np.sum(self.grid)
            density = building_cells / self.grid.size
            print(f"Rasterized {building_cells} cells ({density:.2%} coverage)")
            
        except Exception as e:
            print(f"Error during rasterization for {self.place_name or 'custom bounding box'}: {e}")
            self.grid = np.zeros((self.grid_height, self.grid_width), dtype=np.uint8)

    def _save_data(self, output_dir="city_data"):
        """Save rasterization data to a city-specific subfolder in city_data"""
        if self.grid is None:
            print(f"No grid data to save for {self.place_name or 'custom bounding box'}.")
            return
        
        safe_name = "".join(c for c in (self.place_name or "CustomArea") if c.isalnum() or c in (' ', '_')).rstrip().replace(' ', '_')
        city_dir = Path(output_dir) / safe_name
        city_dir.mkdir(parents=True, exist_ok=True)
        
        building_cells = np.sum(self.grid)
        total_cells = self.grid.size
        building_density = building_cells / total_cells
        
        np.save(city_dir / "grid.npy", self.grid)
        edge_grid = self.get_edge_grid()
        np.save(city_dir / "edge_grid.npy", edge_grid)
        
        metadata = {
            "place_name": self.place_name,
            "country": self.country,
            "continent": self.continent,
            "bounding_box": {
                "minLon": self.bounding_box.minLon,
                "maxLon": self.bounding_box.maxLon,
                "minLat": self.bounding_box.minLat,
                "maxLat": self.bounding_box.maxLat
            },
            "grid_dimensions": {
                "width": self.grid_width,
                "height": self.grid_height,
                "cell_width_meters": self.cell_width if self.cell_width else 0.0,
                "cell_height_meters": self.cell_height if self.cell_height else 0.0
            },
            "economic_data": {
                "gdp_usd_billions": self.gdp,
                "population": self.population
            },
            "grid_statistics": {
                "total_cells": total_cells,
                "building_cells": int(building_cells),
                "building_density": float(building_density),
                "edge_cells": int(np.sum(edge_grid))
            }
        }
        
        # Save metadata
        with open(city_dir / "metadata.json", 'w') as f:
            json.dump(metadata, f, indent=2)
        
        print(f"Saved data to {city_dir}")

    def _cleanup(self):
        """Clear large objects to free memory"""
        self.grid = None
        self.buildings = None
        print(f"Cleaned up memory for {self.place_name or 'custom bounding box'}")

    def get_edge_grid(self):
        """Create a grid showing only building edges"""
        if self.grid is None:
            raise ValueError("No grid data available.")
        
        edge_grid = np.zeros_like(self.grid)
        
        for i in range(1, self.grid.shape[0] - 1):
            for j in range(1, self.grid.shape[1] - 1):
                if self.grid[i, j] == 1:
                    if (self.grid[i-1, j] == 0 or self.grid[i+1, j] == 0 or 
                        self.grid[i, j-1] == 0 or self.grid[i, j+1] == 0):
                        edge_grid[i, j] = 1
        
        return edge_grid

    def save_data(self, output_dir="city_data"):
        """Save rasterization data to a city-specific subfolder (public method)"""
        self._save_data(output_dir=output_dir)
        self._cleanup()

    @classmethod
    def load_data(cls, city_dir):
        """Load rasterization data from a city-specific subfolder without re-rasterizing"""
        city_dir = Path(city_dir)
        if not city_dir.exists():
            raise FileNotFoundError(f"Directory {city_dir} does not exist")

        # Load metadata
        with open(city_dir / "metadata.json", 'r') as f:
            metadata = json.load(f)

        # Create instance without running __init__
        instance = cls.__new__(cls)

        # Assign metadata fields manually
        instance.place_name = metadata.get("place_name")
        instance.country = metadata.get("country")
        instance.continent = metadata.get("continent")

        bbox = metadata["bounding_box"]
        instance.bounding_box = BoundingBox(
            minLon=bbox["minLon"],
            maxLon=bbox["maxLon"],
            minLat=bbox["minLat"],
            maxLat=bbox["maxLat"]
        )

        dims = metadata["grid_dimensions"]
        instance.grid_width = dims["width"]
        instance.grid_height = dims["height"]
        instance.cell_width = dims["cell_width_meters"]
        instance.cell_height = dims["cell_height_meters"]

        econ = metadata["economic_data"]
        instance.gdp = econ.get("gdp_usd_billions")
        instance.population = econ.get("population")

        instance.grid = np.load(city_dir / "grid.npy")

        print(f"Loaded pre-rasterized data from {city_dir}")
        return instance

    def plot(self, show_edges=False, save_arrays=False, output_dir="city_data"):
        """
        Plot the building distribution with economic data in title
        
        Args:
            show_edges (bool): Whether to show only building edges
            save_arrays (bool): Whether to save data (grids, metadata, plots)
            output_dir (str): Root directory for saved data
        """
        if self.grid is None:
            raise ValueError("No grid data available.")

        if save_arrays:
            self.save_data(output_dir=output_dir)
            return 
        
        plot_grid = self.get_edge_grid() if show_edges else self.grid
        title_suffix = " (Building Edges)" if show_edges else ""
        
        city_name = self.place_name if self.place_name else "Custom Area"
        gdp_str = f"GDP: ${self.gdp:.1f}B" if self.gdp is not None else "GDP: N/A"
        pop_str = f"Pop: {int(self.population):,}" if self.population is not None else "Pop: N/A"
        title = f"Building Distribution in {city_name}\n{gdp_str}, {pop_str}{title_suffix}"

        fig, ax = plt.subplots(figsize=(12, 12))
        im = ax.imshow(
            plot_grid,
            cmap='binary',
            interpolation='nearest',
            origin='lower',
            extent=[self.bounding_box.minLon, self.bounding_box.maxLon,
                   self.bounding_box.minLat, self.bounding_box.maxLat]
        )
        
        ax.set_title(title)
        ax.set_xlabel("Longitude")
        ax.set_ylabel("Latitude")
        
        cbar = plt.colorbar(im, ax=ax)
        cbar.set_label('Building Edge Presence' if show_edges else 'Building Presence')
        
        plt.tight_layout()
        plt.show()
        plt.close(fig)
        
        total_cells = plot_grid.size
        building_cells = np.sum(plot_grid)
        print(f"Grid statistics for {city_name}:")
        print(f"  Total cells: {total_cells}")
        print(f"  Building cells: {building_cells}")
        print(f"  Building density: {building_cells/total_cells*100:.2f}%")
        print(f"  GDP: {gdp_str}")
        print(f"  Population: {pop_str}")

    def get_grid(self):
        """Return the building grid as numpy array"""
        return self.grid.copy() if self.grid is not None else None

    def get_bounds(self):
        """Return the geographic bounds"""
        return self.bounding_box

    def get_economic_data(self):
        """Return the economic data"""
        return {'gdp': self.gdp, 'population': self.population}

class BoundingBox:
    def __init__(self, minLon, maxLon, minLat, maxLat):
        self.minLon = minLon
        self.maxLon = maxLon
        self.minLat = minLat
        self.maxLat = maxLat
        
    def __str__(self):
        return f"BoundingBox(lon: {self.minLon:.6f} to {self.maxLon:.6f}, lat: {self.minLat:.6f} to {self.maxLat:.6f})"


In [None]:
class CityCSVProcessor:
    def __init__(self, csv_path='cities.csv', output_dir='city_data', min_population=10000):
        """
        Initialize the CSV processor with path to cities.csv and output directory
        
        Args:
            csv_path (str): Path to the cities.csv file
            output_dir (str): Directory to save rasterized city data
            min_population (int): Minimum population threshold for processing
        """
        self.csv_path = csv_path
        self.output_dir = output_dir
        self.min_population = min_population
        self.df = None
        
    def load_csv(self):
        """Load and preprocess the cities.csv file"""
        try:
            self.df = pd.read_csv(self.csv_path, sep=';')
            print(f"Successfully loaded {len(self.df)} cities from {self.csv_path}")
            
            self.df['Coordinates'] = self.df['Coordinates'].str.strip()
            self.df = self.df[self.df['Population'] > self.min_population]
            print(f"Found {len(self.df)} cities with population > {self.min_population}")
            return True
        except Exception as e:
            print(f"Error loading CSV file {self.csv_path}: {e}")
            return False
    
    def process_all_cities(self, grid_width=3000, grid_height=3000):
        """
        Process all cities in the CSV file and rasterize them using OSM boundaries
        
        Args:
            grid_width (int): Number of cells in x-direction
            grid_height (int): Number of cells in y-direction
        """
        if not self.load_csv():
            return
        
        processed_cities = self.get_processed_cities()
        processed_place_names = set(city['place_name'] for city in processed_cities)
        print(f"Found {len(processed_place_names)} already processed cities in {self.output_dir}")
        
        success_count = 0
        skip_count = 0
        
        for idx, row in self.df.iterrows():
            try:
                city_name = row['Name']
                country = row['Country name EN']
                population = row['Population']
                place_name = f"{city_name}, {country}"
                
                if place_name in processed_place_names:
                    print(f"Skipping {place_name} - already processed")
                    skip_count += 1
                    continue
                
                print(f"\nProcessing {idx+1}/{len(self.df)}: {place_name}")
                print(f"  Population: {population:,}")
                
                # Let CityRasterizer handle the bounding box automatically
                rasterizer = CityRasterizer(
                    place_name=place_name,
                    country=country,
                    grid_width=grid_width,
                    grid_height=grid_height,
                    population=population
                )
                
                rasterizer.save_data(output_dir=self.output_dir)
                success_count += 1
                
            except Exception as e:
                print(f"Error processing {row['Name']}, {row['Country name EN']}: {e}")
                skip_count += 1
                continue
        
        print(f"\nFinished. Newly processed: {success_count}, Skipped: {skip_count}")
    
    def get_processed_cities(self):
        """Return a list of all successfully processed cities"""
        processed = []
        city_dirs = Path(self.output_dir).glob('*')
        
        for city_dir in city_dirs:
            if city_dir.is_dir():
                metadata_path = city_dir / 'metadata.json'
                if metadata_path.exists():
                    with open(metadata_path, 'r') as f:
                        metadata = json.load(f)
                        processed.append({
                            'place_name': metadata['place_name'],
                            'country': metadata['country'],
                            'path': str(city_dir)
                        })
        
        return processed

In [None]:

processor = CityCSVProcessor(
    csv_path='cities.csv',
    output_dir='city_data'
)

# Process with higher resolution grid
processor.process_all_cities(grid_width=3000, grid_height=3000)

# Verify processing
processed = processor.get_processed_cities()
print(f"Successfully processed {len(processed)} cities at higher resolution")

In [None]:
import os
import json
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from pathlib import Path
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from tensorflow.keras import mixed_precision
from sklearn.manifold import TSNE

# Enable mixed precision for memory efficiency
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)

class CityDataGenerator(keras.utils.Sequence):
    """Generator for 1024x1024 resolution images (resized from 3000x3000)"""
    def __init__(self, city_dirs, metadata_df, batch_size=20, image_size=(1024, 1024),
                 shuffle=True, mode='train', **kwargs):
        super().__init__(**kwargs)
        self.city_dirs = city_dirs
        self.metadata_df = metadata_df
        self.batch_size = batch_size
        self.image_size = image_size
        self.shuffle = shuffle
        self.mode = mode
        
        self.country_encoder = LabelEncoder()
        valid_countries = metadata_df['Country'].dropna().unique()
        self.country_encoder.fit(valid_countries)
        
        self.pop_scaler = StandardScaler()
        valid_populations = metadata_df[['Population']].dropna()
        self.pop_scaler.fit(valid_populations)
        
        self.on_epoch_end()
    
    def __len__(self):
        return int(np.ceil(len(self.city_dirs) / self.batch_size))
    
    def __getitem__(self, index):
        batch_dirs = self.city_dirs[index * self.batch_size:(index + 1) * self.batch_size]
        
        X_list = []
        y_country_list = []
        y_population_list = []
        
        for city_dir in batch_dirs:
            grid_path = city_dir / "grid.npy"
            if not grid_path.exists():
                continue
            
            try:
                img = np.load(grid_path).astype('float32')
                
                if img.shape != (3000, 3000):
                    continue
                
                img = (img - img.min()) / (img.max() - img.min() + 1e-6)
                
                img_tensor = tf.convert_to_tensor(img[..., np.newaxis], dtype=tf.float32)
                resized_img = tf.image.resize(img_tensor, self.image_size, method='bilinear')
                
                city_name = city_dir.name.replace('_', ', ')
                city_meta = self.metadata_df[self.metadata_df['place_name'] == city_name]
                if city_meta.empty:
                    continue
                city_meta = city_meta.iloc[0]
                
                country = city_meta['Country']
                if country not in self.country_encoder.classes_:
                    continue
                
                country_encoded = self.country_encoder.transform([country])[0]
                population_scaled = self.pop_scaler.transform(pd.DataFrame({'Population': [city_meta['Population']]}))[0]
                
                X_list.append(resized_img.numpy())
                y_country_list.append(country_encoded)
                y_population_list.append(population_scaled)
            
            except Exception:
                continue
        
        if not X_list:
            X = np.zeros((1, *self.image_size, 1), dtype=np.float32)
            y_country = np.zeros((1, len(self.country_encoder.classes_)), dtype=np.float32)
            y_population = np.zeros((1, 1), dtype=np.float32)
            return X, {'country_output': y_country, 'population_output': y_population, 'reconstruction': X}
        
        X = np.array(X_list, dtype=np.float32)
        y_country = np.array(y_country_list, dtype=int)
        y_population = np.array(y_population_list, dtype=np.float32)
        
        y_country = keras.utils.to_categorical(y_country, num_classes=len(self.country_encoder.classes_))
        
        return X, {'country_output': y_country, 'population_output': y_population, 'reconstruction': X}
    
    def on_epoch_end(self):
        self.indexes = np.arange(len(self.city_dirs))
        if self.shuffle:
            np.random.shuffle(self.indexes)
            self.city_dirs = [self.city_dirs[i] for i in self.indexes]

class ProgressiveAutoencoderModel:
    def __init__(self, num_countries, input_shape=(1024, 1024, 1)):
        self.num_countries = num_countries
        self.input_shape = input_shape
    
    def build_model(self):
        inputs = keras.Input(shape=self.input_shape)
        
        x = layers.Conv2D(16, (3, 3), strides=2, activation='relu', padding='same')(inputs)
        x = layers.BatchNormalization()(x)
        
        x = self._downsample_block(x, 32, strides=2)
        x = self._downsample_block(x, 64, strides=2)
        x = self._downsample_block(x, 128, strides=2)
        x = self._downsample_block(x, 256, strides=2)
        x = self._downsample_block(x, 512, strides=2)
        
        encoded = x
        
        pooled = layers.GlobalAveragePooling2D(name='global_avg_pool')(encoded)
        
        country_output = layers.Dense(
            self.num_countries, 
            activation='softmax',
            name='country_output',
            dtype='float32' 
        )(pooled)
        
        population_output = layers.Dense(
            1,
            name='population_output',
            dtype='float32'
        )(pooled)
        
        x = self._upsample_block(encoded, 256, strides=2)
        x = self._upsample_block(x, 128, strides=2)
        x = self._upsample_block(x, 64, strides=2)
        x = self._upsample_block(x, 32, strides=2)
        x = self._upsample_block(x, 16, strides=2)
        x = self._upsample_block(x, 16, strides=2)
        
        reconstruction = layers.Conv2D(1, (3, 3), activation='sigmoid', padding='same', name='reconstruction')(x)
        
        model = keras.Model(inputs=inputs, outputs=[country_output, population_output, reconstruction])
        return model
    
    def _downsample_block(self, x, filters, strides):
        x = layers.Conv2D(filters, (3, 3), strides=strides, 
                         activation='relu', padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.Dropout(0.3)(x)
        return x
    
    def _upsample_block(self, x, filters, strides):
        x = layers.Conv2DTranspose(filters, (3, 3), strides=strides, 
                                   activation='relu', padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.Dropout(0.3)(x)
        return x

def prepare_datasets(data_dir="city_data", test_size=0.2, val_size=0.1, batch_size=4):
    data_dir = Path(data_dir)
    city_dirs = [d for d in data_dir.iterdir() if d.is_dir()]
    metadata = []
    
    for city_dir in city_dirs:
        try:
            meta_path = city_dir / "metadata.json"
            if not meta_path.exists():
                continue
            with open(meta_path, 'r') as f:
                meta = json.load(f)
                population = meta['economic_data']['population']
                building_density = meta['grid_statistics']['building_density']
                if population is None or building_density is None:
                    continue
                if population <= 10000 or building_density <= 0.005:
                    continue
                metadata.append({
                    'place_name': meta['place_name'],
                    'Country': meta['country'],
                    'Population': population,
                    'Building_Density': building_density
                })
        except Exception:
            continue
    
    metadata_df = pd.DataFrame(metadata)
    metadata_df = metadata_df.dropna(subset=['Country', 'Population', 'Building_Density'])
    
    valid_cities = metadata_df['place_name'].values
    city_dirs = [d for d in city_dirs if d.name.replace('_', ', ') in valid_cities]
    
    train_dirs, test_dirs = train_test_split(city_dirs, test_size=test_size, random_state=42)
    train_dirs, val_dirs = train_test_split(train_dirs, test_size=val_size/(1-test_size), random_state=42)
    
    train_gen = CityDataGenerator(
        train_dirs, metadata_df, 
        batch_size=batch_size,
        image_size=(1024, 1024),
        mode='train'
    )
    val_gen = CityDataGenerator(
        val_dirs, metadata_df,
        batch_size=batch_size,
        image_size=(1024, 1024),
        mode='val'
    )
    test_gen = CityDataGenerator(
        test_dirs, metadata_df,
        batch_size=batch_size,
        image_size=(1024, 1024),
        mode='test'
    )
    
    return train_gen, val_gen, test_gen, metadata_df

def visualize_latent_embeddings(model, test_gen, num_samples=100):
    """Visualize latent embeddings using t-SNE"""
    embedding_model = keras.Model(
        inputs=model.inputs,
        outputs=model.get_layer('global_avg_pool').output
    )
    
    embeddings = []
    country_labels = []
    city_names = []
    
    samples_collected = 0
    for i in range(len(test_gen)):
        X, y = test_gen[i]
        batch_embeddings = embedding_model.predict(X, verbose=0)
        batch_countries = np.argmax(y['country_output'], axis=1)
        
        for j in range(len(X)):
            if samples_collected >= num_samples:
                break
            embeddings.append(batch_embeddings[j])
            country_labels.append(batch_countries[j])
            city_idx = test_gen.indexes[i * test_gen.batch_size + j]
            city_name = test_gen.city_dirs[city_idx].name.replace('_', ', ')
            city_names.append(city_name)
            samples_collected += 1
        
        if samples_collected >= num_samples:
            break
    
    embeddings = np.array(embeddings)
    country_labels = test_gen.country_encoder.inverse_transform(country_labels)
    
    tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embeddings)-1))
    embeddings_2d = tsne.fit_transform(embeddings)
    
    plt.figure(figsize=(10, 8))
    unique_countries = np.unique(country_labels)
    colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_countries)))
    color_map = dict(zip(unique_countries, colors))
    
    for country in unique_countries:
        mask = country_labels == country
        plt.scatter(
            embeddings_2d[mask, 0],
            embeddings_2d[mask, 1],
            label=country,
            c=[color_map[country]],
            alpha=0.6
        )
    
    plt.title("Latent Embeddings of City Grids (t-SNE)")
    plt.xlabel("t-SNE Component 1")
    plt.ylabel("t-SNE Component 2")
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.show()
    
    print("Example embeddings:")
    for i in range(min(5, len(city_names))):
        print(f"City: {city_names[i]}, Country: {country_labels[i]}")

def visualize_reconstructions(model, test_gen, num_samples=5):
    """Visualize original vs reconstructed images to see learned structures"""
    originals = []
    reconstructions = []
    city_names = []
    
    samples_collected = 0
    for i in range(len(test_gen)):
        X, y = test_gen[i]
        preds = model.predict(X, verbose=0)
        batch_recons = preds[2]
        
        for j in range(len(X)):
            if samples_collected >= num_samples:
                break
            originals.append(X[j].squeeze())
            reconstructions.append(batch_recons[j].squeeze())
            city_idx = test_gen.indexes[i * test_gen.batch_size + j]
            city_name = test_gen.city_dirs[city_idx].name.replace('_', ', ')
            city_names.append(city_name)
            samples_collected += 1
        
        if samples_collected >= num_samples:
            break
    
    n = len(originals)
    plt.figure(figsize=(20, 4))
    for i in range(n):
        ax = plt.subplot(2, n, i + 1)
        plt.imshow(originals[i], cmap='viridis')
        plt.title(f"Original: {city_names[i]}")
        plt.axis('off')
        
        ax = plt.subplot(2, n, i + 1 + n)
        plt.imshow(reconstructions[i], cmap='viridis')
        plt.title("Reconstructed")
        plt.axis('off')
    plt.tight_layout()
    plt.show()

def load_and_resize_images(data_dir="city_data", image_size=(1024, 1024), num_images=3):
    """Load and resize the first num_images valid grid.npy files."""
    data_dir = Path(data_dir)
    city_dirs = [d for d in data_dir.iterdir() if d.is_dir()]
    
    images = []
    city_names = []
    count = 0
    
    for city_dir in city_dirs:
        if count >= num_images:
            break
            
        grid_path = city_dir / "grid.npy"
        if not grid_path.exists():
            continue
            
        try:
            img = np.load(grid_path).astype('float32')
            if img.shape != (3000, 3000):
                continue
                
            img = (img - img.min()) / (img.max() - img.min() + 1e-6)
            
            img_tensor = tf.convert_to_tensor(img[..., np.newaxis], dtype=tf.float32)
            resized_img = tf.image.resize(img_tensor, image_size, method='bilinear')
            resized_img = resized_img.numpy().squeeze()
            
            images.append(resized_img)
            city_names.append(city_dir.name.replace('_', ', '))
            count += 1
            
        except Exception:
            continue
    
    return images, city_names

def plot_images(images, city_names):
    """Plot the images with city names as titles."""
    plt.figure(figsize=(15, 5))
    for i, (img, name) in enumerate(zip(images, city_names)):
        plt.subplot(1, 3, i + 1)
        plt.imshow(img, cmap='viridis')
        plt.title(name, fontsize=12)
        plt.axis('off')
    plt.tight_layout()
    plt.show()

def train_model():
    train_gen, val_gen, test_gen, metadata_df = prepare_datasets(batch_size=4)
    
    if len(train_gen.city_dirs) == 0 or len(val_gen.city_dirs) == 0:
        raise ValueError("No valid training or validation data found. Check city_data directory and metadata.")
    
    num_countries = len(metadata_df['Country'].unique())
    model_builder = ProgressiveAutoencoderModel(num_countries, input_shape=(1024, 1024, 1))
    model = model_builder.build_model()
    
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=0.0001),
        loss={
            'country_output': 'categorical_crossentropy',
            'population_output': 'mse',
            'reconstruction': 'mse'
        },
        metrics={
            'country_output': 'accuracy',
            'population_output': ['mae', tf.keras.metrics.RootMeanSquaredError()],
            'reconstruction': ['mae']
        },
        loss_weights={
            'country_output': 0.5,
            'population_output': 0.05,
            'reconstruction': 1.0
        }
    )
    
    callbacks = [
        keras.callbacks.ModelCheckpoint(
            'best_model_1024_autoencoder.keras',
            save_best_only=True,
            monitor='val_loss',
            mode='min'
        ),
        keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=10,
            restore_best_weights=True
        ),
        keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=3
        )
    ]
    
    history = model.fit(
        train_gen,
        validation_data=val_gen,
        epochs=50,
        callbacks=callbacks,
        verbose=1
    )
    
    test_results = model.evaluate(test_gen, verbose=1)
    print(f"Test results - Loss: {test_results[0]}, Country Accuracy: {test_results[4]}, Population MAE: {test_results[5]}, Reconstruction MAE: {test_results[6]}")
    
    return model, history, test_gen

# Main execution
if __name__ == "__main__":
    try:
        images, city_names = load_and_resize_images(data_dir="city_data", image_size=(1024, 1024), num_images=3)
        if not images:
            print("No valid images found in the city_data directory.")
        else:
            print(f"Displaying {len(images)} resized images:")
            plot_images(images, city_names)
        
        model, history, test_gen = train_model()
        visualize_latent_embeddings(model, test_gen, num_samples=100)
        visualize_reconstructions(model, test_gen, num_samples=5)
    except Exception as e:
        print(f"Error during execution: {e}")
        raise

In [3]:
def plot_training_history(history):
    """Plot training vs validation loss and accuracy from model history."""
    hist = history.history
    
    # --- Loss curves ---
    plt.figure(figsize=(14,5))
    
    plt.subplot(1,2,1)
    plt.plot(hist['loss'], label='Training Loss')
    plt.plot(hist['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training vs Validation Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)

    # --- Country accuracy curves (classification branch only) ---
    if 'country_output_accuracy' in hist and 'val_country_output_accuracy' in hist:
        plt.subplot(1,2,2)
        plt.plot(hist['country_output_accuracy'], label='Training Accuracy')
        plt.plot(hist['val_country_output_accuracy'], label='Validation Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.title('Country Classification Accuracy')
        plt.legend()
        plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

In [None]:
plot_training_history(history)

In [None]:

from sklearn.cluster import KMeans

def get_all_embeddings(model, data_gen):
    """Extract latent embeddings + metadata for all cities in a generator."""
    embedding_model = keras.Model(
        inputs=model.inputs,
        outputs=model.get_layer('global_avg_pool').output
    )
    
    embeddings = []
    city_names = []
    countries = []
    originals = []
    reconstructions = []

    for i in range(len(data_gen)):
        X, y = data_gen[i]
        emb = embedding_model.predict(X, verbose=0)
        preds = model.predict(X, verbose=0)
        recons = preds[2]

        embeddings.append(emb)
        countries.extend(test_gen.country_encoder.inverse_transform(np.argmax(y['country_output'], axis=1)))
        city_names.extend([d.name.replace('_', ', ') for d in data_gen.city_dirs[i*data_gen.batch_size:(i+1)*data_gen.batch_size]])
        originals.extend(X)
        reconstructions.extend(recons)

    embeddings = np.vstack(embeddings)
    originals = np.array(originals)
    reconstructions = np.array(reconstructions)

    return embeddings, np.array(city_names), np.array(countries), originals, reconstructions

def run_kmeans_on_embeddings(embeddings, countries, n_clusters=2):
    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
    cluster_labels = kmeans.fit_predict(embeddings)

    df = pd.DataFrame({
        "Country": countries,
        "Cluster": cluster_labels
    })
    print(df.groupby(["Cluster", "Country"]).size().unstack(fill_value=0))
    return cluster_labels

def visualize_clusters(originals, reconstructions, city_names, cluster_labels, n_samples=5):
    unique_clusters = np.unique(cluster_labels)
    for cluster in unique_clusters:
        idxs = np.where(cluster_labels == cluster)[0][:n_samples]
        
        print(f"\nCluster {cluster} examples:")
        plt.figure(figsize=(15, 4))
        for j, idx in enumerate(idxs):
            # Original
            ax = plt.subplot(2, n_samples, j+1)
            plt.imshow(originals[idx].squeeze(), cmap='viridis')
            plt.title(city_names[idx], fontsize=10)
            plt.axis("off")

            # Reconstruction
            ax = plt.subplot(2, n_samples, j+1+n_samples)
            plt.imshow(reconstructions[idx].squeeze(), cmap='viridis')
            plt.title("Reconstructed", fontsize=10)
            plt.axis("off")
        plt.suptitle(f"Cluster {cluster}", fontsize=14)
        plt.tight_layout()
        plt.show()


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans, SpectralClustering
from sklearn.decomposition import PCA
from tensorflow import keras

def prepare_datasets_all(data_dir="city_data", batch_size=4):
    """Prepare a single dataset generator for all valid cities in city_data."""
    data_dir = Path(data_dir)
    city_dirs = [d for d in data_dir.iterdir() if d.is_dir()]
    metadata = []
    
    for city_dir in city_dirs:
        try:
            meta_path = city_dir / "metadata.json"
            if not meta_path.exists():
                continue
            with open(meta_path, 'r') as f:
                meta = json.load(f)
                population = meta['economic_data']['population']
                building_density = meta['grid_statistics']['building_density']
                if population is None or building_density is None:
                    continue
                metadata.append({
                    'place_name': meta['place_name'],
                    'Country': meta['country'],
                    'Population': population,
                    'Building_Density': building_density
                })
        except Exception:
            continue
    
    metadata_df = pd.DataFrame(metadata)
    metadata_df = metadata_df.dropna(subset=['Country', 'Population', 'Building_Density'])
    
    valid_cities = metadata_df['place_name'].values
    city_dirs = [d for d in city_dirs if d.name.replace('_', ', ') in valid_cities]
    
    all_gen = CityDataGenerator(
        city_dirs, metadata_df,
        batch_size=batch_size,
        image_size=(1024, 1024),
        mode='all',
        shuffle=False
    )
    
    return all_gen, metadata_df

def get_all_embeddings(model, all_gen, num_samples=None):
    """
    Extract embeddings, city names, countries, originals, and reconstructions for all cities.
    If num_samples is specified, limit to that number; otherwise, process all.
    """
    embedding_model = keras.Model(
        inputs=model.inputs,
        outputs=model.get_layer('global_avg_pool').output
    )
    
    embeddings = []
    city_names = []
    countries = []
    originals = []
    reconstructions = []
    
    samples_collected = 0
    max_samples = num_samples if num_samples is not None else len(all_gen.city_dirs)
    
    for i in range(len(all_gen)):
        X, y = all_gen[i]
        batch_embeddings = embedding_model.predict(X, verbose=0)
        batch_recons = model.predict(X, verbose=0)[2]
        batch_countries = np.argmax(y['country_output'], axis=1)
        
        for j in range(len(X)):
            if samples_collected >= max_samples:
                break
            embeddings.append(batch_embeddings[j])
            originals.append(X[j].squeeze())
            reconstructions.append(batch_recons[j].squeeze())
            city_idx = all_gen.indexes[i * all_gen.batch_size + j]
            city_name = all_gen.city_dirs[city_idx].name.replace('_', ', ')
            city_names.append(city_name)
            countries.append(all_gen.country_encoder.inverse_transform([batch_countries[j]])[0])
            samples_collected += 1
        
        if samples_collected >= max_samples:
            break
    
    embeddings = np.array(embeddings)
    originals = np.array(originals)
    reconstructions = np.array(reconstructions)
    city_names = np.array(city_names)
    countries = np.array(countries)
    
    return embeddings, city_names, countries, originals, reconstructions


In [203]:
all_gen, metadata_df = prepare_datasets_all(batch_size=4)

embeddings, city_names, countries, originals, reconstructions = get_all_embeddings(
    model, all_gen, num_samples=None
)


In [206]:
len(embeddings)

4638

In [None]:
import pandas as pd

df = pd.read_csv("oecd_cities.csv")

gdp_df = df[df["VAR"] == "GDP_PC_REAL_PPP"].copy()

gdp_df = gdp_df[["Metropolitan areas", "TIME_PERIOD", "OBS_VALUE"]]
gdp_df.rename(columns={
    "Metropolitan areas": "City",
    "TIME_PERIOD": "Year",
    "OBS_VALUE": "GDP_per_capita"
}, inplace=True)

gdp_df

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression, RANSACRegressor
from sklearn.manifold import TSNE
from scipy.stats import pearsonr, spearmanr, f_oneway, kruskal
import seaborn as sns
from pathlib import Path
import json
import requests
from fuzzywuzzy import fuzz, process
import pandas as pd
from statsmodels.stats.multicomp import MultiComparison
from statsmodels.multivariate.manova import MANOVA
from scipy import stats
from scipy import ndimage
from collections import Counter

def load_oecd_gdp_data(csv_path="oecd_cities.csv"):
    """Load OECD GDP data and return as dictionary for fast lookup"""
    try:
        df = pd.read_csv(csv_path)
        gdp_df = df[df["VAR"] == "GDP_PC_REAL_PPP"].copy()
        gdp_df = gdp_df[["Metropolitan areas", "TIME_PERIOD", "OBS_VALUE"]]
        gdp_df.rename(columns={
            "Metropolitan areas": "City",
            "TIME_PERIOD": "Year", 
            "OBS_VALUE": "GDP_per_capita"
        }, inplace=True)
        gdp_df = gdp_df.loc[gdp_df.groupby('City')['Year'].idxmax()]
        return dict(zip(gdp_df['City'].str.strip().str.lower(), gdp_df['GDP_per_capita']))
    except FileNotFoundError:
        print(f"Error: {csv_path} not found")
        return {}

def get_population_from_metadata(city_name, data_dir="city_data"):
    """Get population for a single city from metadata"""
    data_dir = Path(data_dir)
    city_dir = data_dir / city_name.replace(', ', '_')
    meta_path = city_dir / "metadata.json"
    try:
        with open(meta_path, 'r') as f:
            meta = json.load(f)
        population = meta['economic_data']['population']
        return population if population is not None else np.nan
    except:
        return np.nan

def get_building_density(city_name, data_dir="city_data"):
    """Get building_density for a single city from metadata"""
    data_dir = Path(data_dir)
    city_dir = data_dir / city_name.replace(', ', '_')
    meta_path = city_dir / "metadata.json"
    try:
        with open(meta_path, 'r') as f:
            meta = json.load(f)
        density = meta['grid_statistics']['building_density']
        return density if density is not None else np.nan
    except:
        return np.nan

def get_gdp_for_cities(city_names, countries, oecd_gdp_dict):
    """Get GDP data for cities using OECD data and country fallbacks"""
    gdp_values = []
    print("Fetching country-level GDP data...")
    try:
        url = "https://api.worldbank.org/v2/country/all/indicator/NY.GDP.PCAP.PP.KD?format=json&date=2023&per_page=300"
        response = requests.get(url, timeout=10)
        response.raise_for_status()
        data = response.json()[1]
        country_gdp = {entry['country']['value'].lower(): entry['value'] 
                       for entry in data if entry['value'] is not None}
    except:
        print("Warning: Could not fetch country GDP data")
        country_gdp = {}
    
    for city, country in zip(city_names, countries):
        city_clean = city.replace(',.*', '').strip().lower()
        if city_clean in oecd_gdp_dict:
            gdp_values.append(oecd_gdp_dict[city_clean])
        else:
            matches = [(k, fuzz.token_sort_ratio(city_clean, k)) for k in oecd_gdp_dict.keys()]
            best_match = max(matches, key=lambda x: x[1]) if matches else (None, 0)
            if best_match[1] >= 80:
                gdp_values.append(oecd_gdp_dict[best_match[0]])
                print(f"Fuzzy matched {city} to {best_match[0]} (score: {best_match[1]})")
            else:
                if country.lower() in country_gdp:
                    gdp_values.append(country_gdp[country.lower()])
                    print(f"Using country GDP for {city} ({country})")
                else:
                    gdp_values.append(np.nan)
                    print(f"No GDP data found for {city} ({country})")
    return np.array(gdp_values, dtype=np.float64)

def validate_inputs(embeddings, city_names, countries, originals, reconstructions, n_clusters):
    """Validate input variables and check for numerical issues"""
    if not all([embeddings is not None, city_names is not None, countries is not None, 
                originals is not None, reconstructions is not None]):
        raise NameError("One or more input variables are not defined")
    if len(embeddings) != len(city_names) or len(embeddings) != len(countries) or \
       len(embeddings) != len(originals) or len(embeddings) != len(reconstructions):
        raise ValueError("Input arrays must have the same length")
    if n_clusters < 2 or n_clusters > len(embeddings):
        raise ValueError(f"n_clusters must be between 2 and {len(embeddings)}")
    
    if np.any(~np.isfinite(embeddings)):
        raise ValueError("Embeddings contain NaN or infinite values")
    if np.any(np.abs(embeddings) > 1e15):
        print("Warning: Embeddings contain very large values, converting to float64")
        embeddings = embeddings.astype(np.float64)
    
    return embeddings

def cluster_cities_memory_efficient(embeddings, city_names, countries, originals, reconstructions, 
                                  n_clusters=7, data_dir="city_data"):
    """Perform clustering with filter for building_density >= 0.04"""
    embeddings = validate_inputs(embeddings, city_names, countries, originals, reconstructions, n_clusters)
    print(f"Starting analysis for {len(city_names)} cities...")
    
    # Load data
    populations = np.array([get_population_from_metadata(city, data_dir) for city in city_names], dtype=np.float64)
    building_densities = np.array([get_building_density(city, data_dir) for city in city_names], dtype=np.float64)
    oecd_gdp_dict = load_oecd_gdp_data()
    gdp_values = get_gdp_for_cities(city_names, countries, oecd_gdp_dict)
    
    # Check for large or infinite values
    for name, arr in [("populations", populations), ("building_densities", building_densities), ("gdp_values", gdp_values)]:
        if np.any(~np.isfinite(arr)):
            print(f"Warning: {name} contains NaN or infinite values")
        if np.any(np.abs(arr) > 1e15):
            print(f"Warning: {name} contains very large values, converting to float64")
            if name == "populations":
                populations = arr.astype(np.float64)
            elif name == "building_densities":
                building_densities = arr.astype(np.float64)
            elif name == "gdp_values":
                gdp_values = arr.astype(np.float64)
    
    # Diagnostic: Print ranges
    print(f"Embeddings range: min={np.min(embeddings):.2e}, max={np.max(embeddings):.2e}")
    print(f"Populations range: min={np.min(populations):.2e}, max={np.max(populations):.2e}")
    print(f"Building density range: min={np.min(building_densities):.2e}, max={np.max(building_densities):.2e}")
    print(f"GDP values range: min={np.min(gdp_values):.2e}, max={np.max(gdp_values):.2e}")
    
    # Filter for valid data and building_density >= 0.04
    valid_mask = ~(np.isnan(populations) | np.isinf(populations) | 
                   np.isnan(gdp_values) | np.isinf(gdp_values) | 
                   np.isnan(building_densities) | np.isinf(building_densities)) & (building_densities >= 0.001)
    
    print(f"Cities with valid population, GDP, and building_density >= 0.04: {np.sum(valid_mask)} out of {len(city_names)}")
    
    # Print invalid cities for debugging
    invalid_mask = ~valid_mask
    if np.any(invalid_mask):
        print("Cities with invalid data or building_density < 0.04:")
        for i in np.where(invalid_mask)[0]:
            print(f"  {city_names[i]}: pop={populations[i]}, gdp={gdp_values[i]}, density={building_densities[i]}")
    
    valid_embeddings = embeddings[valid_mask]
    valid_city_names = [city_names[i] for i in np.where(valid_mask)[0]]
    valid_countries = [countries[i] for i in np.where(valid_mask)[0]]
    valid_populations = populations[valid_mask]
    valid_gdp = gdp_values[valid_mask]
    valid_originals = [originals[i] for i in np.where(valid_mask)[0]]
    valid_reconstructions = [reconstructions[i] for i in np.where(valid_mask)[0]]
    
    print(f"\nFinal dataset: {len(valid_city_names)} cities")
    
    densities, avg_sizes, num_components = extract_urban_metrics(valid_originals)
    
    urban_valid_mask = np.isfinite(densities) & np.isfinite(avg_sizes) & ~np.isinf(densities) & ~np.isinf(avg_sizes)
    print(f"Cities with valid urban metrics: {np.sum(urban_valid_mask)} out of {len(valid_city_names)}")
    
    final_embeddings = valid_embeddings[urban_valid_mask].astype(np.float64)
    final_city_names = [valid_city_names[i] for i in np.where(urban_valid_mask)[0]]
    final_countries = [valid_countries[i] for i in np.where(urban_valid_mask)[0]]
    final_populations = valid_populations[urban_valid_mask].astype(np.float64)
    final_gdp = valid_gdp[urban_valid_mask].astype(np.float64)
    final_originals = [valid_originals[i] for i in np.where(urban_valid_mask)[0]]
    final_reconstructions = [valid_reconstructions[i] for i in np.where(urban_valid_mask)[0]]
    densities = densities[urban_valid_mask].astype(np.float64)
    avg_sizes = avg_sizes[urban_valid_mask].astype(np.float64)
    num_components = num_components[urban_valid_mask].astype(np.float64)
    
    if np.any(~np.isfinite(final_embeddings)):
        raise ValueError("Final embeddings contain NaN or infinite values")
    
    features_to_cluster = np.column_stack([final_embeddings, final_populations.reshape(-1, 1), 
                                          densities.reshape(-1, 1), avg_sizes.reshape(-1, 1)])
    scaler = StandardScaler()
    try:
        features_scaled = scaler.fit_transform(features_to_cluster.astype(np.float64))
    except ValueError as e:
        print(f"Error in StandardScaler: {e}")
        print(f"Features range: min={np.min(features_to_cluster, axis=0)}, max={np.max(features_to_cluster, axis=0)}")
        raise
    
    print(f"Performing K-means clustering with {n_clusters} clusters...")
    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
    cluster_labels = kmeans.fit_predict(features_scaled)
    
    silhouette_avg = silhouette_score(features_scaled, cluster_labels)
    print(f"Silhouette Score: {silhouette_avg:.3f}")
    
    original_indices = np.where(valid_mask)[0][urban_valid_mask]
    
    return (original_indices, cluster_labels, final_city_names, final_countries, 
            final_populations, final_gdp, final_originals, final_reconstructions, 
            features_scaled, kmeans, final_embeddings)

def analyze_correlations(embeddings, populations, gdp_values):
    """Compute Pearson, Spearman, and partial correlations"""
    correlations = {'pearson': {}, 'spearman': {}, 'partial': {}}
    
    if np.any(~np.isfinite(populations)) or np.any(~np.isfinite(gdp_values)):
        raise ValueError("Populations or GDP values contain NaN or infinite values")
    
    pearson_pop_gdp, p_pop_gdp = pearsonr(populations, gdp_values)
    spearman_pop_gdp, sp_pop_gdp = spearmanr(populations, gdp_values)
    correlations['pearson']['population_gdp'] = (pearson_pop_gdp, p_pop_gdp)
    correlations['spearman']['population_gdp'] = (spearman_pop_gdp, sp_pop_gdp)
    
    for i in range(embeddings.shape[1]):
        pearson_corr, p_value = pearsonr(embeddings[:, i], gdp_values)
        spearman_corr, sp_value = spearmanr(embeddings[:, i], gdp_values)
        correlations['pearson'][f'embedding_{i}'] = (pearson_corr, p_value)
        correlations['spearman'][f'embedding_{i}'] = (spearman_corr, sp_value)
    
    df = pd.DataFrame(np.column_stack([embeddings, populations, gdp_values]), 
                      columns=[f'emb_{i}' for i in range(embeddings.shape[1])] + ['population', 'gdp'])
    for i in range(embeddings.shape[1]):
        partial_corr = df[[f'emb_{i}', 'gdp']].corr().iloc[0, 1]
        correlations['partial'][f'embedding_{i}'] = (partial_corr, None)
    
    print("\nCorrelation Analysis:")
    print(f"Population vs GDP: Pearson r={pearson_pop_gdp:.3f}, p={p_pop_gdp:.3f}")
    print(f"Population vs GDP: Spearman r={spearman_pop_gdp:.3f}, p={sp_pop_gdp:.3f}")
    print("Significant embedding correlations (p < 0.05):")
    for i in range(embeddings.shape[1]):
        if correlations['pearson'][f'embedding_{i}'][1] < 0.05:
            print(f"Embedding dim {i}: Pearson r={correlations['pearson'][f'embedding_{i}'][0]:.3f}, p={correlations['pearson'][f'embedding_{i}'][1]:.3f}")
    return correlations

def incremental_regression_analysis(embeddings, populations, gdp_values):
    """Quantify incremental predictive power of embeddings and population"""
    if np.any(~np.isfinite(embeddings)) or np.any(~np.isfinite(populations)) or np.any(~np.isfinite(gdp_values)):
        raise ValueError("Inputs to incremental regression contain NaN or infinite values")
    
    scaler = StandardScaler()
    try:
        embeddings_scaled = scaler.fit_transform(embeddings)
        populations_scaled = scaler.fit_transform(populations.reshape(-1, 1))
    except ValueError as e:
        print(f"Error in scaling for incremental regression: {e}")
        raise
    
    reg_pop = LinearRegression().fit(populations_scaled, gdp_values)
    r2_pop = reg_pop.score(populations_scaled, gdp_values)
    
    reg_emb = LinearRegression().fit(embeddings_scaled, gdp_values)
    r2_emb = reg_emb.score(embeddings_scaled, gdp_values)
    
    X_full = np.column_stack([embeddings_scaled, populations_scaled])
    reg_full = LinearRegression().fit(X_full, gdp_values)
    r2_full = reg_full.score(X_full, gdp_values)
    
    print("\nIncremental Regression Analysis:")
    print(f"R² (Population only): {r2_pop:.3f}")
    print(f"R² (Embeddings only): {r2_emb:.3f}")
    print(f"R² (Full model): {r2_full:.3f}")
    print(f"Incremental R² from embeddings (beyond population): {r2_full - r2_pop:.3f}")
    print(f"Incremental R² from population (beyond embeddings): {r2_full - r2_emb:.3f}")
    
    return reg_pop, reg_emb, reg_full, r2_pop, r2_emb, r2_full

def plot_incremental_r2(r2_pop, r2_emb, r2_full):
    """Plot bar chart of incremental R² values"""
    labels = ['Population Only', 'Embeddings Only', 'Full Model']
    r2_values = [r2_pop, r2_emb, r2_full]
    
    plt.figure(figsize=(10, 6))
    bars = plt.bar(labels, r2_values)
    plt.ylabel('R² Score')
    plt.title('Incremental Predictive Power for GDP Prediction')
    plt.ylim(0, 1)
    
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                f'{height:.3f}', ha='center', va='bottom')
    
    plt.text(2, r2_full - 0.05, f'Δ from Pop: +{r2_full - r2_pop:.3f}', ha='center', va='top', color='blue')
    plt.text(2, r2_full - 0.10, f'Δ from Emb: +{r2_full - r2_emb:.3f}', ha='center', va='top', color='green')
    
    plt.tight_layout()
    plt.show()

def extract_urban_metrics(originals):
    """Extract urban metrics from original images (assuming binary building footprints)"""
    densities = []
    avg_sizes = []
    num_components = []
    
    for i, img in enumerate(originals):
        if np.any(~np.isfinite(img)):
            print(f"Warning: Image {i} contains NaN or infinite values, skipping")
            densities.append(np.nan)
            avg_sizes.append(np.nan)
            num_components.append(np.nan)
            continue
        
        img_binary = (img > 0.5).astype(np.float64)
        density = np.mean(img_binary)
        densities.append(density)
        
        labeled, num = ndimage.label(img_binary)
        sizes = ndimage.sum(img_binary, labeled, range(1, num + 1))
        avg_size = np.mean(sizes) if num > 0 else 0
        avg_sizes.append(avg_size)
        num_components.append(num)
    
    return np.array(densities, dtype=np.float64), np.array(avg_sizes, dtype=np.float64), np.array(num_components, dtype=np.float64)

def analyze_urban_metrics(densities, avg_sizes, num_components, gdp_values, populations):
    """Analyze correlations of urban metrics with GDP and population"""
    metrics = {
        'density': densities,
        'avg_building_size': avg_sizes,
        'num_building_clusters': num_components
    }
    
    print("\nUrban Metrics Analysis:")
    for name, values in metrics.items():
        valid_mask = np.isfinite(values) & np.isfinite(gdp_values)
        if np.sum(valid_mask) > 1:
            pearson_r, p_val = pearsonr(values[valid_mask], gdp_values[valid_mask])
            print(f"{name.capitalize()} vs GDP: Pearson r={pearson_r:.3f}, p={p_val:.3f}")
        else:
            print(f"{name.capitalize()} vs GDP: Insufficient valid data")
        
        valid_mask = np.isfinite(values) & np.isfinite(populations)
        if np.sum(valid_mask) > 1:
            pearson_pop, p_pop = pearsonr(values[valid_mask], populations[valid_mask])
            print(f"{name.capitalize()} vs Population: Pearson r={pearson_pop:.3f}, p={p_pop:.3f}")
        else:
            print(f"{name.capitalize()} vs Population: Insufficient valid data")
    
    return metrics

def plot_urban_metrics_vs_gdp(densities, avg_sizes, num_components, gdp_values):
    """Scatter plots of urban metrics vs GDP"""
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    valid_mask = np.isfinite(densities) & np.isfinite(gdp_values)
    axes[0].scatter(densities[valid_mask], gdp_values[valid_mask], alpha=0.5)
    axes[0].set_xlabel('Building Density')
    axes[0].set_ylabel('GDP per capita ($)')
    axes[0].set_title('Building Density vs GDP')
    
    valid_mask = np.isfinite(avg_sizes) & np.isfinite(gdp_values)
    axes[1].scatter(avg_sizes[valid_mask], gdp_values[valid_mask], alpha=0.5)
    axes[1].set_xlabel('Average Building Size')
    axes[1].set_title('Avg Building Size vs GDP')
    
    valid_mask = np.isfinite(num_components) & np.isfinite(gdp_values)
    axes[2].scatter(num_components[valid_mask], gdp_values[valid_mask], alpha=0.5)
    axes[2].set_xlabel('Number of Building Clusters')
    axes[2].set_title('Building Clusters vs GDP')
    
    plt.tight_layout()
    plt.show()

def dimensionality_reduction_visualization(embeddings, gdp_values, populations, cluster_labels):
    """Perform t-SNE and visualize embeddings"""
    if np.any(~np.isfinite(embeddings)):
        raise ValueError("Embeddings contain NaN or infinite values for t-SNE")
    
    try:
        tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embeddings)-1))
        embeddings_2d = tsne.fit_transform(embeddings)
    except ValueError as e:
        print(f"Error in t-SNE: {e}")
        return None
    
    plt.figure(figsize=(12, 10))
    valid_mask = np.isfinite(gdp_values) & np.isfinite(populations)
    scatter = plt.scatter(embeddings_2d[valid_mask, 0], embeddings_2d[valid_mask, 1], 
                          c=gdp_values[valid_mask], cmap='viridis', s=populations[valid_mask]/1e5, alpha=0.7)
    plt.colorbar(scatter, label='GDP per capita ($)')
    plt.title('t-SNE Visualization of Embeddings (Colored by GDP, Sized by Population)')
    plt.xlabel('t-SNE Dimension 1')
    plt.ylabel('t-SNE Dimension 2')
    plt.grid(True, alpha=0.3)
    plt.show()
    
    plt.figure(figsize=(12, 10))
    plt.scatter(embeddings_2d[valid_mask, 0], embeddings_2d[valid_mask, 1], 
                c=cluster_labels[valid_mask], cmap='tab10', s=populations[valid_mask]/1e5, alpha=0.7)
    plt.title('t-SNE Visualization of Embeddings (Colored by Clusters, Sized by Population)')
    plt.xlabel('t-SNE Dimension 1')
    plt.ylabel('t-SNE Dimension 2')
    plt.grid(True, alpha=0.3)
    plt.show()
    
    return embeddings_2d

def outlier_robustness_analysis(embeddings, populations, gdp_values, city_names):
    """Perform outlier detection and robustness check with RANSAC"""
    if np.any(~np.isfinite(embeddings)) or np.any(~np.isfinite(populations)) or np.any(~np.isfinite(gdp_values)):
        raise ValueError("Inputs to outlier analysis contain NaN or infinite values")
    
    X = np.column_stack([embeddings, populations.reshape(-1, 1)])
    scaler = StandardScaler()
    try:
        X_scaled = scaler.fit_transform(X)
    except ValueError as e:
        print(f"Error in scaling for outlier analysis: {e}")
        raise
    
    reg_ridge = LinearRegression().fit(X_scaled, gdp_values)
    residuals_ridge = gdp_values - reg_ridge.predict(X_scaled)
    r2_ridge = reg_ridge.score(X_scaled, gdp_values)
    
    reg_ransac = RANSACRegressor(random_state=42).fit(X_scaled, gdp_values)
    inlier_mask = reg_ransac.inlier_mask_
    r2_ransac = reg_ransac.score(X_scaled, gdp_values)
    
    print("\nOutlier Analysis:")
    print(f"Number of outliers detected: {len(np.where(~inlier_mask)[0])}")
    print(f"R² (Standard Linear): {r2_ridge:.3f}")
    print(f"R² (Robust RANSAC): {r2_ransac:.3f}")
    if len(np.where(~inlier_mask)[0]) > 0:
        print("Top outliers by residual magnitude:")
        outlier_residuals = np.abs(residuals_ridge[~inlier_mask])
        sorted_outliers = np.where(~inlier_mask)[0][np.argsort(outlier_residuals)[::-1]][:5]
        for idx in sorted_outliers:
            print(f"  {city_names[idx]}: Residual = {residuals_ridge[idx]:.2f}, GDP = {gdp_values[idx]:.2f}")
    
    plt.figure(figsize=(12, 8))
    plt.scatter(gdp_values, residuals_ridge, c='blue', alpha=0.5, label='All Points')
    plt.scatter(gdp_values[~inlier_mask], residuals_ridge[~inlier_mask], c='red', alpha=0.7, label='Outliers')
    plt.axhline(0, color='black', linestyle='--', alpha=0.5)
    plt.xlabel('Actual GDP per capita ($)')
    plt.ylabel('Residuals')
    plt.title('Regression Residuals (Red: Outliers)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    return reg_ridge, reg_ransac, residuals_ridge, inlier_mask

def analyze_clusters_memory_efficient(cluster_labels, city_names, countries, populations, gdp_values, 
                                     embeddings, originals, reconstructions, n_clusters=7):
    """Enhanced cluster analysis with statistical tests"""
    print("\n" + "="*60)
    print("CLUSTER ANALYSIS")
    print("="*60)
    
    correlations = analyze_correlations(embeddings, populations, gdp_values)
    reg_pop, reg_emb, reg_full, r2_pop, r2_emb, r2_full = incremental_regression_analysis(embeddings, populations, gdp_values)
    
    gdp_by_cluster = [gdp_values[cluster_labels == i] for i in range(n_clusters)]
    print("\nStatistical Tests for GDP Differences:")
    try:
        f_stat, p_value = f_oneway(*[gdp for gdp in gdp_by_cluster if len(gdp) > 0])
        print(f"ANOVA test for GDP differences: F={f_stat:.3f}, p={p_value:.3f}")
    except:
        print("ANOVA test failed")
    
    try:
        h_stat, p_value = kruskal(*[gdp for gdp in gdp_by_cluster if len(gdp) > 0])
        print(f"Kruskal-Wallis test: H={h_stat:.3f}, p={p_value:.3f}")
    except:
        print("Kruskal-Wallis test failed")
    
    print("\nPairwise GDP Comparisons:")
    df = pd.DataFrame({'gdp': gdp_values, 'cluster': cluster_labels})
    mc = MultiComparison(df['gdp'], df['cluster'])
    tukey_result = mc.tukeyhsd()
    print(tukey_result)
    
    print("\nMANOVA for Embeddings Across Clusters:")
    try:
        df_manova = pd.DataFrame(embeddings, columns=[f'emb_{i}' for i in range(embeddings.shape[1])])
        df_manova['cluster'] = cluster_labels
        manova = MANOVA.from_formula(' + '.join([f'emb_{i}' for i in range(embeddings.shape[1])]) + ' ~ cluster', data=df_manova)
        print(manova.mv_test())
    except Exception as e:
        print(f"MANOVA failed: {e}")
    
    print("\nCluster-wise Pearson Correlations (GDP vs Population):")
    cluster_correlations = {}
    for cluster_id in range(n_clusters):
        cluster_mask = cluster_labels == cluster_id
        if np.sum(cluster_mask) > 1:
            cluster_pop = populations[cluster_mask]
            cluster_gdp = gdp_values[cluster_mask]
            corr, p_value = pearsonr(cluster_pop, cluster_gdp)
            cluster_correlations[cluster_id] = (corr, p_value)
            print(f"Cluster {cluster_id}: r={corr:.3f}, p={p_value:.3f}")
        else:
            print(f"Cluster {cluster_id}: Not enough data for correlation")
            cluster_correlations[cluster_id] = (np.nan, np.nan)
    
    cluster_stats = {}
    for cluster_id in range(n_clusters):
        cluster_mask = cluster_labels == cluster_id
        cluster_cities = np.array(city_names)[cluster_mask]
        cluster_countries = np.array(countries)[cluster_mask]
        cluster_pop = populations[cluster_mask]
        cluster_gdp = gdp_values[cluster_mask]
        
        stats = {
            'count': len(cluster_cities),
            'gdp_mean': np.mean(cluster_gdp) if len(cluster_gdp) > 0 else np.nan,
            'gdp_std': np.std(cluster_gdp) if len(cluster_gdp) > 0 else np.nan,
            'gdp_cv': np.std(cluster_gdp) / np.mean(cluster_gdp) if len(cluster_gdp) > 0 and np.mean(cluster_gdp) != 0 else np.nan,
            'pop_mean': np.mean(cluster_pop) if len(cluster_pop) > 0 else np.nan,
            'pop_std': np.std(cluster_pop) if len(cluster_pop) > 0 else np.nan,
            'top_countries': {},
            'sample_cities': [],
            'all_cities': list(zip(cluster_cities, cluster_countries, cluster_gdp, cluster_pop))
        }
        
        unique_countries, counts = np.unique(cluster_countries, return_counts=True)
        for country, count in zip(unique_countries, counts):
            stats['top_countries'][country] = count
        
        gdp_indices = np.argsort(cluster_gdp)[::-1][:5]
        for idx in gdp_indices:
            city_idx = np.where(cluster_mask)[0][idx]
            stats['sample_cities'].append({
                'name': cluster_cities[idx],
                'country': cluster_countries[idx], 
                'gdp': cluster_gdp[idx],
                'pop': cluster_pop[idx],
                'index': city_idx
            })
        
        cluster_stats[cluster_id] = stats
        print(f"\nCluster {cluster_id}:")
        print(f"  Cities: {stats['count']}")
        print(f"  GDP per capita: ${stats['gdp_mean']:,.0f} ± ${stats['gdp_std']:,.0f} (CV: {stats['gdp_cv']:.3f})")
        print(f"  Population: {stats['pop_mean']:,.0f} ± {stats['pop_std']:,.0f}")
        top_countries = sorted(stats['top_countries'].items(), key=lambda x: x[1], reverse=True)[:3]
        country_str = ', '.join([f'{country}({count})' for country, count in top_countries])
        print(f"  Top countries: {country_str}")
        print(f"  Sample cities (top 5 by GDP):")
        for city_info in stats['sample_cities']:
            print(f"    • {city_info['name']} ({city_info['country']}): ${city_info['gdp']:,.0f}")
    
    return cluster_stats, correlations, reg_full, cluster_correlations, reg_pop, reg_emb, r2_pop, r2_emb, r2_full, gdp_by_cluster

def plot_correlation_heatmap(populations, gdp_values):
    """Plot correlation heatmap between population and GDP"""
    valid_mask = np.isfinite(populations) & np.isfinite(gdp_values)
    df = pd.DataFrame({
        'Population': populations[valid_mask],
        'GDP': gdp_values[valid_mask]
    })
    
    corr_matrix = df.corr()
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', center=0, fmt='.3f',
                square=True, cbar_kws={"shrink": .8})
    plt.title('Correlation Heatmap: Population vs GDP', fontsize=14)
    plt.tight_layout()
    plt.show()

def plot_gdp_distribution_by_cluster(cluster_labels, gdp_values, n_clusters=7):
    """Plot GDP distribution by cluster with cleaner visualization for large n_clusters"""
    plt.figure(figsize=(20, 8))
    gdp_by_cluster = [gdp_values[cluster_labels == i] for i in range(n_clusters)]
    
    boxplot = plt.boxplot(gdp_by_cluster, patch_artist=True)
    
    colors = plt.cm.Set3(np.linspace(0, 1, n_clusters))
    for patch, color in zip(boxplot['boxes'], colors):
        patch.set_facecolor(color)
    
    means = [np.mean(cluster_gdp) for cluster_gdp in gdp_by_cluster]
    for i, mean in enumerate(means):
        if np.isfinite(mean):
            plt.scatter(i+1, mean, color='red', zorder=3, s=30, marker='D')
    
    plt.xlabel('Cluster', fontsize=12)
    plt.ylabel('GDP per capita ($)', fontsize=12)
    plt.title(f'GDP Distribution by Cluster (n_clusters={n_clusters})', fontsize=14)
    plt.xticks(range(1, n_clusters+1), [f'{i}' for i in range(n_clusters)], rotation=90, fontsize=6)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

def plot_population_vs_gdp_by_cluster(cluster_labels, city_names, populations, gdp_values, n_clusters=7):
    """Plot population vs GDP with cluster coloring and reduced opacity"""
    plt.figure(figsize=(14, 10))
    colors = plt.cm.tab10(np.linspace(0, 1, n_clusters))
    
    for cluster_id in range(n_clusters):
        cluster_mask = cluster_labels == cluster_id
        if np.sum(cluster_mask) > 0:
            valid_mask = cluster_mask & np.isfinite(populations) & np.isfinite(gdp_values)
            plt.scatter(populations[valid_mask], gdp_values[valid_mask], 
                       c=[colors[cluster_id]], label=f'Cluster {cluster_id}', alpha=0.05, s=100)
            
            if np.sum(valid_mask) > 1:
                slope, intercept, r_value, p_value, std_err = stats.linregress(
                    populations[valid_mask], gdp_values[valid_mask])
                x_range = np.linspace(np.min(populations[valid_mask]), 
                                     np.max(populations[valid_mask]), 100)
                plt.plot(x_range, slope*x_range + intercept, 
                        color=colors[cluster_id], linestyle='--', alpha=0.7,
                        label=f'Cluster {cluster_id} (r={r_value:.2f})')
    
    plt.xscale('log')
    plt.xlabel('Population (log scale)')
    plt.ylabel('GDP per capita ($)')
    plt.title('Population vs GDP by Cluster with Regression Lines', fontsize=16)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

def plot_cluster_composition(cluster_stats, n_clusters=7):
    """Plot histogram of countries in each cluster with smaller x-axis labels"""
    fig, axes = plt.subplots(n_clusters, 1, figsize=(12, 4 * n_clusters))
    if n_clusters == 1:
        axes = [axes]
    
    for cluster_id in range(n_clusters):
        stats = cluster_stats[cluster_id]
        countries = stats['top_countries']
        
        sorted_countries = sorted(countries.items(), key=lambda x: x[1], reverse=True)
        country_names = [c[0] for c in sorted_countries]
        country_counts = [c[1] for c in sorted_countries]
        
        axes[cluster_id].bar(range(len(country_names)), country_counts)
        axes[cluster_id].set_title(f'Cluster {cluster_id} - Country Distribution ({stats["count"]} cities)')
        axes[cluster_id].set_ylabel('Count')
        axes[cluster_id].set_xticks(range(len(country_names)))
        axes[cluster_id].set_xticklabels(country_names, rotation=45, ha='right', fontsize=8)
        
        for i, count in enumerate(country_counts):
            axes[cluster_id].text(i, count + 0.1, str(count), ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()

def plot_cluster_images_memory_efficient(cluster_labels, city_names, countries, gdp_values, 
                                       populations, originals, reconstructions, cluster_id, n_cities=10, select_mode='top_gdp'):
    """Plot original and reconstructed images for a cluster in a row"""
    cluster_mask = cluster_labels == cluster_id
    cluster_indices = np.where(cluster_mask)[0]
    if len(cluster_indices) == 0:
        print(f"No cities in cluster {cluster_id}")
        return
    
    if select_mode == 'top_gdp':
        cluster_gdp = gdp_values[cluster_mask]
        sorted_indices = np.argsort(cluster_gdp)[::-1][:n_cities]
        selected_indices = cluster_indices[sorted_indices]
        suptitle_text = f'Cluster {cluster_id} - Top {n_cities} Cities by GDP'
    elif select_mode == 'one_per_country':
        country_counts = Counter(np.array(countries)[cluster_mask])
        top_countries = sorted(country_counts, key=country_counts.get, reverse=True)[:min(n_cities, len(country_counts))]
        selected_indices = []
        for country in top_countries:
            country_mask = (cluster_labels == cluster_id) & (np.array(countries) == country)
            country_indices = np.where(country_mask)[0]
            country_gdps = gdp_values[country_mask]
            best_idx = country_indices[np.argmax(country_gdps)]
            selected_indices.append(best_idx)
        suptitle_text = f'Cluster {cluster_id} - One City per Top {len(top_countries)} Countries (by frequency)'
    else:
        raise ValueError(f"Unknown select_mode: {select_mode}")
    
    n_images = len(selected_indices)
    if n_images == 0:
        print(f"No images selected for cluster {cluster_id}")
        return
    
    if n_images == 1:
        fig, axes = plt.subplots(2, 1, figsize=(3, 6))
        axes = axes.reshape(2, 1)
    else:
        fig, axes = plt.subplots(2, n_images, figsize=(3 * n_images, 6))
    
    for i, city_idx in enumerate(selected_indices):
        city_name = city_names[city_idx]
        country = countries[city_idx]
        gdp = gdp_values[city_idx] 
        population = populations[city_idx]
        
        original_array = originals[city_idx]
        if np.any(~np.isfinite(original_array)):
            print(f"Warning: Original image for {city_name} contains invalid values")
            original_array = np.clip(original_array, 0, 1)
        
        axes[0, i].imshow(original_array, cmap='binary', vmin=0, vmax=1)
        axes[0, i].set_title(f"{city_name}\nGDP: ${gdp:,.0f}", fontsize=10)
        axes[0, i].axis('off')
        
        reconstruction_array = reconstructions[city_idx]
        if np.any(~np.isfinite(reconstruction_array)):
            print(f"Warning: Reconstruction image for {city_name} contains invalid values")
            reconstruction_array = np.clip(reconstruction_array, 0, 1)
        
        im = axes[1, i].imshow(reconstruction_array, cmap='viridis', vmin=0, vmax=1)
        axes[1, i].set_title(f"Pop: {population:,.0f}", fontsize=10)
        axes[1, i].axis('off')
    
    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
    fig.colorbar(im, cax=cbar_ax)
    
    cluster_mean_gdp = np.mean(gdp_values[cluster_mask][np.isfinite(gdp_values[cluster_mask])])
    cluster_std_gdp = np.std(gdp_values[cluster_mask][np.isfinite(gdp_values[cluster_mask])])
    plt.suptitle(f'{suptitle_text}\nMean GDP: ${cluster_mean_gdp:,.0f} ± ${cluster_std_gdp:,.0f}', 
                 fontsize=16, y=0.95)
    plt.tight_layout(rect=[0, 0, 0.9, 0.95])
    plt.show()

def plot_regression_predictions(embeddings, populations, gdp_values, city_names, cluster_labels, 
                              reg_pop, reg_emb, reg_full, r2_pop, r2_emb, r2_full, n_clusters=7):
    """Visualize predictions from linear regression models and embeddings' role in GDP prediction"""
    valid_mask = np.isfinite(populations) & np.isfinite(gdp_values) & np.isfinite(embeddings).all(axis=1)
    embeddings = embeddings[valid_mask]
    populations = populations[valid_mask]
    gdp_values = gdp_values[valid_mask]
    city_names = [city_names[i] for i in np.where(valid_mask)[0]]
    cluster_labels = cluster_labels[valid_mask]
    
    scaler = StandardScaler()
    embeddings_scaled = scaler.fit_transform(embeddings)
    populations_scaled = scaler.fit_transform(populations.reshape(-1, 1))
    X_full = np.column_stack([embeddings_scaled, populations_scaled])
    
    pred_pop = reg_pop.predict(populations_scaled)
    pred_emb = reg_emb.predict(embeddings_scaled)
    pred_full = reg_full.predict(X_full)
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    axes[0].scatter(gdp_values, pred_pop, alpha=0.5, c=cluster_labels, cmap='tab10', s=100)
    axes[0].plot([gdp_values.min(), gdp_values.max()], [gdp_values.min(), gdp_values.max()], 
                 'k--', alpha=0.7, label='Perfect Prediction')
    axes[0].set_xlabel('Actual GDP per capita ($)')
    axes[0].set_ylabel('Predicted GDP per capita ($)')
    axes[0].set_title(f'Population-only Model\nR² = {r2_pop:.3f}')
    axes[0].grid(True, alpha=0.3)
    axes[0].legend()
    
    axes[1].scatter(gdp_values, pred_emb, alpha=0.5, c=cluster_labels, cmap='tab10', s=100)
    axes[1].plot([gdp_values.min(), gdp_values.max()], [gdp_values.min(), gdp_values.max()], 
                 'k--', alpha=0.7, label='Perfect Prediction')
    axes[1].set_xlabel('Actual GDP per capita ($)')
    axes[1].set_ylabel('Predicted GDP per capita ($)')
    axes[1].set_title(f'Embeddings-only Model\nR² = {r2_emb:.3f}')
    axes[1].grid(True, alpha=0.3)
    axes[1].legend()
    
    axes[2].scatter(gdp_values, pred_full, alpha=0.5, c=cluster_labels, cmap='tab10', s=100)
    axes[2].plot([gdp_values.min(), gdp_values.max()], [gdp_values.min(), gdp_values.max()], 
                 'k--', alpha=0.7, label='Perfect Prediction')
    axes[2].set_xlabel('Actual GDP per capita ($)')
    axes[2].set_ylabel('Predicted GDP per capita ($)')
    axes[2].set_title(f'Full Model (Pop + Emb)\nR² = {r2_full:.3f}')
    axes[2].grid(True, alpha=0.3)
    axes[2].legend()
    
    plt.suptitle('Actual vs Predicted GDP per Capita by Model', fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()
    
    try:
        tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embeddings)-1))
        embeddings_2d = tsne.fit_transform(embeddings)
        
        plt.figure(figsize=(12, 10))
        scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], 
                             c=pred_full, cmap='viridis', s=populations/1e5, alpha=0.7)
        plt.colorbar(scatter, label='Predicted GDP per capita ($)')
        plt.title('t-SNE of Embeddings (Colored by Full Model Predictions, Sized by Population)')
        plt.xlabel('t-SNE Dimension 1')
        plt.ylabel('t-SNE Dimension 2')
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()
    except ValueError as e:
        print(f"Error in t-SNE visualization: {e}")

def run_memory_efficient_clustering_analysis(embeddings, city_names, countries, originals, reconstructions, 
                                            n_clusters=7, data_dir="city_data"):
    """Main function to run clustering analysis with enhanced statistical analysis"""
    print(f"Starting memory-efficient clustering analysis with {n_clusters} clusters...")
    print(f"Input data: {len(city_names)} cities")
    
    embeddings = validate_inputs(embeddings, city_names, countries, originals, reconstructions, n_clusters)
    
    (valid_indices, cluster_labels, valid_city_names, valid_countries, 
     valid_populations, valid_gdp, valid_originals, valid_reconstructions, 
     features_scaled, kmeans, valid_embeddings) = cluster_cities_memory_efficient(
        embeddings, city_names, countries, originals, reconstructions, n_clusters, data_dir
    )
    
    cluster_stats, correlations, reg_full, cluster_correlations, reg_pop, reg_emb, r2_pop, r2_emb, r2_full, gdp_by_cluster = analyze_clusters_memory_efficient(
        cluster_labels, valid_city_names, valid_countries, valid_populations, 
        valid_gdp, valid_embeddings, valid_originals, valid_reconstructions, n_clusters
    )
    
    print("\n=== Additional Analyses ===")
    
    plot_incremental_r2(r2_pop, r2_emb, r2_full)
    
    densities, avg_sizes, num_components = extract_urban_metrics(valid_originals)
    analyze_urban_metrics(densities, avg_sizes, num_components, valid_gdp, valid_populations)
    plot_urban_metrics_vs_gdp(densities, avg_sizes, num_components, valid_gdp)
    
    dimensionality_reduction_visualization(
        valid_embeddings, valid_gdp, valid_populations, cluster_labels
    )
    
    outlier_robustness_analysis(
        valid_embeddings, valid_populations, valid_gdp, valid_city_names
    )
    
    plot_regression_predictions(
        valid_embeddings, valid_populations, valid_gdp, valid_city_names, cluster_labels,
        reg_pop, reg_emb, reg_full, r2_pop, r2_emb, r2_full, n_clusters
    )
    
    print("\nCreating visualizations...")
    plot_correlation_heatmap(valid_populations, valid_gdp)
    plot_gdp_distribution_by_cluster(cluster_labels, valid_gdp, n_clusters)
    plot_population_vs_gdp_by_cluster(cluster_labels, valid_city_names, valid_populations, valid_gdp, n_clusters)
    plot_cluster_composition(cluster_stats, n_clusters)
    
    print("\nDisplaying cluster images...")
    for cluster_id in range(n_clusters):
        cluster_size = np.sum(cluster_labels == cluster_id)
        if cluster_size > 0:
            print(f"\nCluster {cluster_id}: {cluster_size} cities")
            plot_cluster_images_memory_efficient(
                cluster_labels, valid_city_names, valid_countries, valid_gdp, 
                valid_populations, valid_originals, valid_reconstructions, 
                cluster_id, n_cities=min(10, cluster_size), select_mode='top_gdp'
            )
    
    print("\n=== Clusters with Highest and Lowest GDP Variance (min 30 cities) ===")
    large_clusters = [id for id in range(n_clusters) if cluster_stats[id]['count'] >= 30]
    if len(large_clusters) < 10:
        print(f"Not enough large clusters (only {len(large_clusters)} found)")
    else:
        vars = [(id, cluster_stats[id]['gdp_std']**2) for id in large_clusters if np.isfinite(cluster_stats[id]['gdp_std'])]
        vars_sorted = sorted(vars, key=lambda x: x[1])
        lowest_5 = vars_sorted[:5]
        highest_5 = vars_sorted[-5:][::-1]

        print("\nLowest 5 variances:")
        for id, var in lowest_5:
            print(f"Cluster {id}: variance={var:,.2f}, count={cluster_stats[id]['count']}, mean_gdp=${cluster_stats[id]['gdp_mean']:,.2f}")
            country_counts = sorted(cluster_stats[id]['top_countries'].items(), key=lambda x: x[1], reverse=True)
            print("Country counts:", country_counts)
            
            fig, ax = plt.subplots(figsize=(8, 6))
            country_names = [c[0] for c in country_counts]
            counts = [c[1] for c in country_counts]
            ax.bar(range(len(country_names)), counts)
            ax.set_title(f"Cluster {id} - Country Distribution ({cluster_stats[id]['count']} cities)")
            ax.set_ylabel('Number of Cities')
            ax.set_xticks(range(len(country_names)))
            ax.set_xticklabels(country_names, rotation=45, ha='right', fontsize=8)
            for i, count in enumerate(counts):
                ax.text(i, count + 0.1, str(count), ha='center', va='bottom')
            plt.tight_layout()
            plt.show()
            
            gdps = np.array([city[2] for city in cluster_stats[id]['all_cities'] if np.isfinite(city[2])])
            plt.figure(figsize=(8, 6))
            plt.boxplot(gdps, patch_artist=True, boxprops=dict(facecolor='lightblue'))
            plt.title(f"GDP Boxplot for Cluster {id} (Lowest Variance)")
            plt.ylabel('GDP per capita ($)')
            plt.grid(True, alpha=0.3)
            plt.tight_layout()
            plt.show()
            
            plot_cluster_images_memory_efficient(
                cluster_labels, valid_city_names, valid_countries, valid_gdp, 
                valid_populations, valid_originals, valid_reconstructions, 
                id, n_cities=10, select_mode='one_per_country'
            )

        print("\nHighest 5 variances:")
        for id, var in highest_5:
            print(f"Cluster {id}: variance={var:,.2f}, count={cluster_stats[id]['count']}, mean_gdp=${cluster_stats[id]['gdp_mean']:,.2f}")
            country_counts = sorted(cluster_stats[id]['top_countries'].items(), key=lambda x: x[1], reverse=True)
            print("Country counts:", country_counts)
            
            fig, ax = plt.subplots(figsize=(8, 6))
            country_names = [c[0] for c in country_counts]
            counts = [c[1] for c in country_counts]
            ax.bar(range(len(country_names)), counts)
            ax.set_title(f"Cluster {id} - Country Distribution ({cluster_stats[id]['count']} cities)")
            ax.set_ylabel('Number of Cities')
            ax.set_xticks(range(len(country_names)))
            ax.set_xticklabels(country_names, rotation=45, ha='right', fontsize=8)
            for i, count in enumerate(counts):
                ax.text(i, count + 0.1, str(count), ha='center', va='bottom')
            plt.tight_layout()
            plt.show()
            
            gdps = np.array([city[2] for city in cluster_stats[id]['all_cities'] if np.isfinite(city[2])])
            plt.figure(figsize=(8, 6))
            plt.boxplot(gdps, patch_artist=True, boxprops=dict(facecolor='lightcoral'))
            plt.title(f"GDP Boxplot for Cluster {id} (Highest Variance)")
            plt.ylabel('GDP per capita ($)')
            plt.grid(True, alpha=0.3)
            plt.tight_layout()
            plt.show()
            
            plot_cluster_images_memory_efficient(
                cluster_labels, valid_city_names, valid_countries, valid_gdp, 
                valid_populations, valid_originals, valid_reconstructions, 
                id, n_cities=10, select_mode='one_per_country'
            )
    
    return {
        'cluster_labels': cluster_labels,
        'valid_indices': valid_indices,
        'cluster_stats': cluster_stats,
        'silhouette_score': silhouette_score(features_scaled, cluster_labels),
        'valid_city_names': valid_city_names,
        'valid_countries': valid_countries,
        'valid_populations': valid_populations,
        'valid_gdp': valid_gdp,
        'correlations': correlations,
        'cluster_correlations': cluster_correlations,
        'regression_model': reg_full,
        'kmeans_model': kmeans
    }

try:
    results = run_memory_efficient_clustering_analysis(
        embeddings=embeddings,
        city_names=city_names,
        countries=countries,
        originals=originals,
        reconstructions=reconstructions,
        n_clusters=100, 
        data_dir="city_data"
    )
except NameError as e:
    print(f"Error: {e}. Please ensure all input variables are defined.")
except ValueError as e:
    print(f"ValueError: {e}. Check for invalid values in embeddings, populations, GDP, or building density data.")