In [None]:
import numpy as np
import pandas as pd

from PIL import Image
import gc
import os
import sys
import cv2

In [None]:
from datasets import Dataset, load_from_disk
# Add the directory containing lit_sam_model.py to the Python path
sys.path.append(os.path.abspath("../"))
import dataprocessing.rcsHandlingFunctions as rcs
from dataprocessing.slope import calculate_slope


In [None]:
# Load the DataFrame from the file 
df_loaded = pd.read_pickle("train_df_sam.pkl")

In [None]:
df_met = pd.read_pickle("df_met.pkl")

In [None]:
df_dem = pd.read_pickle("dataframe_avalanches_dem.pkl")

In [None]:
def extract_id(path):
    # Normalize path separators
    path = path.replace("\\", "/")
    # Split on the known folder name ("avalanche_input/")
    parts = path.split("avalanche_input/")
    if len(parts) > 1:
        # The next segment in the path should be the id folder
        return parts[1].split("/")[0]
    return None

df_loaded['id'] = df_loaded['image_path'].apply(extract_id)
df_met['id'] = df_met['source_file'].apply(extract_id)
df_dem['id'] = df_dem['dem_path'].apply(extract_id)

In [None]:
print(df_loaded.iloc[0]["id"])
print(df_met.iloc[0]["id"])
print(df_dem.iloc[0]["id"])
print(df_loaded.iloc[0]["image_path"])
print(df_met.iloc[0]["source_file"])
print(df_dem.iloc[0]["dem_path"])

In [None]:
print(df_loaded.columns)
print(df_met.columns)
print(df_dem.columns)

In [None]:
# Merge df_met into df_loaded using a left join so that only the entries present in df_loaded are kept.
merged_df = pd.merge(df_loaded, df_met, on='id', how='left', suffixes=('', '_met'))
merged_df = pd.merge(merged_df, df_dem, on='id', how='left', suffixes=('', '_original'))

In [None]:
merged_df.columns

In [None]:
df_loaded = merged_df

In [None]:
# Create a list to store the indices of non-empty masks
valid_indices = [i for i, is_empty in enumerate(df_loaded['empty_mask']) if not is_empty]
len(valid_indices)

In [None]:
# Filter the image and mask arrays to keep only the non-empty pairs
filtered_rcs = df_loaded.loc[valid_indices, 'rcs']
filtered_masks = df_loaded.loc[valid_indices, 'mask']
filtered_boxes = df_loaded.loc[valid_indices, 'boxes']
filtered_DEM = df_loaded.loc[valid_indices, 'dem_original']
filtered_met = df_loaded.loc[valid_indices, ['air_temperature_2m', 'precipitation_amount', 'wind_speed_10m', 'relative_humidity_2m', 'air_pressure_at_sea_level']]
filtered_id = df_loaded.loc[valid_indices, 'id']
print("Mask shape:", filtered_masks.shape)

In [None]:
def convert_mask_to_tiff_format(mask):
    """
    Convert the mask to the desired TIFF format in memory and normalize values from 0-255 to 0-1.

    Parameters:
    mask (numpy.ndarray): The input mask.

    Returns:
    numpy.ndarray: The mask in the desired TIFF format with normalized values.
    """
    # Ensure the mask is in the correct format
    mask = mask.astype(np.float32)  # Convert to 32-bit float
    mask /= 255.0  # Normalize values from 0-255 to 0-1
    return mask

In [None]:
# Function to substitute DEM data into the image
def substitute_dem(image, dem_data, channel=2):
    """
    Substitute one of the layers of the image with the DEM data.
    Args:
        image (np.ndarray): The original image.
        dem_data (np.ndarray): The DEM data.
        channel (int): The channel to be replaced with DEM data.
    Returns:
        np.ndarray: The modified image with DEM data.
    """
    modified_image = image.copy().astype(np.float32)
    modified_image[:, :, channel] = dem_data / 4000.0  # Reduce the DEM data in the 0-1 range
    return modified_image

In [None]:
def prepare_meteo_data(row):
    def ensure_list(x):
        return x if isinstance(x, list) else [x]
        
    # Ensure each meteo value is a list
    temp = np.array(ensure_list(row["air_temperature_2m"]))
    precip = np.array(ensure_list(row["precipitation_amount"]))
    wind = np.array(ensure_list(row["wind_speed_10m"]))
    humidity = np.array(ensure_list(row["relative_humidity_2m"]))
    pressure = np.array(ensure_list(row["air_pressure_at_sea_level"]))
    
    meteo = np.stack([temp, precip, wind, humidity, pressure], axis=-1)  # shape: (T, 5)
    return meteo.tolist()  # consistently return list of lists

In [None]:
# Function to process data in batches
def process_in_batches(rcs_data, masks, boxes, dems, met, ids, batch_size=100):
    for i in range(0, len(rcs_data), batch_size):
        batch_rcs_data = rcs_data[i:i + batch_size]
        batch_masks = masks[i:i + batch_size]
        batch_boxes = boxes[i:i + batch_size]
        batch_dems = dems[i:i + batch_size]
        batch_met = met[i:i + batch_size]
        batch_ids = ids[i:i + batch_size]

        # First, merge all rcs instances using _merge_all for each instance.
        merged_images = [rcs._merge_all(*rcs_instance)[0] for rcs_instance in batch_rcs_data]
        # Each merged image has shape (H, W, 5), where:
        #   merged_image[:,:,0]  => HV0 (rescaled hv0)
        #   merged_image[:,:,1]  => HV1 (rescaled hv1)
        #   merged_image[:,:,2]  => VV0 (rescaled vv0)
        #   merged_image[:,:,3]  => VV1 (rescaled vv1)
        #   merged_image[:,:,4]  => duplicate of VV1 (can be ignored)

        dataset_dict = {
            "VH0": [cv2.resize(img[:, :, 0], (512, 512), interpolation=cv2.INTER_LINEAR) for img in merged_images],
            "VH1": [cv2.resize(img[:, :, 1], (512, 512), interpolation=cv2.INTER_LINEAR) for img in merged_images],
            "VV0": [cv2.resize(img[:, :, 2], (512, 512), interpolation=cv2.INTER_LINEAR) for img in merged_images],
            "VV1": [cv2.resize(img[:, :, 3], (512, 512), interpolation=cv2.INTER_LINEAR) for img in merged_images],
            "dem": [cv2.resize(dem / 4000.0, (512, 512), interpolation=cv2.INTER_LINEAR) for dem in batch_dems],
            "label": [convert_mask_to_tiff_format(np.array(mask)) for mask in batch_masks],
            "box": [box for box in batch_boxes],
            "met": [prepare_meteo_data(row) for row in batch_met.to_dict('records')],
            "slope": [cv2.resize(calculate_slope(dem) / 90, (512, 512), interpolation=cv2.INTER_LINEAR) for dem in batch_dems],
            "id": batch_ids
        }

        # Create the dataset using the datasets.Dataset class
        dataset = Dataset.from_dict(dataset_dict)
        
        # Process the dataset (e.g., training, evaluation, etc.)
        # Your processing code here
        # Save the dataset to disk
        dataset.save_to_disk('datasetBoxes' + str(i))
        #break;  # Remove this line to process all batches
        
        # Clear memory
        del dataset_dict, dataset
        gc.collect()

In [None]:
divide = 7

In [None]:
process_in_batches(filtered_rcs, filtered_masks, filtered_boxes, filtered_DEM, filtered_met, filtered_id, batch_size= filtered_rcs.shape[0]//divide)

In [None]:
datasetsNames = ['datasetBoxes' + str(i) for i in range(0, len(filtered_rcs), filtered_rcs.shape[0]//divide)]

In [None]:
print(datasetsNames)

In [None]:
def is_dataset_loaded_and_not_empty(dataset):
    """
    Check if the dataset is loaded and not empty.

    Parameters:
    dataset (Dataset): The loaded dataset.

    Returns:
    bool: True if the dataset is loaded and not empty, False otherwise.
    """
    if dataset is None:
        return False
    if len(dataset) == 0:
        return False
    return True

In [None]:
datasetList = []
for datasetName in datasetsNames:
    dataset = load_from_disk(datasetName)
    if is_dataset_loaded_and_not_empty(dataset):
        print(f"Dataset '{datasetName}' loaded successfully.")
    datasetList.append(dataset)

In [None]:
len(datasetList)

In [None]:
from datasets import concatenate_datasets
# Concatenate the datasets
merged_dataset = concatenate_datasets(datasetList)

In [None]:
import matplotlib.pyplot as plt
#Check dataset
print(merged_dataset.column_names)

#draw all the 6 inputs ( slope, dem, VH, VV)
plt.figure(figsize=(12, 8))
count = 0
for i, (key, value) in enumerate(merged_dataset[0].items()):
    if key in ["slope", "dem", "VH0", "VH1", "VV0", "VV1"]:
        plt.subplot(2, 3, count + 1)
        plt.imshow(value, cmap='gray')
        plt.title(key)
        plt.axis('off')
        count += 1
plt.tight_layout()
plt.show()

#slope max and min
slope_array = np.array(merged_dataset[0]["slope"])
print("Slope max:", slope_array.max())
print("Slope min:", slope_array.min())


In [None]:
# Save the dataset to disk
merged_dataset.save_to_disk('datasetTrainFinal')

In [None]:
merged_dataset.shape

In [None]:
# Load the dataset from disk
dataset = load_from_disk('datasetTrainFinal')

In [None]:
from sklearn.model_selection import train_test_split

# Split the dataset into training and test sets (90% training, 10% test)
train_test_split_ratio = 0.9
train_dataset, test_dataset = dataset.train_test_split(test_size=1 - train_test_split_ratio, seed = 20).values()

In [None]:
# Split the training dataset into training and validation sets (90% training, 10% validation)
train_val_split_ratio = 0.9
train_dataset, val_dataset = train_dataset.train_test_split(test_size=1 - train_val_split_ratio, seed = 20).values()

In [None]:
print(train_dataset.shape)
print(val_dataset.shape)
print(test_dataset.shape)

In [None]:




train_dataset.save_to_disk('datasetTrainDEMSeparateFloat')
val_dataset.save_to_disk('datasetValDEMSeparateFloat')
test_dataset.save_to_disk('datasetTestDEMSeparateFloat')