#Data augmentation
## Primo bilanciamento tra sottocartelle 

In [None]:
import os
import random
import numpy as np
from PIL import Image, ImageFile
import csv
from tqdm import tqdm

# Ignora le immagini troncate
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Imposta il seed per garantire la riproducibilità
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)

def load_image(file_path):
    return Image.open(file_path)

def save_image(image, file_path):
    image.save(file_path)

def time_shifting(image_array, shift_max=0.2):
    shift = np.random.uniform(-shift_max, shift_max) * image_array.shape[1]
    return np.roll(image_array, int(shift), axis=1)

def noise_addition(image_array, noise_factor=0.1):
    noise = np.random.randn(*image_array.shape) * noise_factor
    return image_array + noise

def noise_reduction(image_array, reduction_factor=0.1):
    return image_array - image_array.mean() * reduction_factor

def time_masking(image_array, mask_max_percentage=0.2):
    mask_percentage = np.random.uniform(0, mask_max_percentage)
    t = int(mask_percentage * image_array.shape[1])
    t0 = np.random.randint(0, image_array.shape[1] - t)
    image_array[:, t0:t0+t] = 0
    return image_array

def frequency_masking(image_array, mask_max_percentage=0.2):
    mask_percentage = np.random.uniform(0, mask_max_percentage)
    f = int(mask_percentage * image_array.shape[0])
    f0 = np.random.randint(0, image_array.shape[0] - f)
    image_array[f0:f0+f, :] = 0
    return image_array

def apply_random_transformation(image_array):
    transformations = [
        time_shifting,
        noise_addition,
        noise_reduction,
        time_masking,
        frequency_masking
    ]
    transformation = random.choice(transformations)
    return transformation(image_array)

def main(target_dir, seed=42):
    set_seed(seed)
    
    animal_dirs = [os.path.join(target_dir, d) for d in os.listdir(target_dir) if os.path.isdir(os.path.join(target_dir, d))]
    
    max_count = 0
    max_animal_dir = None

    # Trova la sottocartella con il maggior numero di spettrogrammi
    for animal_dir in animal_dirs:
        count = len([f for f in os.listdir(animal_dir) if f.endswith('.png')])
        if count > max_count:
            max_count = count
            max_animal_dir = animal_dir
    
    print(f"La sottocartella con più spettrogrammi è: {max_animal_dir} con {max_count} file.")

    # Lista per salvare i percorsi dei file generati
    generated_files = []

    # Applica data augmentation alle altre sottocartelle
    for animal_dir in animal_dirs:
        if animal_dir == max_animal_dir:
            continue
        
        current_files = [f for f in os.listdir(animal_dir) if f.endswith('.png')]
        current_count = len(current_files)

        with tqdm(total=max_count - current_count, desc=f"Processing {os.path.basename(animal_dir)}") as pbar:
            while current_count < max_count:
                for file in current_files:
                    if current_count >= max_count:
                        break

                    file_path = os.path.join(animal_dir, file)
                    
                    try:
                        image = load_image(file_path)
                        image_array = np.array(image)

                        augmented_image_array = apply_random_transformation(image_array)
                        augmented_image = Image.fromarray(np.uint8(augmented_image_array))

                        new_file_name = f"{os.path.splitext(file)[0]}_aug_{current_count}.png"
                        new_file_path = os.path.join(animal_dir, new_file_name)
                        save_image(augmented_image, new_file_path)

                        generated_files.append(new_file_path)

                        current_count += 1
                        pbar.update(1)
                    
                    except OSError as e:
                        print(f"Errore nel caricamento dell'immagine {file_path}: {e}")
                        continue
    
    # Salva i percorsi dei file generati in un file CSV
    csv_file_path = os.path.join(target_dir, "generated_files.csv")
    with open(csv_file_path, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(["file_path"])
        for file_path in generated_files:
            writer.writerow([file_path])

    print(f"I percorsi dei file generati sono stati salvati in {csv_file_path}")

if __name__ == "__main__":
    target_directory = "DatasetSpettrogrammi/Training/Target"
    main(target_directory, seed=42)
