In [None]:
# Import necessary libraries
import xarray as xr
import rioxarray
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# File paths
# Replace with the path to your temperature data file
temperature_data_path = r"YOUR_TEMPERATURE_DATA_PATH.nc"

# Replace with the path to your tree density data file
tree_density_path = r"YOUR_TREE_DENSITY_DATA_PATH.tif"

# Replace with the path to the folder where you want to save the output images
output_folder = r"YOUR_OUTPUT_FOLDER_PATH"

# Load temperature data
temperature_data = xr.open_dataset(temperature_data_path)

# Load tree density data
tree_density_data = rioxarray.open_rasterio(tree_density_path)

# Extract daily mean temperature
daily_mean_temp = temperature_data['tas'].resample(time='D').mean()

# Check the shapes of the data
print("Shape of daily mean temperature:", daily_mean_temp.shape)
print("Shape of tree count:", tree_density_data.shape)

# Reproject or interpolate tree count to match the temperature data spatial resolution
tree_count_static = tree_density_data.squeeze()  # Remove the singleton dimension

# Interpolate tree count to match the temperature grid
tree_count_interp = tree_count_static.interp(x=daily_mean_temp['X'], y=daily_mean_temp['Y'], method='nearest')

# Fill missing values in the tree count with zeros
tree_count_interp = tree_count_interp.where(tree_count_interp > 0, 0)

# Create a DataFrame to hold the data
results = []

# Loop through each day and calculate mean temperature and tree count
for date in daily_mean_temp.time.values:
    # Extract the date
    date_str = pd.to_datetime(date).strftime('%Y-%m-%d')
    
    # Get the temperature for the current day
    temp_for_day = daily_mean_temp.sel(time=date).values
    
    # Use the interpolated tree count for the current day
    tree_count_for_day = tree_count_interp.values
    
    # Calculate the mean temperature for each unique tree count
    unique_tree_counts = np.unique(tree_count_for_day)
    for tree_count in unique_tree_counts:
        if tree_count > 0:
            mask = tree_count_for_day == tree_count
            mean_temp = np.mean(temp_for_day[mask])
            results.append({'date': date_str, 'tree_count': tree_count, 'mean_temp': mean_temp})

# Convert results to DataFrame
results_df = pd.DataFrame(results)

# Convert date column to datetime and extract month
results_df['date'] = pd.to_datetime(results_df['date'])
results_df['month'] = results_df['date'].dt.month

# Set the style for better-looking plots
sns.set(style="whitegrid")

# Create a figure with subplots for each month that has data
valid_months = results_df['month'].unique()
fig, axes = plt.subplots(nrows=len(valid_months)//3 + 1, ncols=3, figsize=(18, 6 * (len(valid_months)//3 + 1)))

# Iterate over each month that has data and create a scatter plot
for month, ax in zip(sorted(valid_months), axes.flat):
    month_data = results_df[results_df['month'] == month]
    ax.scatter(month_data['tree_count'], month_data['mean_temp'], s=20, color='blue')  # Reduced size of dots to 20
    ax.set_title(f"Month {month}", fontsize=14)
    ax.set_xlabel('Number of Trees', fontsize=12)
    ax.set_ylabel('Mean Temperature (°C)', fontsize=12)
    
    # Save the current subplot as a separate image in the specified folder
    plt.savefig(f"{output_folder}\\Month{month}.png")

plt.tight_layout()
plt.show()