# Inference

In [None]:
import os
import numpy as np
import rasterio
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from datetime import datetime
from utils import load_folder, normalize, crop 


def plot_seasonal_images(sample_folder: str, filename: str, factor=2) -> None:
    """
    Plots one image per season along with the classification map.
    
    Args:
        sample_folder (str): Folder containing the sample data.
    """
    # Load dates and RGB images
    dates = [datetime.strptime(filename.split('_')[0], '%Y-%m-%d') for filename in os.listdir(os.path.join(sample_folder, 'rgb'))]
    dates.sort()
    rgb = load_folder(os.path.join(sample_folder, 'rgb'))

    # Define seasons and corresponding months
    seasons = {
        'Winter': [12, 1, 2],
        'Spring': [3, 4, 5],
        'Summer': [6, 7, 8],
        'Autumn': [9, 10, 11]
    }

    # Load tree classification map
    output_name = os.path.join(sample_folder, 'results', f'{filename}.tif')
    tree_class = rasterio.open(output_name).read(1)

    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    for ax in axes.ravel():
        ax.set_axis_off()

    fig.subplots_adjust(hspace=0.4, wspace=0.4)
    cmap = ListedColormap(['lightgray', 'brown', 'green'])

    tree_class_cropped = crop(tree_class, factor=factor)
    tree_ax = plt.subplot2grid((2, 4), (0, 0), colspan=2, rowspan=2)

    # Plot the RGB images for each season in the right 2x2 space
    for idx, (season, months) in enumerate(seasons.items()):
        season_dates = [date for date in dates if date.month in months]
        if not season_dates:
            continue

        # Select the first date of each season for simplicity
        season_date = season_dates[0]
        season_index = dates.index(season_date)
        rgb_season = rgb[season_index]
        
        row, col = divmod(idx, 2)
        ax = plt.subplot2grid((2, 4), (row, col + 2))
        ax.imshow(2.5 * normalize(crop(rgb_season, factor=factor).transpose(1, 2, 0)))
        ax.annotate(season, (10, 10), color='white', fontsize=8, fontweight='bold')
        ax.axis('off')

        if season == 'Winter':
            tree_ax.imshow(2.5 * normalize(crop(rgb_season, factor=factor).transpose(1, 2, 0)))
    

    # Plot the tree classification map on the left 2x2 space
    im = tree_ax.imshow(tree_class_cropped, cmap=cmap, interpolation='none', alpha=0.4)
    tree_ax.axis('off')

    # Add custom legend for classification map
    import matplotlib.patches as mpatches
    patches = [mpatches.Patch(color='lightgray', label='No forest'),
               mpatches.Patch(color='brown', label='Deciduous'),
               mpatches.Patch(color='green', label='Evergreen')]
    fig.legend(handles=patches, loc='lower center', ncol=3, bbox_to_anchor=(0.5, -0.05))

    plt.tight_layout()
    plt.show()

    return fig

model_name = 'XGBoost'
config = "no_resample_cloud_disturbance_weights_3Y"
extra = config + '_Group'
filename = f'{model_name}_{extra}'

folder_dir = '/Users/arthurcalvi/Data/species/validation/tiles'
files = os.listdir(folder_dir)
files = [f for f in files if f != '.DS_Store']
index = 80
sample_folder = os.path.join(folder_dir, files[index])
print(sample_folder)

fig = plot_seasonal_images(sample_folder, filename=filename, factor=2)
#save with legend outside
# fig.savefig(f'images/seasonal_images_{index}.png', dpi=300)
fig.savefig(f'images/seasonal_images_{index}.png', dpi=300, bbox_inches='tight')

# Comparaison BDFORET

In [None]:
import os
import numpy as np
import rasterio
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from datetime import datetime
from utils import load_folder, normalize, crop 

def plot_seasonal_images(sample_folder: str, filename: str, factor=2) -> None:
    """
    Plots one image per season along with the classification map.
    
    Args:
        sample_folder (str): Folder containing the sample data.
    """
    # Load dates and RGB images
    dates = [datetime.strptime(fname.split('_')[0], '%Y-%m-%d') for fname in os.listdir(os.path.join(sample_folder, 'rgb'))]
    dates.sort()
    rgb = load_folder(os.path.join(sample_folder, 'rgb'))

    # Define seasons and corresponding months
    seasons = {
        'Winter': [12, 1, 2],
        'Spring': [3, 4, 5],
        'Summer': [6, 7, 8],
        'Autumn': [9, 10, 11]
    }

    # Load BDFORET
    output_name = os.path.join(sample_folder, 'reference_species', 'bdforet.tif')
    bdforet = rasterio.open(output_name).read(1)

    # Load tree classification map
    output_name = os.path.join(sample_folder, 'results', f'{filename}.tif')
    tree_class = rasterio.open(output_name).read(1)

    fig, axes = plt.subplots(2, 6, figsize=(18, 6))
    for ax in axes.ravel():
        ax.set_axis_off()

    fig.subplots_adjust(hspace=0.4, wspace=0.4)
    cmap = ListedColormap(['lightgray', 'brown', 'green'])

    tree_class_cropped = crop(tree_class, factor=factor)
    bdforet_cropped = crop(bdforet, factor=factor)

    # Plot the tree classification map in the first 2x2 space
    tree_ax = plt.subplot2grid((2, 6), (0, 0), colspan=2, rowspan=2)
    

    # Plot the BDFORET map in the second 2x2 space
    bdforet_ax = plt.subplot2grid((2, 6), (0, 2), colspan=2, rowspan=2)

    # Plot the RGB images for each season in the 4x2 space
    for idx, (season, months) in enumerate(seasons.items()):
        season_dates = [date for date in dates if date.month in months]
        if not season_dates:
            continue

        # Select the first date of each season for simplicity
        season_date = season_dates[0]
        season_index = dates.index(season_date)
        rgb_season = rgb[season_index]
        
        row, col = divmod(idx, 2)
        ax = plt.subplot2grid((2, 6), (row, col + 4))
        ax.imshow(2.5 * normalize(crop(rgb_season, factor=factor).transpose(1, 2, 0)))
        ax.annotate(season, (10, 10), color='white', fontsize=8, fontweight='bold')
        ax.axis('off')

        if season == 'Winter':
            tree_ax.imshow(2.5 * normalize(crop(rgb_season, factor=factor).transpose(1, 2, 0)))
            bdforet_ax.imshow(2.5 * normalize(crop(rgb_season, factor=factor).transpose(1, 2, 0)))

    tree_ax.imshow(tree_class_cropped, cmap=cmap, interpolation='none', alpha=0.4)
    tree_ax.set_title("Tree Classification Map")
    tree_ax.axis('off')

    bdforet_ax.imshow(bdforet_cropped, cmap=cmap, interpolation='none', alpha=0.4)
    bdforet_ax.set_title("BDFORET Map")
    bdforet_ax.axis('off')


    # Add custom legend for classification map
    import matplotlib.patches as mpatches
    patches = [mpatches.Patch(color='lightgray', label='No forest'),
               mpatches.Patch(color='brown', label='Deciduous'),
               mpatches.Patch(color='green', label='Evergreen')]
    fig.legend(handles=patches, loc='lower center', ncol=3, bbox_to_anchor=(0.5, -0.05))

    plt.tight_layout()
    plt.show()

    return fig

model_name = 'XGBoost'
config = "no_resample_cloud_disturbance_weights_3Y"
extra = config + '_Group'
filename = f'{model_name}_{extra}'

folder_dir = '/Users/arthurcalvi/Data/species/validation/tiles'
files = os.listdir(folder_dir)
files = [f for f in files if f != '.DS_Store']
index = 40
sample_folder = os.path.join(folder_dir, files[index])
print(sample_folder)

fig = plot_seasonal_images(sample_folder, filename=filename, factor=2)
#save with legend outside
fig.savefig(f'images/seasonal_images_{index}_bdforet.png', dpi=300, bbox_inches='tight')
