In [None]:
import xarray as xr
import numpy as np
from scipy.stats import wilcoxon
import netCDF4 as nc
import pandas as pd


# Define parameters for analysis
start_year = 2014
end_year = 2015
longhurst_region_code = 39
region_name = "PNEC"
combined_cat_values = [3, 4]  # Magnitude of heatwave categories to consider
num_samples = 100 
consecutive_months_threshold = 3

# Section 1: Load and Preprocess Data
# Load the netCDF file containing variables other than chlorophyll
dataset = xr.open_dataset('/Users/sayooj/Downloads/GlobalAtlas_MHW_ESACCISST_1deg_1982-2021.nc', decode_times=False)

# Define the start and end indices for slicing
start_idx = (start_year - 1982) * 365
end_idx = start_idx + (end_year - start_year + 1) * 365 - 1

# Create a new dataset with data only for the specified time range
new_dataset = dataset.isel(time=slice(start_idx, end_idx + 1))

# Convert data variables to float32 if needed
new_dataset['cat'] = new_dataset['cat'].astype('float32')
new_dataset['mhw'] = new_dataset['mhw'].astype('float32')

# Save the new dataset to a new netCDF file
new_dataset.to_netcdf(f'/Users/sayooj/Downloads/{region_name}_{start_year}_{end_year}.nc')

# Section 2: Mask Based on Longhurst Regions
# Open the Longhurst region file
longhurst_file = '/Users/sayooj/Downloads/Longhurst_1_deg.nc'
longhurst_dataset = xr.open_dataset(longhurst_file)

# Read the Longhurst variable
longhurst = longhurst_dataset['longhurst'].values

# Create a mask based on Longhurst regions and transpose it
mask = np.isin(longhurst, [longhurst_region_code]).T

# Apply the mask to the entire time range
masked_dataset = new_dataset.where(mask)

# Save the masked data to a new netCDF file
masked_file_path = f'/Users/sayooj/Downloads/masked_{region_name}_{start_year}_{end_year}.nc'
masked_dataset.to_netcdf(masked_file_path)

# Section 3: Create Monthly Masks with Values Only Inside Longhurst Region
# Load the netCDF file containing the masked data
masked_nc_file = xr.open_dataset(masked_file_path, decode_times=False)

# Extract the masked cat variable and apply the Longhurst mask
masked_cat = masked_nc_file['cat'].where(mask)

# Calculate the number of months
num_months = int(len(masked_nc_file['time']) / 30)

# Create an empty array to store monthly masks
monthly_masks = np.zeros((num_months, len(masked_nc_file['lat']), len(masked_nc_file['lon']))) * np.nan

# Iterate over each month
for month in range(num_months):
    # Calculate the start and end indices for the current month
    start_idx = month * 30
    end_idx = (month + 1) * 30

    # Extract the masked daily cat values for the current month
    month_data = masked_cat[start_idx:end_idx]

    # Find the maximum category occurrence for each lat-lon point in the current month
    max_values = np.max(month_data, axis=0)

    # Set areas impacted by the highest category occurrence within the Longhurst region
    monthly_mask = np.where(mask, max_values, np.nan)

    # Save the monthly mask
    monthly_masks[month] = monthly_mask

# Create a new netCDF file to save the monthly masks
output_file = xr.Dataset(
    data_vars={
        'lat': ('lat', masked_nc_file['lat'].values),
        'lon': ('lon', masked_nc_file['lon'].values),
        'time': ('time', np.arange(1, num_months + 1)),
        'monthly_masks': (['time', 'lat', 'lon'], monthly_masks)
    }
)

# Add attributes
output_file['lat'].attrs['units'] = 'degrees_north'
output_file['lon'].attrs['units'] = 'degrees_east'
output_file['time'].attrs['units'] = f'months since {start_year}-01-01'
output_file['monthly_masks'].attrs['units'] = '1'
output_file.attrs['description'] = f'Monthly masks for marine heatwaves in {region_name}'

# Save the monthly masks to a new netCDF file
output_file.to_netcdf(f'/Users/sayooj/Downloads/monthly_masks_{region_name}_{start_year}_{end_year}.nc')

# Close the netCDF files
masked_nc_file.close()

# Section 4: Create Consecutive Monthly Mask for Values 3 or 4
# Load the netCDF file containing the monthly masks
monthly_masks_file = xr.open_dataset(f'/Users/sayooj/Downloads/monthly_masks_{region_name}_{start_year}_{end_year}.nc', decode_times=False)

# Extract the monthly masks variable
monthly_masks_data = monthly_masks_file['monthly_masks'].values

# Initialize the consecutive monthly mask array
consecutive_monthly_mask = np.zeros_like(monthly_masks) * np.nan

# Iterate over each lat-lon point
for lat_idx in range(monthly_masks_data.shape[1]):
    for lon_idx in range(monthly_masks_data.shape[2]):
        # Extract the monthly mask values for the current lat-lon point
        values = monthly_masks_data[:, lat_idx, lon_idx]

        consecutive_count = 0
        consecutive_mask = np.zeros_like(values)

        for i in range(len(values)):
            if (values[i] == 3) or (values[i] == 4):
                consecutive_count += 1
                consecutive_mask[i] = values[i]
            else:
                consecutive_count = 0
                consecutive_mask[i] = 0

            if consecutive_count >= consecutive_months_threshold:
                break

        # Set the consecutive monthly mask values for the current lat-lon point
        consecutive_monthly_mask[:len(consecutive_mask), lat_idx, lon_idx] = consecutive_mask

# Apply the Longhurst mask to set values inside the region to NaN
consecutive_monthly_mask = np.where(mask, consecutive_monthly_mask, np.nan)



# Create a new netCDF file to save the consecutive monthly mask
consecutive_monthly_mask_file = xr.Dataset(
    data_vars={
        'lat': ('lat', monthly_masks_file['lat'].values),
        'lon': ('lon', monthly_masks_file['lon'].values),
        'time': ('time', monthly_masks_file['time'].values),
        'consecutive_monthly_mask': (['time', 'lat', 'lon'], consecutive_monthly_mask)
    }
)

# Add attributes
consecutive_monthly_mask_file['lat'].attrs['units'] = 'degrees_north'
consecutive_monthly_mask_file['lon'].attrs['units'] = 'degrees_east'
consecutive_monthly_mask_file['time'].attrs['units'] = f'months since {start_year}-01-01'
consecutive_monthly_mask_file['consecutive_monthly_mask'].attrs['units'] = '1'
consecutive_monthly_mask_file.attrs['description'] = f'Consecutive monthly mask for values 3 or 4 in {region_name}'

# Save the consecutive monthly mask to a new netCDF file
consecutive_monthly_mask_file.to_netcdf(f'/Users/sayooj/Downloads/consecutive_monthly_mask_{region_name}_{start_year}_{end_year}.nc')

# Close the netCDF files
monthly_masks_file.close()
consecutive_monthly_mask_file.close()

# Initialize a list to store information about consecutive heatwaves
consecutive_heatwave_info = []

# Iterate over each lat-lon point
for lat_idx in range(consecutive_monthly_mask.shape[1]):
    for lon_idx in range(consecutive_monthly_mask.shape[2]):
        # Extract the consecutive monthly mask values for the current lat-lon point
        values = consecutive_monthly_mask[:, lat_idx, lon_idx]

        # Find indices where consecutive heatwaves occurred (values 3 or 4)
        heatwave_indices = np.where(np.isin(values, [3, 4]))[0]

        # If consecutive heatwaves occurred at this lat-lon point
        if len(heatwave_indices) >= consecutive_months_threshold:
            # Get the corresponding dates for the identified indices
            heatwave_dates = monthly_masks_file['time'].values[heatwave_indices]

            # Convert months to dates based on the start_year
            start_date = pd.to_datetime(f'{start_year}-01-01')
            exact_dates = [(start_date + pd.DateOffset(months=int(month))).strftime('%Y-%m-%d') for month in heatwave_dates]

            # Append the lat, lon, months, and exact_dates to the list
            consecutive_heatwave_info.append({
                'lat': monthly_masks_file['lat'].values[lat_idx],
                'lon': monthly_masks_file['lon'].values[lon_idx],
                'months': heatwave_dates.tolist(),
                'exact_dates': exact_dates
            })

# Create a DataFrame with the extracted information
consecutive_heatwave_info_df = pd.DataFrame(consecutive_heatwave_info)

# Save the DataFrame to a CSV file with region name, start year, and end year in the filename
csv_file_path = f'/Users/sayooj/Downloads/consecutive_heatwave_info_{region_name}_{start_year}_{end_year}.csv'
consecutive_heatwave_info_df.to_csv(csv_file_path, index=False)


# Open the existing NetCDF file for chlorophyll data
file_path_chlorophyll = '/Users/Sayooj/Downloads/chlorophyll_anomaly_dataset.nc'
ds_chlorophyll = xr.open_dataset(file_path_chlorophyll)

# Define the time range you want to slice for chlorophyll data
start_date_chlorophyll = f'{start_year}-01-01'
end_date_chlorophyll = f'{end_year}-12-31'

# Slice the chlorophyll dataset to the desired time range
ds_chlorophyll_sliced = ds_chlorophyll.sel(time=slice(start_date_chlorophyll, end_date_chlorophyll))

# Create a new NetCDF file for sliced chlorophyll data
output_file_path_chlorophyll = f'sliced_OC-CCI_chlor_a_{region_name}_{start_year}_{end_year}.nc'
ds_chlorophyll_sliced.to_netcdf(output_file_path_chlorophyll)

# Close the original and sliced chlorophyll datasets
ds_chlorophyll.close()
ds_chlorophyll_sliced.close()

print(f'Sliced chlorophyll dataset saved to {output_file_path_chlorophyll}')

# Open the existing NetCDF file for wind speed and direction data
file_path_wind = '/Users/Sayooj/Downloads/wind_anomaly_dataset.nc'
ds_wind = xr.open_dataset(file_path_wind)

# Calculate the start and end dates based on start_year and end_year
start_date_wind = f'{start_year}-01-01'
end_date_wind = f'{end_year}-12-31'

# Slice the wind dataset to the calculated time range
ds_wind_sliced = ds_wind.sel(time=slice(start_date_wind, end_date_wind))

# Create a new NetCDF file with a formatted filename for wind data
output_file_path_wind = f'sliced_OC-CCI_CCMP_v3.0_wind_{region_name}_{start_year}_{end_year}.nc'
ds_wind_sliced.to_netcdf(output_file_path_wind)

# Close the original and sliced wind datasets
ds_wind.close()
ds_wind_sliced.close()

print(f'Sliced wind dataset saved to {output_file_path_wind}')

# Define the path to the sliced chlorophyll dataset file
chlorophyll_file = nc.Dataset(f'sliced_OC-CCI_chlor_a_{region_name}_{start_year}_{end_year}.nc')

# Get the chlorophyll variable data
chlorophyll = chlorophyll_file.variables['OC-CCI_chlor_a'][:]

# Define the path to the wind speed and direction dataset file
wind_file = nc.Dataset(f'sliced_OC-CCI_CCMP_v3.0_wind_{region_name}_{start_year}_{end_year}.nc')

# Get the wind speed and wind direction variable data
wind_speed = wind_file.variables['CCMP_w'][:]
wind_direction = wind_file.variables['CCMP_wind_dir'][:]

# Open the mask file for the specified region
mask_file = nc.Dataset(f'/Users/sayooj/Downloads/consecutive_monthly_mask_{region_name}_{start_year}_{end_year}.nc')

# Get the mask variable for the specified region
mask_region = mask_file.variables['consecutive_monthly_mask'][:]

# Apply the mask to the chlorophyll data
chlorophyll_masked_year = np.ma.masked_array(chlorophyll, np.logical_not(mask_region))

# Apply the mask to the wind speed and wind direction data
wind_speed_masked_year = np.ma.masked_array(wind_speed, np.logical_not(mask_region))
wind_direction_masked_year = np.ma.masked_array(wind_direction, np.logical_not(mask_region))

# Calculate median values with the mask for chlorophyll, wind speed, and wind direction
chlorophyll_median_region_year = np.ma.median(chlorophyll_masked_year, axis=(1, 2))
wind_speed_median_region_year = np.ma.median(wind_speed_masked_year, axis=(1, 2))
wind_direction_median_region_year = np.ma.median(wind_direction_masked_year, axis=(1, 2))

# Get indices where the mask values are equal to any of the specified cat values (e.g., heatwave period)
indices_heatwave_region_year = np.where(np.isin(mask_region, combined_cat_values))[0]

# Get indices where the mask values are not equal to any of the specified cat values (e.g., non-heatwave period)
indices_non_heatwave_region_year = np.where(~np.isin(mask_region, combined_cat_values))[0]

# Perform the Wilcoxon signed-rank tests for chlorophyll, wind speed, and wind direction
p_values_chlorophyll_region_year = []
p_values_wind_speed_region_year = []
p_values_wind_direction_region_year = []

median_diff_chlorophyll_region_year = []
median_diff_wind_speed_region_year = []
median_diff_wind_direction_region_year = []

std_dev_chlorophyll_region_year = []
std_dev_wind_speed_region_year = []
std_dev_wind_direction_region_year = []

for _ in range(num_samples):
    # Randomly select indices for heatwave and non-heatwave periods
    sample_indices_region_heatwave_year = np.random.choice(indices_heatwave_region_year, len(indices_heatwave_region_year), replace=True)
    sample_indices_region_non_heatwave_year = np.random.choice(indices_non_heatwave_region_year, len(indices_heatwave_region_year), replace=True)

    # Filter the data based on the sampled indices for chlorophyll, wind speed, and wind direction
    sample_chlorophyll_median_region_heatwave_year = chlorophyll_median_region_year[sample_indices_region_heatwave_year]
    sample_chlorophyll_median_region_non_heatwave_year = chlorophyll_median_region_year[sample_indices_region_non_heatwave_year]
    
    sample_wind_speed_median_region_heatwave_year = wind_speed_median_region_year[sample_indices_region_heatwave_year]
    sample_wind_speed_median_region_non_heatwave_year = wind_speed_median_region_year[sample_indices_region_non_heatwave_year]
    
    sample_wind_direction_median_region_heatwave_year = wind_direction_median_region_year[sample_indices_region_heatwave_year]
    sample_wind_direction_median_region_non_heatwave_year = wind_direction_median_region_year[sample_indices_region_non_heatwave_year]

    # Perform the Wilcoxon signed-rank tests for chlorophyll
    _, p_value_chlorophyll_region_year = wilcoxon(sample_chlorophyll_median_region_heatwave_year, sample_chlorophyll_median_region_non_heatwave_year)

    # Perform the Wilcoxon signed-rank tests for wind speed
    _, p_value_wind_speed_region_year = wilcoxon(sample_wind_speed_median_region_heatwave_year, sample_wind_speed_median_region_non_heatwave_year)

    # Perform the Wilcoxon signed-rank tests for wind direction
    _, p_value_wind_direction_region_year = wilcoxon(sample_wind_direction_median_region_heatwave_year, sample_wind_direction_median_region_non_heatwave_year)

    # Append the p-values to the respective lists for all variables
    p_values_chlorophyll_region_year.append(p_value_chlorophyll_region_year)
    p_values_wind_speed_region_year.append(p_value_wind_speed_region_year)
    p_values_wind_direction_region_year.append(p_value_wind_direction_region_year)

    # Calculate the median difference and standard deviation for all variables
    median_diff_chlorophyll_region_year.append(np.median(sample_chlorophyll_median_region_heatwave_year - sample_chlorophyll_median_region_non_heatwave_year))
    median_diff_wind_speed_region_year.append(np.median(sample_wind_speed_median_region_heatwave_year - sample_wind_speed_median_region_non_heatwave_year))
    median_diff_wind_direction_region_year.append(np.median(sample_wind_direction_median_region_heatwave_year - sample_wind_direction_median_region_non_heatwave_year))

    std_dev_chlorophyll_region_year.append(np.std(sample_chlorophyll_median_region_heatwave_year - sample_chlorophyll_median_region_non_heatwave_year))
    std_dev_wind_speed_region_year.append(np.std(sample_wind_speed_median_region_heatwave_year - sample_wind_speed_median_region_non_heatwave_year))
    std_dev_wind_direction_region_year.append(np.std(sample_wind_direction_median_region_heatwave_year - sample_wind_direction_median_region_non_heatwave_year))

# Calculate the median p-values, median median differences, and median standard deviations for all variables
median_p_value_chlorophyll_region_year = np.median(p_values_chlorophyll_region_year)
median_p_value_wind_speed_region_year = np.median(p_values_wind_speed_region_year)
median_p_value_wind_direction_region_year = np.median(p_values_wind_direction_region_year)

median_median_diff_chlorophyll_region_year = np.median(median_diff_chlorophyll_region_year)
median_median_diff_wind_speed_region_year = np.median(median_diff_wind_speed_region_year)
median_median_diff_wind_direction_region_year = np.median(median_diff_wind_direction_region_year)

median_std_dev_chlorophyll_region_year = np.median(std_dev_chlorophyll_region_year)
median_std_dev_wind_speed_region_year = np.median(std_dev_wind_speed_region_year)
median_std_dev_wind_direction_region_year = np.median(std_dev_wind_direction_region_year)

# Calculate the coefficient of variation for all variables
cv_chlorophyll_region_year = median_std_dev_chlorophyll_region_year / np.median(chlorophyll_median_region_year)
cv_wind_speed_region_year = median_std_dev_wind_speed_region_year / np.median(wind_speed_median_region_year)
cv_wind_direction_region_year = median_std_dev_wind_direction_region_year / np.median(wind_direction_median_region_year)

# Function to check if the median difference passes the filter
def apply_filter(median_diff, precision):
    return median_diff >= precision

# Define the precision for chlorophyll
precision_chlorophyll = 0.01  # mg m^-3

# Apply the quality filter for chlorophyll-a
chlorophyll_quality_filter = 0.01  # Threshold for chlorophyll-a in mg m^-3
chlorophyll_filtered_indices = np.where(chlorophyll_median_region_year >= chlorophyll_quality_filter)[0]

# Filter chlorophyll data based on the quality filter
chlorophyll_median_region_year_filtered = chlorophyll_median_region_year[chlorophyll_filtered_indices]

# Apply filter for chlorophyll
if apply_filter(median_median_diff_chlorophyll_region_year, precision_chlorophyll):
    print("\nResults for chlorophyll in {} in {}/{} pass the quality filter:".format(region_name, start_year, end_year))
    print("Median p-value:", median_p_value_chlorophyll_region_year)
    print("Median median difference:", median_median_diff_chlorophyll_region_year)
    print("Median standard deviation:", median_std_dev_chlorophyll_region_year)
else:
    print("\nResults for chlorophyll in {} in {}/{} do not pass the quality filter.".format(region_name, start_year, end_year))

# Print the results for wind speed
print(f"Results for wind speed in {region_name} in {start_year}/{end_year}:")
print("Median p-value:", median_p_value_wind_speed_region_year)
print("Median median difference:", median_median_diff_wind_speed_region_year)
print("Median standard deviation:", median_std_dev_wind_speed_region_year)
print("Coefficient of Variation:", cv_wind_speed_region_year)

# Print the results for wind direction
print(f"Results for wind direction in {region_name} in {start_year}/{end_year}:")
print("Median p-value:", median_p_value_wind_direction_region_year)
print("Median median difference:", median_median_diff_wind_direction_region_year)
print("Median standard deviation:", median_std_dev_wind_direction_region_year)
print("Coefficient of Variation:", cv_wind_direction_region_year)

# Close the netCDF files
chlorophyll_file.close()
wind_file.close()
mask_file.close()