# About this notebook

This notebook aims to prepare the FLAIR dataset for training a ControlNet model to generate fake satellite images of territories. This notebook converts the TIF images into a more useable format and query the OSM database to get georaphical information about a given image.

##### Limitations

The FLAIR dataset covers France only, this will introduce a biais in the learned terrain representation. https://ignf.github.io/FLAIR/

### 1. Imports

In [4]:
from tqdm import tqdm

import os
from glob import glob
import json
from osgeo import gdal, osr
import numpy as np
from pyproj import CRS, Transformer
from PIL import Image
import requests

### 2. Function definitions

In [5]:
LUT = [
    {"color": "#db0e9a", "class": "building"},
    {"color": "#938e7b", "class": "pervious surface"},
    {"color": "#f80c00", "class": "impervious surface"},
    {"color": "#a97101", "class": "bare soil"},
    {"color": "#1553ae", "class": "water"},
    {"color": "#194a26", "class": "coniferous"},
    {"color": "#46e483", "class": "deciduous"},
    {"color": "#f3a60d", "class": "brushwood"},
    {"color": "#660082", "class": "vineyard"},
    {"color": "#55ff00", "class": "herbaceous vegetation"},
    {"color": "#fff30d", "class": "agricultural land"},
    {"color": "#e4df7c", "class": "plowed land"},
    {"color": "#3de6eb", "class": "swimming pool"},
    {"color": "#ffffff", "class": "snow"},
    {"color": "#8ab3a0", "class": "clear cut"},
    {"color": "#6b714f", "class": "mixed"},
    {"color": "#c5dc42", "class": "ligneous"},
    {"color": "#9999ff", "class": "greenhouse"},
    {"color": "#000000", "class": "other"}
]

In [6]:
def convert_to_seg(image_array, lut):
    def hex_to_tuple(hex):
        h = hex.lstrip('#')
        return tuple(int(h[i:i+2], 16) for i in (0, 2, 4))

    for i in range(image_array.shape[0]):
        for j in range(image_array.shape[1]):
            class_id = image_array[i, j][0]
            
            color = hex_to_tuple(lut[class_id-1]["color"]) if class_id < len(lut) else (0,0,0)

            image_array[i, j] = color

    # Convert the NumPy array back to an image and save it to a file
    return image_array


def get_osm_data(latlon):
    lat, lon = latlon
    url = f'https://nominatim.openstreetmap.org/reverse?format=json&lat={lat}&lon={lon}&zoom=18&addressdetails=1'
    response = requests.get(url)
    return response.json()


def convert_to_latlon(coord, proj):
    if not proj:
        proj = 2154

    input_crs = CRS(f"EPSG:{proj}")
    wgs84_crs = CRS("EPSG:4326")  # WGS 84 (latitude et longitude)
    transformer = Transformer.from_crs(input_crs, wgs84_crs)
    return transformer.transform(*coord)


def convert_seg(input_path, output_path, LUT):
    input_tif = gdal.Open(input_path)
    rgb_data = np.zeros((input_tif.RasterYSize, input_tif.RasterXSize, 3), dtype=np.uint8)

    rgb_data[..., 0] = input_tif.GetRasterBand(1).ReadAsArray()
    rgb_data = convert_to_seg(rgb_data, LUT)

    # Close input TIFF file
    input_tif = None

    img = Image.fromarray(rgb_data)
    img.save(output_path)


def convert_image(input_path, output_path):
    input_tif = gdal.Open(input_path)
    rgb_data = np.zeros((input_tif.RasterYSize, input_tif.RasterXSize, 3), dtype=np.uint8)

    for i in range(3):
        rgb_data[..., i] = input_tif.GetRasterBand(i+1).ReadAsArray()

    # Close input TIFF file
    input_tif = None

    img = Image.fromarray(rgb_data)
    img.save(output_path)


def get_tif_metadata(input_path):
    tif_dataset = gdal.Open(input_path)

    # Get the origin, dimensions, and CRS
    geotransform = tif_dataset.GetGeoTransform()
    name = os.path.basename(input_path)
    origin = (geotransform[0], geotransform[3])
    dimensions = (tif_dataset.RasterXSize, tif_dataset.RasterYSize)

    # Get the CRS
    crs_wkt = tif_dataset.GetProjection()
    crs = osr.SpatialReference()
    crs.ImportFromWkt(crs_wkt)

    # Get the units
    unit_type = crs.GetLinearUnitsName()

    srs = osr.SpatialReference()
    srs.ImportFromWkt(tif_dataset.GetProjection())
    code = srs.GetAuthorityCode(None)

    latlon = convert_to_latlon(origin, code)

    # Create a dictionary to store the metadata
    metadata = {
        "origin": origin,
        "dimensions": dimensions,
        "unit_system": unit_type,
        "code": code,
        "latlong": latlon   
    } 

    return metadata


def get_output_path(path, output_dir, ext):
    basename = os.path.basename(path).split('.')[0]
    folder = basename.split('_')[1]

    if not os.path.exists(f"{output_dir}/{folder}/"):
        os.mkdir(f"{output_dir}/{folder}/")

    output_path = f"{output_dir}/{folder}/{basename}.{ext}"
    
    return output_path

### 3. Load image path

In [7]:
image_folder = "./data/flair_aerial_train"
output_dir = "./data/hf_dataset"
metadata_path = "./data/flair-1_metadata_aerial.json"

if not os.path.exists(output_dir):
    os.mkdir(output_dir)

In [8]:
def convert_path_img_to_msk(path):
    return (
        path
        .replace('aerial', 'labels')
        .replace('img', 'msk')
        .replace('IMG', 'MSK')
    )

In [9]:
# get all image path
test_tif_img = []

for file_name in glob(f"{image_folder}/*/*/img/*.tif"):
    test_tif_img.append(file_name)

test_tif_msk = [convert_path_img_to_msk(p) for p in test_tif_img]

data = dict(zip(['image', 'seg'], [test_tif_img, test_tif_msk]))

### 4. Data preprocessing

In [10]:
# convert image
images_path = data['image']

for path in tqdm(images_path):
    jpg_path = get_output_path(path, output_dir, 'png')
    convert_image(path, jpg_path)

100%|██████████| 61712/61712 [3:40:59<00:00,  4.65it/s]   


In [18]:
# convert seg
seg_path = data['seg']

for path in tqdm(seg_path):
    jpg_path = get_output_path(path, output_dir, 'png')
    convert_seg(path, jpg_path, LUT)

100%|██████████| 50/50 [01:19<00:00,  1.58s/it]


In [21]:
# get metadata
images_path = data['image']
f = open(metadata_path, 'r')
metadata_json = json.load(f)

for path in tqdm(images_path):
    basename = os.path.basename(path).split('.')[0]

    ocsge_metadata = metadata_json[basename]
    
    centroid = (ocsge_metadata['patch_centroid_x'], ocsge_metadata['patch_centroid_y'])
    osm_metadata = get_osm_data(centroid)

    tif_metadata = get_tif_metadata(path)

    combined_metadata = tif_metadata | ocsge_metadata | osm_metadata

    json_path = get_output_path(path, output_dir, 'json')
    with open(json_path, "w") as fp:
        json.dump(combined_metadata, fp)

100%|██████████| 50/50 [00:49<00:00,  1.01it/s]


In [18]:
from PIL import Image

path = "./data/hf_dataset/000834/MSK_000834.png"
def get_color_percentages(image_path, color_mapper=None):
    # Open image and convert to RGB mode
    with Image.open(image_path).convert("RGB") as image:
        # Get a list of (count, color) tuples
        color_counts = image.getcolors(image.size[0] * image.size[1])
    
    # Convert each (R, G, B) tuple to a hex string and store percentages in a dict
    color_percentages = {}
    total_pixels = sum(count for count, color in color_counts)
    for count, (r, g, b) in color_counts:
        hex_color = f"#{r:02x}{g:02x}{b:02x}"
        
        if color_mapper:
            hex_color = color_mapper(hex_color)

        color_percentages[hex_color] = count / total_pixels
    
    return color_percentages


def color_mapper(color_code):
    for color_dict in LUT:
        if color_dict['color'] == color_code:
            return color_dict['class']
    return color_code
  

get_color_percentages(path, color_mapper)

{'coniferous': 0.039398193359375,
 'brushwood': 0.02703857421875,
 'deciduous': 0.053955078125,
 'building': 0.348846435546875,
 'herbaceous vegetation': 0.2403564453125,
 'impervious surface': 0.2904052734375}