# Loading data

In [None]:
import os
import numpy as np
import pandas as pd
import rasterio
from datetime import datetime
from sklearn.impute import SimpleImputer
from tqdm import tqdm
from utils import load_folder, calculate_slope_with_dates

# Remove warnings
import warnings
warnings.filterwarnings("ignore")

def load_data_from_tile(path: str, config: str) -> dict:
    tile_id = os.path.basename(path).split('_')[1]
    dates = [datetime.strptime(filename.split('_')[0], '%Y-%m-%d') for filename in os.listdir(os.path.join(path, 'rgb'))]
    dates.sort()
    rgb = load_folder(os.path.join(path, 'rgb'))
    chm = rasterio.open(os.path.join(path, 'tree_map', 'CHM2020.tif')).read(1)
    forest_mask = (chm > 250).astype(bool)
    slope_map = calculate_slope_with_dates(rgb[:, 0], dates, len(rgb[:, 0]) / 2, len(rgb[:, 0])) / 100
    weights = (1 - abs(slope_map.ravel())).clip(0, 1)

    path_features = os.path.join(path, 'features')
    r_APO = rasterio.open(os.path.join(path_features, f'APO_R_{config}.tif')).read()
    amplitude_map_r, phase_map_r, offset_map_r = r_APO[0], r_APO[1], r_APO[2]
    g_APO = rasterio.open(os.path.join(path_features, f'APO_G_{config}.tif')).read()
    amplitude_map_g, phase_map_g, offset_map_g = g_APO[0], g_APO[1], g_APO[2]
    b_APO = rasterio.open(os.path.join(path_features, f'APO_B_{config}.tif')).read()
    amplitude_map_b, phase_map_b, offset_map_b = b_APO[0], b_APO[1], b_APO[2]
    crswir_APO = rasterio.open(os.path.join(path_features, f'APO_CRSWIR_{config}.tif')).read()
    amplitude_map_crswir, phase_map_crswir, offset_map_crswir = crswir_APO[0], crswir_APO[1], crswir_APO[2]
    dem = rasterio.open(os.path.join(path_features, 'elevation_aspect.tif')).read()
    elevation, aspect = dem[0], dem[1]

    features = {
        'amplitude_red': amplitude_map_r.ravel(),
        'phase_red': phase_map_r.ravel(),
        'offset_red': offset_map_r.ravel(),
        'amplitude_green': amplitude_map_g.ravel(),
        'phase_green': phase_map_g.ravel(),
        'offset_green': offset_map_g.ravel(),
        'amplitude_blue': amplitude_map_b.ravel(),
        'phase_blue': phase_map_b.ravel(),
        'offset_blue': offset_map_b.ravel(),
        'amplitude_crswir': amplitude_map_crswir.ravel(),
        'phase_crswir': phase_map_crswir.ravel(),
        'offset_crswir': offset_map_crswir.ravel(),
        'elevation': elevation.ravel(),
        'aspect': aspect.ravel(),
        'tile_id': np.array([tile_id] * aspect.size)  # Add tile_id to the features
    }

    path_reference = os.path.join(path, 'reference_species')
    tif = [x for x in os.listdir(path_reference) if x.endswith('.tif')]
    reference = rasterio.open(os.path.join(path_reference, tif[0])).read()
    genus = reference[1]
    phen = reference[2]  # Assuming phenology data is stored in the third band
    source = reference[4]
    valid_mask = (forest_mask & (phen != 0)).astype(bool)

    filtered_features = {k: v[valid_mask.ravel()] for k, v in features.items()}
    filtered_weights = weights[valid_mask.ravel()]
    filtered_genus = genus[valid_mask]
    filtered_phen = phen[valid_mask]
    filtered_source = source[valid_mask]

    filtered_features['genus'] = filtered_genus
    filtered_features['phen'] = filtered_phen
    filtered_features['source'] = filtered_source

    df = pd.DataFrame(filtered_features)
    df = df.dropna()

    return df, filtered_weights[df.index]

def load_data(directory: str, config: str) -> pd.DataFrame:
    all_data = []
    all_weights = []
    tile_to_greco = {}

    for folder in tqdm(os.listdir(directory)):
        path = os.path.join(directory, folder)
        if folder.__contains__('.DS_Store') or folder.__contains__('.txt'):
            continue
        try:
            tile_df, tile_weight = load_data_from_tile(path, config)
            tile_id = os.path.basename(path).split('_')[1]
            greco_region = "_".join(os.path.basename(path).split('_')[4:-1])
            tile_to_greco[tile_id] = greco_region
            tile_df['tile_id'] = tile_id
            all_data.append(tile_df)
            all_weights.append(tile_weight)
        except Exception as e:
            print(f"Error processing {folder}: {e}")
            continue

    print(f"Loaded {len(all_data)} tiles")
    data_df = pd.concat(all_data, ignore_index=True)
    weights_array = np.concatenate(all_weights)

    return data_df, weights_array, tile_to_greco

# # Load data
# data_dir = '/Users/arthurcalvi/Data/species/validation/tiles'
# data, all_weights, tile_to_greco = load_data(data_dir)

# # Verify data structure
# print(data.head())
# print(tile_to_greco)


In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import GroupShuffleSplit
from collections import Counter
from utils import mapping_real_greco
# Split data into training and validation sets
def stratified_group_split(data, test_size=0.25, random_state=42):
    splitter = GroupShuffleSplit(n_splits=1, test_size=test_size, random_state=random_state)
    split = splitter.split(data, groups=data['tile_id'])
    train_idx, val_idx = next(split)
    
    train_data = data.iloc[train_idx]
    val_data = data.iloc[val_idx]
    
    return train_data, val_data

# Verify the split
def verify_split(train_data, val_data):
    # Check distribution of phenology classes
    train_phen_counts = Counter(train_data['phen'])
    val_phen_counts = Counter(val_data['phen'])
    
    print("Training phenology class distribution:", train_phen_counts)
    print("Validation phenology class distribution:", val_phen_counts)
    
    # Check distribution of GRECO regions
    train_greco_counts = Counter(train_data['greco_region'])
    val_greco_counts = Counter(val_data['greco_region'])
    
    print("Training GRECO region distribution:", train_greco_counts)
    print("Validation GRECO region distribution:", val_greco_counts)
    
    # Ensure no overlap of tiles between training and validation
    train_tiles = set(train_data['tile_id'])
    val_tiles = set(val_data['tile_id'])
    
    print("Common tiles between training and validation:", train_tiles.intersection(val_tiles))
    
    # Create summary table
    summary_table = pd.DataFrame(columns=['GRECO Region', 'Set', 'Deciduous', 'Evergreen', 'Total'])
    
    for greco_region in set(train_data['greco_region']).union(set(val_data['greco_region'])):
        train_deciduous = len(train_data[(train_data['greco_region'] == greco_region) & (train_data['phen'] == 1)])
        train_evergreen = len(train_data[(train_data['greco_region'] == greco_region) & (train_data['phen'] == 2)])
        val_deciduous = len(val_data[(val_data['greco_region'] == greco_region) & (val_data['phen'] == 1)])
        val_evergreen = len(val_data[(val_data['greco_region'] == greco_region) & (val_data['phen'] == 2)])
        
        summary_table = summary_table.append({
            'GRECO Region': greco_region,
            'Set': 'Training',
            'Deciduous': train_deciduous,
            'Evergreen': train_evergreen,
            'Total': train_deciduous + train_evergreen
        }, ignore_index=True)
        
        summary_table = summary_table.append({
            'GRECO Region': greco_region,
            'Set': 'Validation',
            'Deciduous': val_deciduous,
            'Evergreen': val_evergreen,
            'Total': val_deciduous + val_evergreen
        }, ignore_index=True)
    
    # Add total row
    total_deciduous_train = summary_table[summary_table['Set'] == 'Training']['Deciduous'].sum()
    total_evergreen_train = summary_table[summary_table['Set'] == 'Training']['Evergreen'].sum()
    total_deciduous_val = summary_table[summary_table['Set'] == 'Validation']['Deciduous'].sum()
    total_evergreen_val = summary_table[summary_table['Set'] == 'Validation']['Evergreen'].sum()
    
    summary_table = summary_table.append({
        'GRECO Region': 'Total',
        'Set': 'Training',
        'Deciduous': total_deciduous_train,
        'Evergreen': total_evergreen_train,
        'Total': total_deciduous_train + total_evergreen_train
    }, ignore_index=True)
    
    summary_table = summary_table.append({
        'GRECO Region': 'Total',
        'Set': 'Validation',
        'Deciduous': total_deciduous_val,
        'Evergreen': total_evergreen_val,
        'Total': total_deciduous_val + total_evergreen_val
    }, ignore_index=True)


    summary_table['GRECO Region'] = summary_table['GRECO Region'].map(mapping_real_greco)
    
    print(summary_table)
    
    return summary_table

methods = ["resampled_no_weights",
            "no_resample_no_weights", 
            "no_resample_cloud_weights",
            "no_resample_cloud_disturbance_weights"]

years = [1, 2, 3]
#combine methods and years to get all configs
configs = [f"{method}_Y{year}" for method in methods for year in years]

# Assuming data, all_weights, and tile_to_greco are already loaded from the provided code
data_dir = '/Users/arthurcalvi/Data/species/validation/tiles'
for config in tqdm(configs): 
    print(f'config : {config}')
    data, all_weights, tile_to_greco = load_data(data_dir, config)
    # Add GRECO region to the data
    data['greco_region'] = data['tile_id'].map(tile_to_greco)

    train_data, val_data = stratified_group_split(data)

    # Verify the split
    summary_table = verify_split(train_data, val_data)
    summary_table.to_csv(f'summary_table_{config}.csv', index=False)
    # Save the datasets
    train_data.to_csv(f'train_data_{config}.csv', index=False)
    val_data.to_csv(f'val_data_{config}.csv', index=False)

    print("Training and validation datasets saved.")
