In [None]:
# Load the data and preprocess it
from tqdm import tqdm 
from datetime import datetime 
from sklearn.preprocessing import RobustScaler
from utils import load_folder, calculate_slope_with_dates, load_checkpoint
import rasterio 
import os
import numpy as np
import pandas as pd
import cv2

features = ['amplitude_red', 'cos_phase_red','offset_red',
            'cos_phase_blue', 
            'amplitude_crswir', 'cos_phase_crswir', 'sin_phase_crswir', 'offset_crswir', 
            'elevation']
model_name = 'XGBoost'
config = "no_resample_cloud_disturbance_weights_3Y"
extra = config + '_Group'
model = load_checkpoint(model_name, checkpoint_dir='checkpoints', extra=extra).best_estimator_
directory = '/Users/arthurcalvi/Data/species/validation/tiles'


def load_data_from_tile_inf(path: str, config: str) -> dict:
    tile_id = os.path.basename(path).split('_')[1]
    dates = [datetime.strptime(filename.split('_')[0], '%Y-%m-%d') for filename in os.listdir(os.path.join(path, 'rgb'))]
    dates.sort()
    rgb = load_folder(os.path.join(path, 'rgb'))
    chm = rasterio.open(os.path.join(path, 'tree_map', 'CHM2020.tif')).read(1)
    forest_mask = (chm > 250).astype(bool)
    #EROSION of chm of 1 pixel with opencv
    kernel = np.ones((3,3),np.uint8)
    forest_mask = cv2.erode(forest_mask.astype(np.uint8), kernel, iterations=1).astype(bool)

    slope_map = calculate_slope_with_dates(rgb[:, 0], dates, len(rgb[:, 0]) / 2, len(rgb[:, 0])) / 100
    weights = (1 - abs(slope_map.ravel())).clip(0, 1)

    path_features = os.path.join(path, 'features')
    r_APO = rasterio.open(os.path.join(path_features, f'APO_R_{config}.tif')).read()
    amplitude_map_r, phase_map_r, offset_map_r = r_APO[0], r_APO[1], r_APO[2]
    g_APO = rasterio.open(os.path.join(path_features, f'APO_G_{config}.tif')).read()
    amplitude_map_g, phase_map_g, offset_map_g = g_APO[0], g_APO[1], g_APO[2]
    b_APO = rasterio.open(os.path.join(path_features, f'APO_B_{config}.tif')).read()
    amplitude_map_b, phase_map_b, offset_map_b = b_APO[0], b_APO[1], b_APO[2]
    crswir_APO = rasterio.open(os.path.join(path_features, f'APO_CRSWIR_{config}.tif')).read()
    amplitude_map_crswir, phase_map_crswir, offset_map_crswir = crswir_APO[0], crswir_APO[1], crswir_APO[2]
    dem = rasterio.open(os.path.join(path_features, 'elevation_aspect.tif')).read()
    elevation, aspect = dem[0], dem[1]

    features_ = {
        'amplitude_red': amplitude_map_r.ravel(),
        'cos_phase_red': np.cos(phase_map_r.ravel()),
        'offset_red': offset_map_r.ravel(),
        'cos_phase_blue': np.cos(phase_map_b.ravel()),
        'amplitude_crswir': amplitude_map_crswir.ravel(),
        'cos_phase_crswir': np.cos(phase_map_crswir.ravel()),
        'sin_phase_crswir': np.sin(phase_map_crswir.ravel()),
        'offset_crswir': offset_map_crswir.ravel(),
        'elevation': elevation.ravel(),
        'tile_id': np.array([tile_id] * aspect.size)  # Add tile_id to the features
    }

    return features_, forest_mask, amplitude_map_crswir.shape 

print(f"Using configuration: {config}")
for folder in tqdm(os.listdir(directory)):
    path = os.path.join(directory, folder)
    
    if folder.__contains__('.DS_Store') or folder.__contains__('.txt'):
        continue
    try:
    # if True:
        print(f"Processing {folder}")
        features_, forest_mask, shape = load_data_from_tile_inf(path, config)
        features_ = pd.DataFrame(features_)[features]

        #preprocess data
        features_.replace([np.inf, -np.inf], np.nan, inplace=True)
        #fill with last value
        features_.fillna(method='ffill', inplace=True)

        if model_name in ["Logistic Regression", "KNN", "MLP"]:
            #scale data
            scaler = RobustScaler()
            features_ = scaler.fit_transform(features_)

        results = model.predict(features_) + 1 # Add 1 to the results to match the classes in the tree map
        results[ forest_mask.ravel() == False ] = 0

        # Save the results to a GeoTIFF file
        ref = rasterio.open(os.path.join(path, 'tree_map', 'CHM2020.tif'))
        profile = ref.profile

        path_results = os.path.join(directory, folder, 'results')
        os.makedirs(path_results, exist_ok=True)
        path_file = os.path.join(path_results, f'{model_name}_{extra}.tif')
        profile.update(dtype=rasterio.uint8, count=1, compress='lzw', nodata=0)
        with rasterio.open(path_file, 'w', **profile) as dst:
            dst.write(results.reshape(shape), 1)

    except Exception as e:
        print(f"Error processing {folder}: {e}")
        continue
