In [16]:
import pickle
import os
import rasterio
from pyproj import Transformer
import numpy as np
import matplotlib.pyplot as plt
import random
import geopandas as gpd
import pandas as pd
from shapely import wkt
from tqdm import tqdm


In [17]:
def show(data):
    """
    Inspect the structure of the 'data' dictionary.
    For each year, it prints the available keys, types, and shapes if applicable.
    """
    print("üîé Exploring data structure\n" + "-"*60)
    for year_key, year_data in data.items():
        print(f"{year_key}")
        
        if not isinstance(year_data, dict):
            print(f"  ‚ö†Ô∏è Expected dict, got {type(year_data)}\n")
            continue
        
        # Loop through subkeys (like imgs_array, iris_index, etc.)
        for key, value in year_data.items():
            info = f"  ‚îú‚îÄ {key:<25} ‚Üí "
            
            # Describe arrays
            if isinstance(value, (list, tuple)):
                info += f"list[{len(value)}]"
            elif isinstance(value, dict):
                info += f"dict[{len(value)}]"
            elif hasattr(value, "shape"):
                info += f"array shape={value.shape}, dtype={getattr(value, 'dtype', 'N/A')}"
            else:
                info += str(type(value))
            
            print(info)
        
        print("-"*60)

def save(data, save_path):
    with open(save_path, "wb") as f:
       pickle.dump(data, f)
    print(f"‚úÖ Data saved successfully at: {save_path}")

def load(save_path):
    with open(save_path, "rb") as f:
       data = pickle.load(f)
    print("‚úÖ Data loaded successfully!")

    return data
    
def read_band(file):
    """Lit une bande Landsat et nettoie les valeurs aberrantes."""
    with rasterio.open(file) as src:
        band = src.read(1).astype(float)
        # Supprime les valeurs satur√©es ou nulles
        band[(band <= 0) | (band >= 60000)] = 0
        return band
    
def get_array_img(path, date):
    B1 = ""
    B2 = ""
    B3 = ""
    B4 = ""
    B5 = ""
    B6 = ""
    for dirpath, _, filenames in os.walk(path):
        for f in filenames:
            if f.lower().endswith("b1.tif"):
                B1 = os.path.join(dirpath, f)
            if f.lower().endswith("b2.tif"):
                B2 = os.path.join(dirpath, f)
            if f.lower().endswith("b3.tif"):
                B3 = os.path.join(dirpath, f)
            if f.lower().endswith("b4.tif"):
                B4 = os.path.join(dirpath, f)
            if f.lower().endswith("b5.tif"):
                B5 = os.path.join(dirpath, f)
            if f.lower().endswith("b6.tif"):
                B6 = os.path.join(dirpath, f)
    

    if date >= 2013:
        red = read_band(B6) 
        green = read_band(B5) 
        blue = read_band(B4)

        red = red / red.max()
        green = green / green.max()
        blue = blue / blue.max()
    
    else:
        red = read_band(B5)
        green = read_band(B4)
        blue = read_band(B3)

    
    rgb = np.dstack((red, green, blue))
    p2, p98 = np.nanpercentile(rgb, (2, 98))
    rgb_norm = np.clip((rgb - p2) / (p98 - p2), 0, 1)
    rgb_8bit = (rgb_norm * 255).astype(np.uint8)

    return rgb_8bit

def return_tif_files(dir):
    for dirpath, _, filenames in os.walk(dir):
        for f in filenames:
            if f.lower().endswith(".tif"):
                path = os.path.join(dirpath, f)
                try:
                    with rasterio.open(path) as src:
                        if src.crs is not None and src.transform is not None:
                            return path
                
                except rasterio.errors.RasterioIOError:
                    continue

def get_grid_coordinates(path, format = "lat_lon"): #this file ends with .Tif

    with rasterio.open(path) as src:
        transform = src.transform
        crs = src.crs
        width, height = src.width, src.height

    cols, rows = np.meshgrid(np.arange(width), np.arange(height))
    xs, ys = rasterio.transform.xy(transform, rows, cols, offset='center')

    xs = np.array(xs).reshape(height, width)
    ys = np.array(ys).reshape(height, width)

    if format != "lat_lon":
        return xs, ys
    
    transformer = Transformer.from_crs(crs, "EPSG:4326", always_xy=True)
    lons, lats = transformer.transform(xs, ys)

    return lons, lats

def world_to_pixel(x, y, transform):
    """Convert geospatial coords (x,y) to pixel indices (row,col)."""
    col, row = ~transform * (x, y)
    return int(round(row)), int(round(col))

def pixel_to_world(col, row, transform):
    """Convert pixel indices (col,row) to geospatial coordinates (x,y)."""
    x, y = transform * (col, row)
    return x, y

def get_city_patch_params(city: str):
    """
    Returns i_start, j_start, i_length, j_length for a given city.
    """

    if city == "Lyon":
        i_start, j_start = 4400, 5000
        i_length, j_length = 1000, 800

    elif city == "Paris":
        i_start, j_start = 3000, 3000
        i_length, j_length = 1600, 1600

    elif city == "Toulouse":
        i_start, j_start = 1800, 2600
        i_length, j_length = 1000, 600

    elif city == "Bordeaux":
        i_start, j_start = 2600, 4100
        i_length, j_length = 800, 800

    else:
        raise ValueError(f"Unknown city: {city}")

    return i_start, j_start, i_length, j_length

def get_geospatial_coordinates(path, num_patch, i_start, j_start, i_length, j_length):

    
    path2021 = os.path.join(path, "2021")
    tif_2021 = return_tif_files(path2021)
    with rasterio.open(tif_2021) as src2021:
        transform2021 = src2021.transform
    
    i_s = np.random.randint(i_start, i_start+i_length, num_patch)
    j_s = np.random.randint(j_start, j_start+j_length, num_patch)

    x_s = [0]* num_patch
    y_s = [0]* num_patch

    for k in range(num_patch):
        i, j = i_s[k], j_s[k]
        x, y = pixel_to_world(j, i, transform2021)
        x_s[k], y_s[k] = x, y

    return x_s , y_s

def get_patches(x_s, y_s, path, patch_size):
    num_patches = len(x_s)
    data = {}

    for date in range(2013, 2022):
        if date != 2012:
            print(date)
            new_path = os.path.join(path, str(date))
            path_tif_file = return_tif_files(new_path)
            img = get_array_img(new_path, date)
            imgs_array = np.zeros((num_patches, patch_size, patch_size, 3), dtype=np.uint8)
            space_coordinates = np.zeros((num_patches, patch_size, patch_size, 2), dtype=np.float32)

            with rasterio.open(path_tif_file) as src:
                transform = src.transform
                crs = src.crs

            # prepare coordinate transformer (to lat/lon)
            transformer = Transformer.from_crs(crs, "EPSG:4326", always_xy=True)

            for k in range(num_patches):
                x, y = x_s[k], y_s[k]
                i, j = world_to_pixel(x, y, transform)

                # image patch
                imgs_array[k, :, :, :] = img[i:i+patch_size, j:j+patch_size, :]

                # grid of pixel indices within the patch
                rows = np.arange(i, i + patch_size)
                cols = np.arange(j, j + patch_size)
                cols_grid, rows_grid = np.meshgrid(cols, rows)

                # convert to map coordinates
                xs, ys = rasterio.transform.xy(transform, rows_grid, cols_grid, offset="center")
                xs = np.array(xs).reshape((patch_size, patch_size))
                ys = np.array(ys).reshape((patch_size, patch_size))

                # convert map coords ‚Üí lon/lat
                lons, lats = transformer.transform(xs, ys)

                # stack and save
                space_coordinates[k, :, :, :] = np.dstack((lats, lons)).astype(np.float32)

            data[str(date)] = {
                "imgs_array": imgs_array,
                "space_coordinates": space_coordinates
            }

    return data

def assign_iris_to_data(data, iris_csv_paths, batch_size=100_000):
    """
    Adds 'iris_index' arrays to the existing 'data' dictionary.
    Each iris_index has shape [num_patches, patch_size, patch_size],
    with each pixel containing its CODE_IRIS (string or np.nan).
    """

    for year, content in data.items():
        if not year.isdigit():
            continue
        if int(year) not in iris_csv_paths:
            print(f"‚ö†Ô∏è No IRIS CSV for {year}, skipping...")
            continue

        print(f"\nüìÖ Processing year {year}...")

        # Load IRIS polygons for that year
        iris_csv = iris_csv_paths[int(year)]
        iris_df = pd.read_csv(iris_csv)
        iris_df["geometry"] = iris_df["geometry"].apply(wkt.loads)
        iris_gdf = gpd.GeoDataFrame(iris_df, geometry="geometry", crs="EPSG:4326")

        # Get pixel coordinate array
        coords = content["space_coordinates"]  # shape [num_patches, H, W, 2]
        num_patches, H, W, _ = coords.shape
        total_points = num_patches * H * W

        print(f"   ‚Üí Flattening {total_points:,} pixels...")

        # Flatten coordinates
        flat_coords = coords.reshape(-1, 2)
        lats = flat_coords[:, 0]
        lons = flat_coords[:, 1]

        # Prepare output
        codes = np.empty(total_points, dtype=object)

        # Process in batches
        for start in tqdm(range(0, total_points, batch_size), desc=f"Joining {year}", ncols=80):
            end = min(start + batch_size, total_points)
            batch_lats = lats[start:end]
            batch_lons = lons[start:end]

            # Create GeoDataFrame of points
            points_gdf = gpd.GeoDataFrame(
                geometry=gpd.points_from_xy(batch_lons, batch_lats),
                crs="EPSG:4326"
            )

            # Spatial join (vectorized)
            joined = gpd.sjoin(
                points_gdf,
                iris_gdf[["CODE_IRIS", "geometry"]],
                how="left",
                predicate="within"
            )

            codes[start:end] = joined["CODE_IRIS"].to_numpy()

        # Reshape back to [num_patches, H, W]
        iris_index = codes.reshape(num_patches, H, W)

        # Attach to data
        data[year]["iris_index"] = iris_index

        matched = np.count_nonzero(~pd.isna(codes))
        print(f"‚úÖ Finished {year}: matched {matched:,}/{total_points:,} pixels")

    print("\n‚úÖ Added 'iris_index' to all matching years.")
    return data

iris_csv_paths = {
    2011: r"C:\Users\adamh\Desktop\IRIS\2011\iris_latlon.csv",
    2012: r"C:\Users\adamh\Desktop\IRIS\2012\iris_latlon.csv",
    2013: r"C:\Users\adamh\Desktop\IRIS\2013\iris_latlon.csv",
    2014: r"C:\Users\adamh\Desktop\IRIS\2014\iris_latlon.csv",
    2015: r"C:\Users\adamh\Desktop\IRIS\2015\iris_latlon.csv",
    2016: r"C:\Users\adamh\Desktop\IRIS\2016\iris_latlon.csv",
    2017: r"C:\Users\adamh\Desktop\IRIS\2017\iris_latlon.csv",
    2018: r"C:\Users\adamh\Desktop\IRIS\2018\iris_latlon.csv",
    2019: r"C:\Users\adamh\Desktop\IRIS\2019\iris_latlon.csv",
    2020: r"C:\Users\adamh\Desktop\IRIS\2020\iris_latlon.csv",
    2021: r"C:\Users\adamh\Desktop\IRIS\2021\iris_latlon.csv",
}



In [18]:
city = "Toulouse"
base_dir = r"C:\Users\adamh\Desktop\Satelite_images"

save_path = os.path.join(base_dir, f"{city}_data.pkl")
city_path = os.path.join(base_dir, "Satellite_Images", city)

In [19]:
num_patch = 500
i_start, j_start, i_length, j_length = get_city_patch_params(city)
patch_size = 128

x_s, y_s = get_geospatial_coordinates(city_path, num_patch, i_start, j_start, i_length, j_length)
data = get_patches(x_s, y_s, city_path, patch_size)


2013
2014
2015
2016
2017
2018
2019
2020
2021


In [20]:
data = assign_iris_to_data(data, iris_csv_paths)
save(data, save_path)


üìÖ Processing year 2013...
   ‚Üí Flattening 8,192,000 pixels...


Joining 2013: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 82/82 [00:10<00:00,  7.92it/s]


‚úÖ Finished 2013: matched 8,192,000/8,192,000 pixels

üìÖ Processing year 2014...
   ‚Üí Flattening 8,192,000 pixels...


Joining 2014: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 82/82 [00:11<00:00,  7.44it/s]


‚úÖ Finished 2014: matched 8,192,000/8,192,000 pixels

üìÖ Processing year 2015...
   ‚Üí Flattening 8,192,000 pixels...


Joining 2015: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 82/82 [00:11<00:00,  7.35it/s]


‚úÖ Finished 2015: matched 8,192,000/8,192,000 pixels

üìÖ Processing year 2016...
   ‚Üí Flattening 8,192,000 pixels...


Joining 2016: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 82/82 [00:11<00:00,  7.34it/s]


‚úÖ Finished 2016: matched 8,192,000/8,192,000 pixels

üìÖ Processing year 2017...
   ‚Üí Flattening 8,192,000 pixels...


Joining 2017: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 82/82 [00:11<00:00,  7.21it/s]


‚úÖ Finished 2017: matched 8,192,000/8,192,000 pixels

üìÖ Processing year 2018...
   ‚Üí Flattening 8,192,000 pixels...


Joining 2018: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 82/82 [00:11<00:00,  7.23it/s]


‚úÖ Finished 2018: matched 8,192,000/8,192,000 pixels

üìÖ Processing year 2019...
   ‚Üí Flattening 8,192,000 pixels...


Joining 2019: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 82/82 [00:11<00:00,  7.06it/s]


‚úÖ Finished 2019: matched 8,192,000/8,192,000 pixels

üìÖ Processing year 2020...
   ‚Üí Flattening 8,192,000 pixels...


Joining 2020: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 82/82 [00:11<00:00,  7.12it/s]


‚úÖ Finished 2020: matched 8,192,000/8,192,000 pixels

üìÖ Processing year 2021...
   ‚Üí Flattening 8,192,000 pixels...


Joining 2021: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 82/82 [00:18<00:00,  4.43it/s]


‚úÖ Finished 2021: matched 8,192,000/8,192,000 pixels

‚úÖ Added 'iris_index' to all matching years.
‚úÖ Data saved successfully at: C:\Users\adamh\Desktop\Satelite_images\Toulouse_data.pkl
