In [None]:
import os
import glob
import rasterio
from rasterio.mask import mask
from rasterio.warp import reproject, Resampling, calculate_default_transform, transform_bounds
import numpy as np
from shapely.geometry import box
from sklearn.metrics import accuracy_score
import pandas as pd

dir_ = '/Users/arthurcalvi/Data/species/validation/tiles'
model_name = 'XGBoost'
config = "no_resample_cloud_disturbance_weights_3Y"
extra = config + '_Group'
dlt_dir = '/Users/arthurcalvi/Data/species/DLT_2018_010m_fr_03035_v020/DLT_Dominant_Leaf_Type_France.tif'  # Update this with the actual path to your DLT files

# DataFrame to store metrics
metrics_df = pd.DataFrame(columns=['tile_id', 'agreement_percentage', 'support'])

# Function to reproject the DLT data to match the classification map CRS
def reproject_to_match(src_array, src_transform, src_crs, dst_crs, dst_shape, dst_transform):
    dst_array = np.empty(dst_shape, dtype=src_array.dtype)
    reproject(
        source=src_array,
        destination=dst_array,
        src_transform=src_transform,
        src_crs=src_crs,
        dst_transform=dst_transform,
        dst_crs=dst_crs,
        resampling=Resampling.nearest
    )
    return dst_array

# Function to crop and reproject VRT to the bounds of the classification map
def crop_vrt_to_bounds(vrt_path, bounds, out_shape, crs):
    with rasterio.open(vrt_path) as src:
        # Reproject bounds to VRT CRS
        vrt_crs = src.crs
        bounds_vrt = transform_bounds(crs, vrt_crs, *bounds)
        
        # Calculate the transform and shape of the output array
        out_transform, out_width, out_height = calculate_default_transform(
            vrt_crs, crs, src.width, src.height, *bounds_vrt)
        out_shape = (out_height, out_width)

        # Read the data from the VRT
        out_image, _ = mask(src, [box(*bounds_vrt)], crop=True)
        out_image = out_image[0]  # Extract the first band

        # Reproject the data to match the classification map CRS
        out_image = reproject_to_match(out_image, src.transform, vrt_crs, crs, out_shape, out_transform)
        return out_image, out_transform

for filename in os.listdir(dir_):
    path = os.path.join(dir_, filename)
    if not os.path.isdir(path):
        continue

    tile_id = filename.split('_')[1]
    print(path)
    classification_map_path = os.path.join(path, 'results', f'{model_name}_{extra}.tif')
    try:
        with rasterio.open(classification_map_path) as src:
            crs = src.crs
            raster = src.read(1)  # 0 no forest, 1 deciduous, 2 evergreen
            raster_bounds = src.bounds
            raster_transform = src.transform
            raster_meta = src.meta

            # Crop and reproject the VRT to the bounds of the classification map
            dlt_raster, dlt_transform = crop_vrt_to_bounds(vrt_path, raster_bounds, raster.shape, crs)
            if dlt_raster is None:
                continue

            # Ensure the shapes match
            if dlt_raster.shape != raster.shape:
                print(f"Shape mismatch between DLT data and classification map for {classification_map_path}")
                continue

            # Compute metrics
            y_true = dlt_raster.flatten()
            y_pred = raster.flatten()

            # Filter out non-forest pixels
            mask = y_true > 0
            y_true = y_true[mask]
            y_pred = y_pred[mask]

            # Compute agreement percentage
            same_class = y_true == y_pred
            agreement_percentage = np.sum(same_class) / len(y_true) * 100
            support = len(y_true)

            # Print metrics
            print(f"Tile ID: {tile_id}")
            print(f"Agreement Percentage: {agreement_percentage}")
            print(f"Support: {support}")

            # Save metrics to DataFrame
            metrics_df = metrics_df.append({
                'tile_id': tile_id,
                'agreement_percentage': agreement_percentage,
                'support': support
            }, ignore_index=True)

    except Exception as e:
        print(f"Error for {classification_map_path}: {e}")
        continue

# Save metrics DataFrame to a CSV file
metrics_df.to_csv('results/metrics_DLT.csv', index=False)

print("Metrics saved to results/metrics_DLT.csv")
