# This notebook takes in images from a specified directory and crops them to certain sizes based on perturbed GPS coordinates.
# The cropped images are then stored in designated directories for further analysis.

In [5]:
from Satellite_Imagery_Module import LocationPerturber
from Satellite_Imagery_Module import ImageClass
import sys
import os
import pandas as pd
import numpy as np

def add_tif_extension(master_dir):
    for root, _, files in os.walk(master_dir):
        for file in files:
            if not file.endswith('.tif'):
                old_file = os.path.join(root, file)
                new_file = os.path.join(root, file + '.tif')
                os.rename(old_file, new_file)
        if file == './DS_Store':
            os.remove(os.path.join(root, file))

def crop_and_store_master_image(analysis_type, perturbation_magnitude, master_dir, gps_perturber, storage_dir_prefix, image_crop_sizes=[32, 64, 128, 256], master_dim=[928, 928, 3], visualize=False):
    """
    Crops and stores images from the master directory based on perturbed GPS coordinates.

    Parameters:
    analysis_type (str): The type of analysis being performed (e.g., 'cluster' or 'original').
    perturbation_magnitude (int): The magnitude of perturbation to apply to the GPS coordinates.
    master_dir (str): The directory containing the original images to be cropped.
    gps_perturber (LocationPerturber): An instance of LocationPerturber containing the perturbed GPS data.
    storage_dir_prefix (str): The prefix for the directory where cropped images will be stored.
    image_crop_sizes (list, optional): A list of sizes for cropping the images. Default is [32, 64, 128, 256].
    master_dim (list, optional): The dimensions of the master image. Default is [928, 928, 3].
    visualize (bool, optional): If True, visualizes the cropped images. Default is False.

    Raises:
    ValueError: If the length of original GPS data does not match the number of files in the master directory.
    ValueError: If perturbed pixel coordinates are out of bounds.

    This function processes the perturbed GPS data to determine the center pixel for cropping,
    checks for consistency in data lengths, and iterates through the images in the master directory
    to crop and store them based on the perturbed coordinates. The cropped images are saved in
    specified directories based on the perturbation magnitude and crop size.
    """
    add_tif_extension(master_dir)  # for safety
    center_pixel = master_dim[0] // 2, master_dim[1] // 2

    # Processes perturbed data
    perturbed_data = gps_perturber.perturbed_data
    perturbed_horizontal_distance = perturbed_data['horizontal_km']
    perturbed_vertical_distance = perturbed_data['vertical_km']

    # Convert distances to pixel values
    perturbed_horizontal_pixels = np.round(perturbed_horizontal_distance * 1000 / 30)
    perturbed_vertical_pixels = np.round(perturbed_vertical_distance * 1000 / 30)
    perturbed_pixel_center = np.array(center_pixel).reshape(2, 1) + np.array([perturbed_vertical_pixels, perturbed_horizontal_pixels])
    original_gps_length = len(gps_perturber.data)
    num_files_in_master_dir = len(os.listdir(master_dir))

    # Check whether they are of the same length
    if original_gps_length != num_files_in_master_dir:
        raise ValueError(f"Mismatch between original GPS data length ({original_gps_length}) and number of files in master directory ({num_files_in_master_dir})")

    for i, filename in enumerate(sorted(os.listdir(master_dir))):
        if perturbed_pixel_center[0][i] < 0 or perturbed_pixel_center[1][i] < 0:
            raise ValueError(f"Perturbed pixel coordinates are out of bounds: ({perturbed_pixel_center[0][i]}, {perturbed_pixel_center[1][i]})")
        if i % 100 == 0:
            print("At image ", i)
        # Fetch Image
        image_path = os.path.join(master_dir, filename)
        image = ImageClass(image_path, lon=None, lat=None, resolution=None, ImageCollection=None, crs=None, fileFormat=None, bands=None, vmin=None, vmax=None)
        for image_size in image_crop_sizes:
            storage_dir1 = os.path.join(storage_dir_prefix, str(image_size))
            storage_dir = storage_dir1 + f"/30/{analysis_type}/magnitude_{perturbation_magnitude}"
            os.makedirs(storage_dir, exist_ok=True)
            storage_path = os.path.join(storage_dir, filename)

            perturbed_image_center = [perturbed_horizontal_pixels[i] + 928 // 2, perturbed_vertical_pixels[i] + 928 // 2]
            image.crop_store_raster(storage_path, perturbed_image_center, image_size, image_size)
            image_stored = ImageClass(storage_path)
            if visualize:
                image_stored.visualize()

    print("Finished crop_and_store_master_image function")

def perturb_and_store_images(pertMags, image_crop_sizes, lat_col, lon_col, master_dir, storage_dir_prefix, original_gps_path, perturbed_gps_dir, analysis_type, perturbation_method, random_seed):
    """
    Perturbs GPS data and stores the perturbed coordinates in specified directories.

    Parameters:
    pertMags (list): A list of perturbation magnitudes to apply.
    image_crop_sizes (list): A list of sizes for cropping images.
    lat_col (str): The name of the column containing latitude information.
    lon_col (str): The name of the column containing longitude information.
    master_dir (str): The directory containing the master images.
    storage_dir_prefix (str): The prefix for the storage directory where perturbed images will be saved.
    original_gps_path (str): The file path to the original GPS data CSV.
    perturbed_gps_dir (str): The directory where perturbed GPS data will be stored.
    analysis_type (str): The type of analysis being performed (e.g., 'cluster' or 'original').
    perturbation_method (str): The method used for perturbing the GPS data.
    random_seed (int): The seed for random number generation.

    This function iterates over the specified perturbation magnitudes, applies perturbations to the GPS data,
    and stores the perturbed data in CSV files. If the analysis type is 'cluster', it groups the original GPS
    data by geographical area and calculates the mean latitude and longitude for each area before perturbing.
    The perturbed data is then merged with the original data and saved to the specified directory.
    Finally, it calls the crop_and_store_master_image function to crop and store images based on the perturbed coordinates.
    """
    for i in pertMags:
        random_seed += 1
        # Perturbing and storing
        perturbation_parameters = {'lower': 3, 'upper': i} # This could be changed to a different range. For the simulation, we used 3-7
        perturbation_magnitude = i
        print(perturbation_magnitude)
        perturbed_gps_path = os.path.join(perturbed_gps_dir, f"magnitude_{perturbation_magnitude}.csv") 
        print("Perturbed_gps_dir", perturbed_gps_dir)
        os.makedirs(perturbed_gps_dir, exist_ok=True)
        original_gps_data = pd.read_csv(original_gps_path)
        if 'cluster' in analysis_type:
            gps_per_village = original_gps_data.copy()
            gps_per_village = gps_per_village.groupby('geo').agg({
                'average_lat': 'mean',
                'average_lon': 'mean'
            }).reset_index()
            gps_perturber = LocationPerturber(gps_per_village, lat_col=lat_col, lon_col=lon_col, mode='km', random_seed=random_seed)
            gps_perturber.perturb_data(perturbation_method, perturbation_parameters)

            # Stores per village gps
            per_village_gps_path = os.path.join(perturbed_gps_dir, f"magnitude_{perturbation_magnitude}_per_village.csv")
            gps_perturber.store(per_village_gps_path)
    
            perturbed_gps_path = os.path.join(perturbed_gps_dir, f"magnitude_{perturbation_magnitude}.csv")
            # Merge perturbed data with original data based on 'geo' column
            merged_data = original_gps_data.merge(

                gps_perturber.perturbed_data[['geo', 'horizontal_km', 'vertical_km']],
                on='geo',
                how='left'
            )
            # Save the merged data to a CSV file
            gps_perturber.perturbed_data = merged_data
            gps_perturber.data = original_gps_data
            gps_perturber.store(perturbed_gps_path)
        else:
            gps_perturber = LocationPerturber(original_gps_data, lat_col=lat_col, lon_col=lon_col, mode='km', random_seed=random_seed)
            gps_perturber.perturb_data(perturbation_method, perturbation_parameters)
            perturbed_gps_path = os.path.join(perturbed_gps_dir, f"magnitude_{perturbation_magnitude}.csv")
            gps_perturber.store(perturbed_gps_path)
        
        crop_and_store_master_image(analysis_type, perturbation_magnitude, master_dir, gps_perturber, storage_dir_prefix, image_crop_sizes, visualize=False)

def process_countries(countries, monte_i, perturbation_method, pertMags, image_crop_sizes):
    for country in countries:
        random_seed = monte_i * 100000
        print(f"Random seed set to: {random_seed}")  # Debug statement
        lat_col = 'average_lat'
        lon_col = 'average_lon'
        master_dir_prefix = "./"
        analysis_type = f"original_{country}"

        # The directory structure could be different depending on the user's setup.
        master_dir = f"{master_dir_prefix}/{country}/master/satellite_images/landsat/0.25/30/{analysis_type}/"
        storage_dir_prefix = f"{master_dir_prefix}/{country}/monte_carlo_{monte_i}/satellite_images/landsat/{perturbation_method}"
        original_gps_path = f'{master_dir_prefix}/{country}/master/gps_locations/{country}_cluster_center_coordinates_per_person_fully_subsetted.csv'
        perturbed_gps_dir = f'{master_dir_prefix}/{country}/monte_carlo_{monte_i}/gps_locations/{perturbation_method}/{analysis_type}'

        print(f"Starting original perturbation for country: {country}")  # Debug statement
        # Perturb for cluster (parameters defined before the cluster preprocessing stage)
        perturb_and_store_images(
            pertMags,
            image_crop_sizes,
            lat_col,
            lon_col,
            master_dir,
            storage_dir_prefix,
            original_gps_path,
            perturbed_gps_dir,
            analysis_type,
            perturbation_method,
            random_seed
        )

In [6]:
for monte_i in [0, 1, 2, 3, 4, 5, 6]:
    pertMags = [7]
    print("pertMags", pertMags)
    image_crop_sizes = [32]
    print("image_crop_sizes", image_crop_sizes)
    perturbation_method = "DHS"
    print("perturbation_method", perturbation_method)
    countries = ["peru"]
    print("countries", countries)
    process_countries(countries, monte_i, perturbation_method, pertMags, image_crop_sizes)

pertMags [7]
image_crop_sizes [32]
perturbation_method DHS
countries ['peru']
Random seed set to: 0
Starting perturbation for country: peru
7
Perturbed_gps_dir /n/holylabs/LABS/meng_lab/Lab/IndividIm/Data/peru/monte_carlo_0/gps_locations/DHS/cluster_peru
IN km mode
At image  0
At image  100
At image  200
At image  300
At image  400
At image  500
At image  600
At image  700
At image  800
At image  900
At image  1000
At image  1100
At image  1200
At image  1300
At image  1400
At image  1500
Finished crop_and_store_master_image function
Starting original perturbation for country: peru
7
Perturbed_gps_dir /n/holylabs/LABS/meng_lab/Lab/IndividIm/Data/peru/monte_carlo_0/gps_locations/DHS/original_peru
IN km mode
At image  0
At image  100
At image  200
At image  300
At image  400
At image  500
At image  600
At image  700
At image  800
At image  900
At image  1000
At image  1100
At image  1200
At image  1300
At image  1400
At image  1500
Finished crop_and_store_master_image function
pertMags [