In [5]:
!pip install mercantile
import os
import re
import numpy as np
from PIL import Image
import timm
import torch
import time
from pathlib import Path
from typing import List, Tuple, Optional
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import os
import re
import numpy as np
import duckdb
import mercantile
from pathlib import Path
import time
from typing import List, Tuple, Optional

images_dir = "images/"
batch_size = 512
num_workers = 4
device = torch.device('cuda')
output_prefix = "resnet34_java"

def parse_tile_from_filename(filename: str) -> Optional[Tuple[int, int, int]]:
    """
    Parse tile coordinates from filename like 'tile_14_13292_8548.png'
    Returns (zoom, x, y) or None if parsing fails
    """
    pattern = r'tile_(\d+)_(\d+)_(\d+)\.'
    match = re.search(pattern, filename)

    if match:
        zoom, x, y = map(int, match.groups())
        return zoom, x, y
    else:
        print(f"Could not parse tile coordinates from: {filename}")
        return None


class TileImageDataset(Dataset):
    """
    Custom Dataset for loading tile images
    """
    def __init__(self, image_paths, transforms):
        self.image_paths = image_paths
        self.transforms = transforms
        self.valid_indices = []
        self.tile_coords = []

        # Pre-filter valid images and parse coordinates
        for i, image_path in enumerate(image_paths):
            coords = parse_tile_from_filename(image_path.name)
            if coords is not None:
                self.valid_indices.append(i)
                self.tile_coords.append(coords)

        print(f"Dataset: {len(self.valid_indices)}/{len(image_paths)} images have valid tile coordinates")

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        # Get the actual image index
        actual_idx = self.valid_indices[idx]
        image_path = self.image_paths[actual_idx]
        coords = self.tile_coords[idx]

        try:
            # Load and transform image
            img = Image.open(image_path).convert('RGB')
            img_tensor = self.transforms(img)

            return {
                'image': img_tensor,
                'path': str(image_path),
                'coords': coords,
                'valid': True
            }
        except Exception as e:
            # Return dummy data for failed images
            print(f"Failed to load {image_path}: {e}")
            return {
                'image': torch.zeros(3, 224, 224),  # Dummy tensor
                'path': str(image_path),
                'coords': coords,
                'valid': False
            }

Collecting mercantile
  Downloading mercantile-1.2.1-py3-none-any.whl.metadata (4.8 kB)
Downloading mercantile-1.2.1-py3-none-any.whl (14 kB)
Installing collected packages: mercantile
Successfully installed mercantile-1.2.1


In [2]:
print("Loading ResNet34 model...")

model = timm.create_model(
    'resnet34.a3_in1k',
    pretrained=True,
    num_classes=0,  # remove classifier nn.Linear
)
model = model.eval()

# Get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

print(f"Model loaded. Feature dimension: {model.num_features}")

Loading ResNet34 model...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/87.3M [00:00<?, ?B/s]

Model loaded. Feature dimension: 512


In [4]:
"""
Extract embeddings from tile images using PyTorch DataLoader
"""


model.to(device)

print(f"Starting embedding extraction from {images_dir}")
print(f"Batch size: {batch_size}, Workers: {num_workers}")

# Find all image files
image_files = []
for ext in ['*.png', '*.jpg', '*.jpeg']:
    image_files.extend(Path(images_dir).glob(ext))

if not image_files:
    print(f"No image files found in {images_dir}")


print(f"Found {len(image_files)} image files")

# Sort files for consistent ordering
image_files = sorted(image_files)


# Create dataset and dataloader
dataset = TileImageDataset(image_files, transforms)
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    shuffle=False,  # Keep order consistent
    pin_memory=True if device.type == 'cuda' else False
)

# Extract embeddings
embeddings = []
image_paths = []
tile_coords = []
failed_files = []

print(f"Processing {len(dataset)} valid images in {len(dataloader)} batches...")
start_time = time.time()

with torch.no_grad():
    for batch_idx, batch in enumerate(dataloader):
        # Move batch to device
        images = batch['image'].to(device)
        paths = batch['path']
        coords = batch['coords']
        valid = batch['valid']

        # Extract embeddings for the batch
        batch_embeddings = model(images)
        batch_embeddings = batch_embeddings.cpu().numpy()

        # Process each item in the batch
        for i in range(len(paths)):
            if valid[i]:  # Only keep valid images
                embeddings.append(batch_embeddings[i])
                image_paths.append(paths[i])
                # Convert coords tuple to regular tuple (DataLoader may return tensors)
                coord_tuple = (int(coords[0][i]), int(coords[1][i]), int(coords[2][i]))
                tile_coords.append(coord_tuple)
            else:
                failed_files.append(paths[i])

        # Progress update
        if (batch_idx + 1) % 10 == 0:
            processed = min((batch_idx + 1) * batch_size, len(dataset))
            elapsed = time.time() - start_time
            rate = processed / elapsed
            eta = (len(dataset) - processed) / rate if rate > 0 else 0
            print(f"Processed {processed}/{len(dataset)} images "
                  f"({rate:.1f} img/s, ETA: {eta:.1f}s)")

total_time = time.time() - start_time
print(f"Extraction completed in {total_time:.1f} seconds")
print(f"Average speed: {len(embeddings) / total_time:.1f} images/second")


Starting embedding extraction from images/
Batch size: 512, Workers: 4
Found 20976 image files
Dataset: 20976/20976 images have valid tile coordinates
Processing 20976 valid images in 41 batches...
Processed 5120/20976 images (1419.7 img/s, ETA: 11.2s)
Processed 10240/20976 images (1710.1 img/s, ETA: 6.3s)
Processed 15360/20976 images (1662.3 img/s, ETA: 3.4s)
Processed 20480/20976 images (1770.8 img/s, ETA: 0.3s)
Extraction completed in 12.4 seconds
Average speed: 1685.8 images/second


In [None]:
def prepare_tile_data(embeddings: np.ndarray, image_paths: List[str]) -> List[dict]:
    """Prepare tile data with embeddings and geometries"""
    print("Preparing tile data...")

    tile_data = []
    failed_parses = 0

    for i, (embedding, image_path) in enumerate(zip(embeddings, image_paths)):
        filename = os.path.basename(image_path)

        # Parse tile coordinates
        tile_coords = parse_tile_from_filename(filename)
        if tile_coords is None:
            failed_parses += 1
            continue

        zoom, x, y = tile_coords

        # Get tile bounds using mercantile
        tile_bounds = mercantile.bounds(x, y, zoom)

        # Calculate tile center
        center_lon = (tile_bounds.west + tile_bounds.east) / 2
        center_lat = (tile_bounds.south + tile_bounds.north) / 2

        # Create WKT point geometry for tile center
        geometry_wkt = f"POINT({center_lon} {center_lat})"

        # Create tile ID
        tile_id = f"{zoom}_{x}_{y}"

        tile_data.append({
            'id': tile_id,
            'zoom': zoom,
            'x': x,
            'y': y,
            'embedding': embedding.tolist(),
            'geometry_wkt': geometry_wkt,
            'image_path': image_path,
            'west': tile_bounds.west,
            'south': tile_bounds.south,
            'east': tile_bounds.east,
            'north': tile_bounds.north,
            'center_lon': center_lon,
            'center_lat': center_lat
        })

        if (i + 1) % 100 == 0:
            print(f"Processed {i + 1}/{len(image_paths)} tiles...")

    if failed_parses > 0:
        print(f"Failed to parse {failed_parses} tiles")

    print(f"Prepared {len(tile_data)} tile records")
    return tile_data

def create_duckdb_database(tile_data: List[dict], output_file: str, embedding_dim: int) -> None:
    """Create DuckDB database with spatial and vector search capabilities"""
    print(f"Creating DuckDB database: {output_file}")

    con = duckdb.connect(database=output_file)

    try:
        # Load extensions
        print("Loading extensions...")
        con.execute("INSTALL spatial; LOAD spatial;")
        con.execute("INSTALL vss; LOAD vss;")
        con.execute("SET hnsw_enable_experimental_persistence = true;")

        # Create table
        print("Creating table...")
        create_table_sql = f"""
        CREATE OR REPLACE TABLE tile_embeddings (
            id VARCHAR PRIMARY KEY,
            zoom INTEGER,
            x INTEGER,
            y INTEGER,
            embedding FLOAT[{embedding_dim}],
            geometry GEOMETRY,
            image_path VARCHAR,
            west DOUBLE,
            south DOUBLE,
            east DOUBLE,
            north DOUBLE,
            center_lon DOUBLE,
            center_lat DOUBLE
        );
        """
        con.execute(create_table_sql)

        # Insert data
        print("Inserting data...")
        start_time = time.time()

        for i, row in enumerate(tile_data):
            insert_sql = """
            INSERT INTO tile_embeddings
            (id, zoom, x, y, embedding, geometry, image_path,
             west, south, east, north, center_lon, center_lat)
            VALUES (?, ?, ?, ?, ?, ST_GeomFromText(?), ?, ?, ?, ?, ?, ?, ?)
            """

            values = (
                row['id'], row['zoom'], row['x'], row['y'],
                row['embedding'], row['geometry_wkt'], row['image_path'],
                row['west'], row['south'], row['east'], row['north'],
                row['center_lon'], row['center_lat']
            )

            con.execute(insert_sql, values)

            if (i + 1) % 100 == 0:
                print(f"Inserted {i + 1}/{len(tile_data)} tiles...")

        insert_time = time.time() - start_time
        print(f"Data insertion completed in {insert_time:.1f} seconds")

        # Create indexes
        print("Creating indexes...")
        con.execute("CREATE INDEX idx_tile_embeddings_geom ON tile_embeddings USING RTREE (geometry);")
        con.execute(f"""
        CREATE INDEX idx_tile_embeddings_hnsw
        ON tile_embeddings
        USING HNSW (embedding)
        WITH (metric = 'cosine');
        """)
        con.execute("CREATE INDEX idx_tile_embeddings_zoom ON tile_embeddings (zoom);")
        con.execute("CREATE INDEX idx_tile_embeddings_coords ON tile_embeddings (zoom, x, y);")

        # Print summary
        result = con.execute("SELECT COUNT(*) FROM tile_embeddings").fetchone()
        print(f"Total tiles: {result[0]}")

        result = con.execute("SELECT MIN(zoom), MAX(zoom) FROM tile_embeddings").fetchone()
        print(f"Zoom range: {result[0]} - {result[1]}")

        print("âœ… Database created successfully!")

    except Exception as e:
        print(f"Error creating database: {e}")
        raise
    finally:
        con.close()

"""Main function to build the tile embeddings database"""
print("ðŸš€ Creating tile embeddings database")

# Prepare tile data
embeddings = np.array(embeddings)
tile_data = prepare_tile_data(embeddings, image_paths)

if not tile_data:
    print("No valid tile data found!")

# Create database
output_db = f"{output_prefix}.db"
embedding_dim = embeddings.shape[1]

create_duckdb_database(tile_data, output_db, embedding_dim)

print(f"Database saved as: {output_db}")

ðŸš€ Creating tile embeddings database
Preparing tile data...
Processed 100/20976 tiles...
Processed 200/20976 tiles...
Processed 300/20976 tiles...
Processed 400/20976 tiles...
Processed 500/20976 tiles...
Processed 600/20976 tiles...
Processed 700/20976 tiles...
Processed 800/20976 tiles...
Processed 900/20976 tiles...
Processed 1000/20976 tiles...
Processed 1100/20976 tiles...
Processed 1200/20976 tiles...
Processed 1300/20976 tiles...
Processed 1400/20976 tiles...
Processed 1500/20976 tiles...
Processed 1600/20976 tiles...
Processed 1700/20976 tiles...
Processed 1800/20976 tiles...
Processed 1900/20976 tiles...
Processed 2000/20976 tiles...
Processed 2100/20976 tiles...
Processed 2200/20976 tiles...
Processed 2300/20976 tiles...
Processed 2400/20976 tiles...
Processed 2500/20976 tiles...
Processed 2600/20976 tiles...
Processed 2700/20976 tiles...
Processed 2800/20976 tiles...
Processed 2900/20976 tiles...
Processed 3000/20976 tiles...
Processed 3100/20976 tiles...
Processed 3200/20