In [None]:
from gis.config import Config

config = Config()



In [None]:
import os 

img_pth = config.mnt_path / 'image/18'
os.listdir(img_pth)

In [None]:
import os
import re
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path

def load_tiles(tile_directory):
    """Load all tile images and their coordinates."""
    tiles = {}
    tile_dir = Path(tile_directory)

    for tile_file in tile_dir.glob(f'*.jpg'):
        try:
            print(tile_file.name)
            z, x, y = tile_file.name.split('_')
            z, x, y = int(z), int(x), int(y.replace('.jpg', ''))
            img = Image.open(tile_file)
            #plt.imshow(img)
            tiles[(x, y)] = np.array(img)
            #print(f"Loaded tile ({x}, {y}) from {tile_file.name}")
        except ValueError as e:
            print(f"Skipping {tile_file.name}: {e}")
        except Exception as e:
            print(f"Error loading {tile_file.name}: {e}")
    


    return tiles


tiles = load_tiles(str(img_pth))

In [None]:

def stitch_tiles(tiles, tile_size=256):
    """
    Stitch tiles together into a single large image.
    
    Args:
        tiles: Dictionary with (x, y) coordinates as keys and numpy arrays as values
        tile_size: Size of each tile in pixels (default 256 for standard web tiles)
    
    Returns:
        Stitched image as numpy array
    """
    if not tiles:
        raise ValueError("No tiles to stitch!")
    
    # Find the bounding box of all tiles
    min_x = min(x for x, y in tiles.keys())
    max_x = max(x for x, y in tiles.keys())
    min_y = min(y for x, y in tiles.keys())
    max_y = max(y for x, y in tiles.keys())
    
    print(f"Tile bounds: X({min_x} to {max_x}), Y({min_y} to {max_y})")
    
    # Calculate output image dimensions
    width = (max_x - min_x + 1) * tile_size
    height = (max_y - min_y + 1) * tile_size
    
    # Determine if tiles are RGB or grayscale
    sample_tile = next(iter(tiles.values()))
    if len(sample_tile.shape) == 3:
        channels = sample_tile.shape[2]
        stitched = np.zeros((height, width, channels), dtype=sample_tile.dtype)
    else:
        stitched = np.zeros((height, width), dtype=sample_tile.dtype)
    
    print(f"Creating stitched image of size: {width}x{height}")
    
    # Place each tile in the correct position
    tiles_placed = 0
    for (tile_x, tile_y), tile_img in tiles.items():
        # Calculate position in the stitched image
        # Note: Y coordinates might be flipped depending on tile system
        start_x = (tile_x - min_x) * tile_size
        start_y = (tile_y - min_y) * tile_size
        end_x = start_x + tile_img.shape[1]
        end_y = start_y + tile_img.shape[0]
        
        # Ensure we don't go out of bounds
        end_x = min(end_x, width)
        end_y = min(end_y, height)
        
        # Place the tile
        if len(stitched.shape) == 3:
            stitched[start_y:end_y, start_x:end_x, :] = tile_img[:end_y-start_y, :end_x-start_x, :]
        else:
            stitched[start_y:end_y, start_x:end_x] = tile_img[:end_y-start_y, :end_x-start_x]
        
        tiles_placed += 1
    
    print(f"Successfully placed {tiles_placed} tiles")
    return stitched


stiched = stitch_tiles(tiles)
stiched.shape

In [None]:
def visualize_stitched_image(stitched_image, figsize=(15, 15), save_path=None):
    """Visualize the stitched image using matplotlib."""
    plt.figure(figsize=figsize)
    
    if len(stitched_image.shape) == 3:
        plt.imshow(stitched_image)
    else:
        plt.imshow(stitched_image, cmap='gray')
    
    plt.title(f'Stitched Tiles (Shape: {stitched_image.shape})')
    plt.xlabel('Pixels')
    plt.ylabel('Pixels')
    
    # Add grid to show tile boundaries
    height, width = stitched_image.shape[:2]
    tile_size = 256
    
    # Vertical lines
    for x in range(0, width, tile_size):
        plt.axvline(x, color='black', alpha=1, linewidth=2)
    
    # Horizontal lines
    for y in range(0, height, tile_size):
        plt.axhline(y, color='black', alpha=1, linewidth=2)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved visualization to {save_path}")
    
    plt.show()

visualize_stitched_image(stiched)

In [None]:
from gis.image_utils import load_mask 

road_tiles = {}
road_masks = {}
for (x,y), tile in tiles.items():
    filepath = config.mnt_path / f'label/18/{x}_{y}.npy'
    if not os.path.exists(filepath):
        continue 
    mask = load_mask((18, x, y))
    
    if np.sum(mask)>0:
        road_tiles[(x,y)] = tile
        road_masks[(x,y)] = mask
len(road_tiles)

In [None]:
visualize_stitched_image(stitch_tiles(road_tiles))

In [None]:
visualize_stitched_image(stitch_tiles(road_masks))

# Looking at getting a line segment from mask

In [None]:
ex = (231007,155461)
example = road_tiles[(231007,155461)]
plt.imshow(example)

In [None]:
mask = load_mask((18, ex[0], ex[1]))
plt.imshow(mask)

In [None]:
from skimage import measure, morphology


In [None]:
skeleton = morphology.skeletonize(mask)
plt.imshow(skeleton)

In [None]:
fig, axs = plt.subplots(1,1,figsize=(10,10))
axs.imshow(mask, alpha=1)
axs.imshow(skeleton, alpha=0.5)

In [None]:
stiched_mask = stitch_tiles(road_masks)
skeleton = morphology.skeletonize(stiched_mask)
fig, axs = plt.subplots(1,1,figsize=(10,10))
axs.imshow(stiched_mask, alpha=1)
axs.imshow(skeleton, alpha=0.5)

In [None]:
skeleton

In [None]:
import numpy as np
import math
from skimage import morphology
from scipy import ndimage
from collections import deque

def tile_pixel_to_latlon(tile_x, tile_y, tile_z, pixel_x, pixel_y, tile_size=256):
    """
    Convert pixel coordinates within a tile to latitude/longitude.
    
    Args:
        tile_x, tile_y: Tile coordinates (x=longitude, y=latitude direction)
        tile_z: Zoom level
        pixel_x, pixel_y: Pixel coordinates within the tile (0 to tile_size-1)
        tile_size: Size of tile in pixels (default 256)
    
    Returns:
        (longitude, latitude) tuple
    """
    # Convert tile + pixel to global pixel coordinates at this zoom level
    global_pixel_x = tile_x * tile_size + pixel_x
    global_pixel_y = tile_y * tile_size + pixel_y
    
    # Convert to lat/lon using Web Mercator math
    n = 2.0 ** tile_z
    
    # Longitude is straightforward
    lon = (global_pixel_x / (tile_size * n)) * 360.0 - 180.0
    
    # Latitude requires inverse Mercator projection
    lat_rad = math.atan(math.sinh(math.pi * (1 - 2 * global_pixel_y / (tile_size * n))))
    lat = math.degrees(lat_rad)
    
    return lon, lat

def extract_skeleton_points(skeleton_mask):
    """
    Extract ordered points from a skeleton mask.
    
    Args:
        skeleton_mask: Binary skeleton from morphology.skeletonize()
    
    Returns:
        List of (y, x) pixel coordinates along the skeleton
    """
    # Find all skeleton pixels
    skeleton_pixels = np.where(skeleton_mask)
    skeleton_coords = list(zip(skeleton_pixels[0], skeleton_pixels[1]))
    
    if len(skeleton_coords) < 2:
        return skeleton_coords
    
    # Order the points by following the skeleton
    return order_skeleton_points(skeleton_coords, skeleton_mask)

def order_skeleton_points(skeleton_coords, skeleton_mask):
    """
    Order skeleton points to form a continuous line.
    """
    if len(skeleton_coords) < 2:
        return skeleton_coords
    
    # Build adjacency graph
    coord_set = set(skeleton_coords)
    adjacency = {}
    
    for y, x in skeleton_coords:
        neighbors = []
        # Check 8-connected neighbors
        for dy in [-1, 0, 1]:
            for dx in [-1, 0, 1]:
                if dy == 0 and dx == 0:
                    continue
                ny, nx = y + dy, x + dx
                if (ny, nx) in coord_set:
                    neighbors.append((ny, nx))
        adjacency[(y, x)] = neighbors
    
    # Find endpoints (nodes with only 1 neighbor)
    endpoints = [coord for coord, neighbors in adjacency.items() if len(neighbors) == 1]
    
    if not endpoints:
        # No clear endpoints, start from any point
        start_point = skeleton_coords[0]
    else:
        start_point = endpoints[0]
    
    # Traverse the skeleton
    ordered_points = []
    visited = set()
    current = start_point
    
    while current and current not in visited:
        ordered_points.append(current)
        visited.add(current)
        
        # Find next unvisited neighbor
        next_point = None
        for neighbor in adjacency.get(current, []):
            if neighbor not in visited:
                next_point = neighbor
                break
        current = next_point
    
    return ordered_points

def skeleton_to_wkt_linestring(skeleton_mask, tile_x, tile_y, tile_z, tile_size=256, simplify_tolerance=0.0001):
    """
    Convert skeleton mask to WKT LINESTRING with geographic coordinates.
    
    Args:
        skeleton_mask: Binary skeleton from morphology.skeletonize()
        tile_x, tile_y, tile_z: Tile coordinates
        tile_size: Tile size in pixels
        simplify_tolerance: Tolerance for coordinate simplification (degrees)
    
    Returns:
        WKT LINESTRING string
    """
    # Extract ordered skeleton points
    skeleton_points = extract_skeleton_points(skeleton_mask)
    
    if len(skeleton_points) < 2:
        return None
    
    # Convert to lat/lon coordinates
    coordinates = []
    for pixel_y, pixel_x in skeleton_points:
        lon, lat = tile_pixel_to_latlon(tile_x, tile_y, tile_z, pixel_x, pixel_y, tile_size)
        coordinates.append((lon, lat))
    
    # Optional: Simplify coordinates to reduce file size
    if simplify_tolerance > 0:
        coordinates = simplify_coordinates(coordinates, simplify_tolerance)
    
    # Format as WKT
    coord_strings = [f"{lon} {lat}" for lon, lat in coordinates]
    wkt = f"LINESTRING({', '.join(coord_strings)})"
    
    return wkt

def simplify_coordinates(coordinates, tolerance):
    """
    Simple coordinate simplification using distance threshold.
    For more advanced simplification, consider using the Ramer-Douglas-Peucker algorithm.
    """
    if len(coordinates) <= 2:
        return coordinates
    
    simplified = [coordinates[0]]  # Always keep first point
    
    for i in range(1, len(coordinates) - 1):
        # Calculate distance to last kept point
        dx = coordinates[i][0] - simplified[-1][0]
        dy = coordinates[i][1] - simplified[-1][1]
        distance = math.sqrt(dx*dx + dy*dy)
        
        if distance > tolerance:
            simplified.append(coordinates[i])
    
    simplified.append(coordinates[-1])  # Always keep last point
    return simplified

def skeleton_to_wkt_multilinestring(skeleton_mask, tile_x, tile_y, tile_z, tile_size=256):
    """
    Convert skeleton mask to WKT MULTILINESTRING for cases with multiple disconnected lines.
    """
    # Label connected components
    labeled_skeleton = ndimage.label(skeleton_mask)[0]
    
    linestrings = []
    
    for component_id in range(1, labeled_skeleton.max() + 1):
        component_mask = labeled_skeleton == component_id
        
        # Extract skeleton for this component
        wkt_line = skeleton_to_wkt_linestring(component_mask, tile_x, tile_y, tile_z, tile_size)
        
        if wkt_line:
            # Extract just the coordinate part
            coord_part = wkt_line.replace('LINESTRING(', '').replace(')', '')
            linestrings.append(f"({coord_part})")
    print(len(linestrings))
    if not linestrings:
        return None
    elif len(linestrings) == 1:
        return f"LINESTRING{linestrings[0]}"
    else:
        return f"MULTILINESTRING({', '.join(linestrings)})"

def process_rail_skeleton_from_tiles(skeleton_mask, tile_info_list):
    """
    Process skeleton that spans multiple tiles.
    
    Args:
        skeleton_mask: Combined skeleton mask from stitched tiles
        tile_info_list: List of dictionaries with 'x', 'y', 'z', 'offset_x', 'offset_y' for each tile
    
    Returns:
        WKT string
    """
    # This is more complex - you'd need to track which pixels belong to which tile
    # For now, assuming single tile or that you can determine tile boundaries
    
    # Find skeleton points
    skeleton_points = extract_skeleton_points(skeleton_mask)
    
    coordinates = []
    for pixel_y, pixel_x in skeleton_points:
        # Determine which tile this pixel belongs to
        tile_info = find_tile_for_pixel(pixel_x, pixel_y, tile_info_list)
        
        if tile_info:
            # Convert to tile-local coordinates
            local_x = pixel_x - tile_info['offset_x']
            local_y = pixel_y - tile_info['offset_y']
            
            # Convert to lat/lon
            lon, lat = tile_pixel_to_latlon(
                tile_info['x'], tile_info['y'], tile_info['z'], 
                local_x, local_y
            )
            coordinates.append((lon, lat))
    
    if len(coordinates) < 2:
        return None
    
    coord_strings = [f"{lon} {lat}" for lon, lat in coordinates]
    return f"LINESTRING({', '.join(coord_strings)})"

def find_tile_for_pixel(pixel_x, pixel_y, tile_info_list, tile_size=256):
    """Helper function to find which tile a pixel belongs to."""
    for tile_info in tile_info_list:
        min_x = tile_info['offset_x']
        max_x = min_x + tile_size
        min_y = tile_info['offset_y']
        max_y = min_y + tile_size
        
        if min_x <= pixel_x < max_x and min_y <= pixel_y < max_y:
            return tile_info
    
    return None

# Example usage
def example_usage():
    """Example of how to use the skeleton to WKT conversion."""
    
    # Create a simple skeleton for demonstration
    test_skeleton = np.zeros((50, 50), dtype=bool)
    # Draw a diagonal line
    for i in range(45):
        test_skeleton[i, i] = True
    
    # Example tile coordinates (Melbourne area, zoom 18)
    tile_x = 236870
    tile_y = 156616
    tile_z = 18
    
    # Convert to WKT
    wkt_result = skeleton_to_wkt_linestring(test_skeleton, tile_x, tile_y, tile_z)
    
    print("Example WKT output:")
    print(wkt_result)
    
    # For multiple disconnected lines
    wkt_multi = skeleton_to_wkt_multilinestring(test_skeleton, tile_x, tile_y, tile_z)
    print("\nMultilinestring version:")
    print(wkt_multi)

example_usage()

# Short term, predicting train track location 

In [None]:
ex = (231007,155461)
example = road_tiles[(231007,155461)]
plt.imshow(example)

In [None]:
mask = load_mask((18, ex[0], ex[1]))
plt.imshow(mask)

In [None]:
skeleton = morphology.skeletonize(mask)
plt.imshow(skeleton)

In [None]:
labeled_skeleton = ndimage.label(skeleton)[0]
labeled_skeleton.shape
plt.imshow(labeled_skeleton)

In [None]:
import shapely

ls = skeleton_to_wkt_multilinestring(skeleton, 231007,155461, 18)
from shapely import wkt

result = wkt.loads(ls)
type(result)

In [None]:
all_coords = []
for line_string in x.geoms:
    # Convert the CoordinateSequence to a list of tuples
    coords_for_line = list(line_string.coords)
    all_coords.extend(coords_for_line)

print(all_coords)

In [None]:
# import folium
# m = folium.Map(location=[49.0, 0.0], zoom_start=4)

# folium.PolyLine(
#     locations=all_coords,
#     color='blue',
#     weight=5,
#     tooltip='Route between cities'
# ).add_to(m)
# m

This direction actually looks pretty good. If i could extend a line in that direction and estimate intersections? 

https://docs.maptiler.com/google-maps-coordinates-tile-bounds-projection/

https://learn.microsoft.com/en-us/azure/azure-maps/zoom-levels-and-tile-grid?tabs=csharp

In [None]:
def tile_to_latlon(z, x, y):
    """
    Converts web Mercator tile coordinates (x, y) and zoom level to 
    latitude and longitude.
    """
    n = 2.0 ** z
    lon_rad = x / n * 2 * math.pi - math.pi
    lat_rad = math.atan(math.sinh(math.pi - (2 * math.pi * y) / n))

    lon_deg = math.degrees(lon_rad)
    lat_deg = math.degrees(lat_rad)

    return lat_deg, lon_deg

lat, lon = tile_to_latlon(18, ex[0], ex[1])
lat, lon

In [None]:
import math

def tile_pixel_to_lat_lon(zoom, tile_x, tile_y, pixel_x, pixel_y, width=256):
    """
    Converts a pixel position on tile + tile location to lat, lon of pixel 
    """
    global_pixel_x = tile_x * width + pixel_x
    global_pixel_y = tile_y * width + pixel_y
    map_width = width * (2 ** zoom)
    mercator_x = (global_pixel_x / map_width) - 0.5
    mercator_y = 0.5 - (global_pixel_y / map_width)
    longitude = mercator_x * 360.0
    latitude = math.degrees(math.atan(math.sinh(mercator_y * 2 * math.pi)))
    return latitude, longitude

lat, lon = tile_pixel_to_lat_lon(18, ex[0], ex[1], 50, 50)
lat, lon

## Extending lines 

In [None]:
import pyproj
import mercantile
import math

def extend_line_and_get_tiles(lat1, lon1, lat2, lon2, extension_distance_m, zoom_level, 
                             local_crs='EPSG:3857'):
    """
    Extends a line between two lat/lon points and returns mercantile tiles along the extended line.
    
    Args:
        lat1, lon1: First point coordinates
        lat2, lon2: Second point coordinates  
        extension_distance_m: How far to extend the line in meters
        zoom_level: Zoom level for tile calculation
        local_crs: Projected CRS for accurate distance calculations
        
    Returns:
        list: Mercantile tiles covering the extended line
    """
    
    # Convert to projected coordinates for accurate calculations
    transformer_to_proj = pyproj.Transformer.from_crs("EPSG:4326", local_crs, always_xy=True)
    transformer_to_geo = pyproj.Transformer.from_crs(local_crs, "EPSG:4326", always_xy=True)
    
    # Convert points to projected coordinates
    x1, y1 = transformer_to_proj.transform(lon1, lat1)
    x2, y2 = transformer_to_proj.transform(lon2, lat2)
    
    # Calculate direction vector
    dx = x2 - x1
    dy = y2 - y1
    
    # Calculate line length
    line_length = math.sqrt(dx*dx + dy*dy)
    
    if line_length == 0:
        raise ValueError("Points are identical")
    
    # Normalize direction vector
    unit_dx = dx / line_length
    unit_dy = dy / line_length
    
    # print(f"Original line length: {line_length:.2f}m")
    # print(f"Direction vector: ({unit_dx:.6f}, {unit_dy:.6f})")
    
    # Extend line in both directions
    # Extend backward from point 1
    back_x = x1 - unit_dx * extension_distance_m
    back_y = y1 - unit_dy * extension_distance_m
    
    # Extend forward from point 2  
    forward_x = x2 + unit_dx * extension_distance_m
    forward_y = y2 + unit_dy * extension_distance_m
    
    # Convert extended points back to lat/lon
    back_lon, back_lat = transformer_to_geo.transform(back_x, back_y)
    forward_lon, forward_lat = transformer_to_geo.transform(forward_x, forward_y)
    
    # print(f"Extended line:")
    # print(f"  Start: ({back_lat:.6f}, {back_lon:.6f})")
    # print(f"  Original P1: ({lat1:.6f}, {lon1:.6f})")
    # print(f"  Original P2: ({lat2:.6f}, {lon2:.6f})")
    # print(f"  End: ({forward_lat:.6f}, {forward_lon:.6f})")
    
    # Generate points along the extended line for tile sampling
    num_samples = max(10, int((line_length + 2*extension_distance_m) / 100))  # Sample every ~100m
    
    sample_points = []
    for i in range(num_samples + 1):
        t = i / num_samples
        sample_x = back_x + t * (forward_x - back_x)
        sample_y = back_y + t * (forward_y - back_y)
        sample_lon, sample_lat = transformer_to_geo.transform(sample_x, sample_y)
        sample_points.append((sample_lat, sample_lon))
    
    # Get unique tiles covering all sample points
    tiles = set()
    for lat, lon in sample_points:
        tile = mercantile.tile(lon, lat, zoom_level)
        tiles.add(tile)
    
    # Also get tiles for the bounding box to ensure coverage
    # min_lat = min(back_lat, forward_lat)
    # max_lat = max(back_lat, forward_lat)
    # min_lon = min(back_lon, forward_lon)
    # max_lon = max(back_lon, forward_lon)
    
    # bbox_tiles = list(mercantile.tiles(min_lon, min_lat, max_lon, max_lat, [zoom_level]))
    # tiles.update(bbox_tiles)
    
    return list(tiles)

import random

sample = random.sample(all_coords, 2)
coord1 = sample[0]
coord2 = sample[1]
print(coord1, coord2)

lat1, lon1 = tile_pixel_to_lat_lon(18, ex[0], ex[1], 50, 50)
lat2, lon2 = tile_pixel_to_lat_lon(18, ex[0], ex[1], 100, 100)
lat1, lon1, lat2, lon2

lat1, lon1 = coord1[1], coord1[0]
lat2, lon2 = coord2[1], coord2[0]

lat1, lon1, lat2, lon2
# lat1, lon1 = tile_pixel_to_lat_lon(18, ex[0], ex[1], 50, 50)
# lat2, lon2 = tile_pixel_to_lat_lon(18, ex[0], ex[1], 100, 100)
# lat1, lon1, lat2, lon2

In [None]:
new_tiles, points = extend_line_and_get_tiles(lat1, lon1, lat2, lon2, extension_distance_m=1000, zoom_level=18)

In [None]:


from gis.tile import download_tile

for t in new_tiles:
    download_tile(z=t.z, x=t.x, y=t.y)

In [None]:
tiles = {}
for t in new_tiles:
    img = Image.open(f'/mnt/gis/image/18/{t.z}_{t.x}_{t.y}.jpg')
    #plt.imshow(img)
    tiles[(t.x, t.y)] = np.array(img)
visualize_stitched_image(stitch_tiles(tiles))

In [None]:
tiles

In [None]:
ex

In [None]:
minx, miny, maxx, maxy = 99999999, 99999999, -1, -1
for (x,y), tile in tiles.items():
    if x < minx:
        minx = x
    if x > maxx:
        maxx = x
    if y < miny:
        miny = y
    if y > maxy:
        maxy = y
minx, miny, maxx, maxy

n_x = maxx - minx
n_y = maxy - miny

tile_index_x = ex[0] - minx
tile_index_y = ex[1] - miny
tile_index_x, tile_index_y

In [None]:
ex

In [None]:
def lat_lon_to_tile_pixel(latitude, longitude, zoom, width=256):
    """
    Converts latitude/longitude coordinates to tile position and pixel coordinates within that tile.
    
    Args:
        latitude (float): Latitude in degrees
        longitude (float): Longitude in degrees  
        zoom (int): Zoom level
        width (int): Tile width in pixels (default 256)
        
    Returns:
        tuple: (tile_x, tile_y, pixel_x, pixel_y)
    """
    # Convert lat/lon to Web Mercator normalized coordinates
    mercator_x = longitude / 360.0 + 0.5
    mercator_y = 0.5 - (math.log(math.tan(math.radians(latitude)) + 1/math.cos(math.radians(latitude))) / (2 * math.pi))
    
    # Convert to global pixel coordinates
    map_width = width * (2 ** zoom)
    global_pixel_x = mercator_x * map_width
    global_pixel_y = mercator_y * map_width
    
    # Calculate tile coordinates
    tile_x = int(global_pixel_x // width)
    tile_y = int(global_pixel_y // width)
    
    # Calculate pixel coordinates within the tile
    pixel_x = global_pixel_x - (tile_x * width)
    pixel_y = global_pixel_y - (tile_y * width)
    
    return tile_x, tile_y, pixel_x, pixel_y

# lat1, lon1, lat2, lon2



In [None]:
def visualize_stitched_image(stitched_image, figsize=(15, 15), save_path=None):
    """Visualize the stitched image using matplotlib."""
    plt.figure(figsize=figsize)
    
    if len(stitched_image.shape) == 3:
        plt.imshow(stitched_image)
    else:
        plt.imshow(stitched_image, cmap='gray')
    
    plt.title(f'Stitched Tiles (Shape: {stitched_image.shape})')
    plt.xlabel('Pixels')
    plt.ylabel('Pixels')
    
    # Add grid to show tile boundaries
    height, width = stitched_image.shape[:2]
    tile_size = 256
    
    # Vertical lines
    for i, x in enumerate(range(0, width, tile_size)):
        if i == 2:
            plt.axvline(x, color='red', alpha=1, linewidth=2)
        else:
            plt.axvline(x, color='black', alpha=1, linewidth=2)
    
    # Horizontal lines
    for i, y in enumerate(range(0, height, tile_size)):
        if i == 3:
            plt.axhline(y, color='red', alpha=1, linewidth=2)
        else:
            plt.axhline(y, color='black', alpha=1, linewidth=2)
    
    tile_x, tile_y, pixel_x, pixel_y = lat_lon_to_tile_pixel(lat1, lon1, 18)

    plt.scatter(256*2+pixel_x, 256*3+pixel_y, color='red')

    tile_x, tile_y, pixel_x, pixel_y = lat_lon_to_tile_pixel(lat2, lon2, 18)

    plt.scatter(256*2+pixel_x, 256*3+pixel_y, color='red')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved visualization to {save_path}")
    
    plt.show()

visualize_stitched_image(stitch_tiles(tiles))

In [None]:
import math



def tile_pixel_to_lat_lon(zoom, tile_x, tile_y, pixel_x, pixel_y, width=256):
    """
    Converts a pixel position on tile + tile location to lat, lon of pixel 
    """
    global_pixel_x = tile_x * width + pixel_x
    global_pixel_y = tile_y * width + pixel_y
    map_width = width * (2 ** zoom)
    mercator_x = (global_pixel_x / map_width) - 0.5
    mercator_y = 0.5 - (global_pixel_y / map_width)
    longitude = mercator_x * 360.0
    latitude = math.degrees(math.atan(math.sinh(mercator_y * 2 * math.pi)))
    return latitude, longitude

# Test the functions to verify they are inverses of each other
def test_conversion_accuracy():
    """Test that the functions are proper inverses"""
    test_cases = [
        # (lat, lon, zoom)
        (-33.8688, 151.2093, 10),  # Sydney
        (-33.8688, 151.2093, 16),  # Sydney high zoom
        (0.0, 0.0, 5),             # Equator/Prime Meridian
        (85.0, 180.0, 12),         # Near max latitude
        (-85.0, -180.0, 12),       # Near min latitude
        (40.7589, -73.9851, 14),   # New York
    ]
    
    print("Testing conversion accuracy (forward and back):")
    print("=" * 70)
    
    for lat, lon, zoom in test_cases:
        # Forward conversion
        tile_x, tile_y, pixel_x, pixel_y = lat_lon_to_tile_pixel(lat, lon, zoom)
        
        # Backward conversion
        lat_back, lon_back = tile_pixel_to_lat_lon(zoom, tile_x, tile_y, pixel_x, pixel_y)
        
        # Calculate error
        lat_error = abs(lat - lat_back)
        lon_error = abs(lon - lon_back)
        
        print(f"Original: ({lat:8.5f}, {lon:9.5f}) zoom={zoom}")
        print(f"  Tile: ({tile_x:5d}, {tile_y:5d}) Pixel: ({pixel_x:6.2f}, {pixel_y:6.2f})")
        print(f"  Back:   ({lat_back:8.5f}, {lon_back:9.5f})")
        print(f"  Error:  ({lat_error:8.7f}, {lon_error:9.7f})")
        print()

# Example: Sydney Opera House
latitude = -33.8688
longitude = 151.2093
zoom = 16

print(f"Converting lat/lon to tile coordinates:")
print(f"Input: Latitude={latitude}, Longitude={longitude}, Zoom={zoom}")

tile_x, tile_y, pixel_x, pixel_y = lat_lon_to_tile_pixel(latitude, longitude, zoom)

print(f"Result:")
print(f"  Tile: ({tile_x}, {tile_y})")  
print(f"  Pixel within tile: ({pixel_x:.3f}, {pixel_y:.3f})")

# Verify by converting back
lat_verify, lon_verify = tile_pixel_to_lat_lon(zoom, tile_x, tile_y, pixel_x, pixel_y)
print(f"\nVerification (converting back):")
print(f"  Original: ({latitude}, {longitude})")
print(f"  Converted back: ({lat_verify:.6f}, {lon_verify:.6f})")
print(f"  Difference: ({abs(latitude-lat_verify):.8f}, {abs(longitude-lon_verify):.8f})")

# Run comprehensive tests
print("\n" + "="*50)
test_conversion_accuracy()