<a href="https://colab.research.google.com/github/VincentCsNv/SymbioseManagement/blob/main/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install rasterio laspy lazrs --quiet
#!pip install pdal --quiet
!pip install torch torchvision --quiet
!pip install scikit-learn --quiet



In [2]:
import os
import numpy as np
import matplotlib.pyplot as plt
import rasterio
import laspy
from pathlib import Path
import glob 
import pandas as pd
import lib.functions_utils as fu
import torch
import sklearn

from importlib import reload
reload(fu)

<module 'lib.functions_utils' from '/Users/Vincent/Documents/2025/Projets_pro/Symbiose Management/SymbioseManagement/lib/functions_utils.py'>

In [3]:
DATA_dir = "data/"
repartition = { 'train': 0.8, 'val': 0.1, 'test': 0.1}
random_seed = 42

# 1. Getting data :

In [75]:
#Getting data information in a dataframe
data_df = fu.extrating_data_to_df(DATA_dir)

# Creating training/validation/test split based on repartion part
data_df = fu.rep_df_train_test_val(data_df, repartition, random_seed)

fu.detailed_distribution(data_df[data_df["type"] == "imagery"]) #Visualization 


Getting all the files .tiff and .laz path...
Number of files : 17138
🌳 Sample by species:
  Picea_abies: 4074 samples (47.5%)
  Castanea_sativa: 3684 samples (43.0%)
  Abies_alba: 811 samples (9.5%)

🎯 Dataset distribution:
  train: 6855 samples (80.0%)
  val: 857 samples (10.0%)
  test: 857 samples (10.0%)


🔍 Detailed distribution:

🌳 Abies_alba:
  train:  80.0% (attendu: 80%) ✅
  val  :  10.0% (attendu: 10%) ✅
  test :  10.0% (attendu: 10%) ✅

🌳 Castanea_sativa:
  train:  80.0% (attendu: 80%) ✅
  val  :  10.0% (attendu: 10%) ✅
  test :  10.0% (attendu: 10%) ✅

🌳 Picea_abies:
  train:  80.0% (attendu: 80%) ✅
  val  :  10.0% (attendu: 10%) ✅
  test :  10.0% (attendu: 10%) ✅


## Creating data loaders for imagery and lidar data

In [None]:
def create_dataloader(df,batch_size = 4, input_type = "aerial"):
    """
    Create dataloaders for training, validation and testing.

    Args:
        df (pd.DataFrame): DataFrame containing file paths and labels.
        batch_size (int): Batch size for the dataloaders.
        input_type (str): Type of input data ('aerial', 'lidar', 'fusion').
    Returns:
        dict: dataloader
    """
    #creating dataset 
    dataset = TreeDataset(df, input_type=input_type)

    #creating dataloaders
    dataloader = DataLoader(
                dataset, 
                batch_size = batch_size, 
                shuffle=True,
                num_workers=0
            )

    return dataloader


class TreeDataset(torch.utils.data.Dataset):
    def __init__(self, df, input_type='aerial', transform=None):
        self.df = df
        self.input_type = input_type
        self.transform = transform
        self.class_to_idx = {cls: idx for idx, cls in enumerate(df['species'].unique())}
        self.idx_to_class = {idx: cls for cls, idx in self.class_to_idx.items()}
        self.num_classes = len(self.class_to_idx)
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        label = self.class_to_idx[row['species']]
        
        # Load aerial image
        if self.input_type in ['aerial', 'fusion']:
            inputs = self.load_aerial_image(row["aerial_path"])
        else:
            inputs = self.load_lidar_data(row["lidar_path"])

        label = self.class_to_idx[row["species"]]

        return torch.tensor(inputs, dtype=torch.float32), torch.tensor(label, dtype=torch.long)

    def load_aerial_image(self, path):
        with rasterio.open(path) as src:
            img = src.read()  # Lire toutes les bandes
            img = np.transpose(img, (1, 2, 0))  # Convertir en HWC
            img = img / 255.0  # Normalisation
        return img
    
    def load_lidar_data(self, path):
        with laspy.open(path) as fh:
            las = fh.read()
        return las

In [95]:
dataloader = create_dataloader(
    df=data_df, 
    batch_size=4,  # Ajuste selon ta RAM/GPU
    input_type='aerial'  )

inputs, labels =next(iter(dataloader))