# Imports

In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
import time

# Selecting New Reference Data With Clustering (K Nearest Neighbors)

In [2]:
def select_reference_data(state, species, year, num_trees=500, k=1):
    
    df = pd.read_csv(f"all data/citizen/{state}.csv")
    df = df[df["Species_name"] == species]
    df = df[df["Year"] == int(year)]
    df = df.drop(["Date_of_observation", "Observation_ID", "User_id", "User_Tree_id", "Species_id", "Species_name", "Year"], axis=1) # Only keep week, phenophases, and coordinates
    df = df.reset_index(drop=True)
    
    selected_ref_dict = {} # Saves closest citizen observations to centroids
    
    for week in df["Week"].sort_values().unique():

        week_df = df[df["Week"] == week] # Only use data from the given week
        week_df = week_df.drop("Week", axis=1) # Drop the week column so that it is not used for clustering
        week_df = week_df.reset_index(drop=True)
        
        # Clustering

        km = KMeans(n_clusters=1, random_state=42, n_init="auto") # Initializing kmeans model with 1 cluster
        clusters = km.fit_predict(week_df) # Creating 1 cluster for all the observations for the given week
        centroids = km.cluster_centers_ # Finding the centroid of the cluster
        centroid = centroids[0]
        centroid = [float(val) for val in centroid]
        distances = []
        
        for i, row in week_df.iterrows():
            distance = np.linalg.norm(row.values - centroid) # Find euclidian distance from each observation to the centroid
            distances.append(distance) # Record all distances
            
        week_df['Distance_to_Centroid'] = distances # Add distances as an attribute in the DataFrame for the week
        closest_k_points = week_df["Distance_to_Centroid"].sort_values()[:k] # Find the k observations with closest (smallest distance) to the centroid
        selected_ref_dict[week] = week_df.iloc[closest_k_points.index].median() # Use the median phenophase values over the k observations selected as the selected reference data for the given week
    
    # Reformatting new reference data dictionary to be returned as a dataframe matching style of reference data
    selected_ref_df = pd.DataFrame(selected_ref_dict).T.sort_index() # Dictionary -> DataFrame
    selected_ref_df.insert(0,"Week",selected_ref_df.index) # Add the week column back
    
    return selected_ref_df

# Plotting Selected Reference vs. Citizen Data (Percentages)

In [3]:
# Helper Function for create_selected_ref_plots
def create_percentage_plots(state, species, year, phenophase):
    
    df = pd.read_csv(f"all data/citizen/{state}.csv")
    df = df[df["Species_name"] == species]
    df = df[df["Year"] == int(year)]
    df = df.drop(["Lat", "Long", "Date_of_observation", "Observation_ID", "User_id", "User_Tree_id", "Species_id", "Species_name", "Year"], axis=1)
    df = df.reset_index(drop=True) 
    presence_pcts = {}
    for week in df["Week"].sort_values().unique():
        weekly_observations = df[df["Week"] == week]
        presence_count = len(weekly_observations[weekly_observations[phenophase] >= 1])
        total_count = len(weekly_observations)
        presence_pcts[week] = presence_count/total_count
    return presence_pcts

def create_selected_ref_plots(plot_path, states, n_species, years, k):
    """
    Saves plots comparing selected reference data with citizen observations

    Args:
        plot_path (string): OS Path for where all plots will be saved
        states (List(string)): List of state names to create plots for
        n_species (int): Top n most prevalent species in the given state
        years (List(int)): List of years to create plots for
        k (int): Number of closest points to centroid considered for selecting reference data
    Returns:
        None
    """
    
    for state in states:
        species_in_state = pd.read_csv(f"all data/citizen/{state}.csv")['Species_name'].value_counts().index[:n_species] # Top n most prevalent species in the given state
        state_start_time = time.time()
        for species in species_in_state:
            species_start_time = time.time()
            for year in years:
                year_start_time = time.time()
                
                species_ref_df = select_reference_data(state, species, year, k=k) # Select reference data for the given 

                state_species_plot_path = f"{plot_path}/{state}/{species.replace(' ', '').replace('.', '').lower()}"
                os.makedirs(state_species_plot_path, exist_ok=True)

                phenophases = ['Leaves_fresh','Leaves_mature','Leaves_old','Flowers_bud','Flowers_open',
                               'Flowers_male','Flowers_Female','Fruits_unripe','Fruits_ripe','Fruits_open']

                for phenophase in phenophases:

                    ref_x = species_ref_df['Week']
                    ref_y = species_ref_df[phenophase]
                    cit_pcts = create_percentage_plots(state, species, year, phenophase)
                    cit_x = list(cit_pcts.keys())
                    cit_y = list(cit_pcts.values())

                    fig, ax = plt.subplots()
                    ax.plot(ref_x, ref_y, label=f'Selected citizen observations (selected reference data)',color='orange')
                    ax.set_ylim(-0.01,2.02)
                    ax.set_title(f'{phenophase} Reference Data vs. Percentage Observing Phenophase for {species} in {state} in {year}', fontsize=8)
                    ax.set(xlabel='Week of the year (0-47)', ylabel='Phenophase Value')
                    twin = ax.twinx()
                    twin.plot(cit_x, cit_y, label=f'Percentage Observing Phenophase')
                    twin.set_ylim(-0.01, 1.01)
                    twin.set(ylabel='Percent')
                    ax.legend()
                    twin.legend(loc='lower left')
                    plt.savefig(f"{state_species_plot_path}/{phenophase}_{year}")
                    plt.close()
                print(f"{year} in {species} in {state} finished in {time.time()-year_start_time} seconds")
            print(f"{species} in {state} finished in {time.time()-species_start_time} seconds")
        print(f"{state} finished in {time.time()-state_start_time} seconds")

In [4]:
plot_path = "plots/selected_ref_vs_cit" # Path for where plots will be stored
states = ['kerala'] # use the following instead for all states: [state.replace('.csv','') for state in os.listdir("all data/citizen")]
n_species = 3 # Top n most prevalent species within the given state
years = [2018,2019,2020,2021,2022,2023]
k = 3
create_selected_ref_plots(plot_path, states, n_species, years, k)

2018 in Mango (all varieties)-Mangifera indica in kerala finished in 9.584952116012573 seconds
2019 in Mango (all varieties)-Mangifera indica in kerala finished in 10.597656011581421 seconds
2020 in Mango (all varieties)-Mangifera indica in kerala finished in 10.432561874389648 seconds
2021 in Mango (all varieties)-Mangifera indica in kerala finished in 9.695079803466797 seconds
2022 in Mango (all varieties)-Mangifera indica in kerala finished in 10.113780975341797 seconds
2023 in Mango (all varieties)-Mangifera indica in kerala finished in 9.993515014648438 seconds
Mango (all varieties)-Mangifera indica in kerala finished in 60.4184250831604 seconds
2018 in Jackfruit-Artocarpus heterophyllus in kerala finished in 12.664247035980225 seconds
2019 in Jackfruit-Artocarpus heterophyllus in kerala finished in 12.995655059814453 seconds
2020 in Jackfruit-Artocarpus heterophyllus in kerala finished in 10.895534992218018 seconds
2021 in Jackfruit-Artocarpus heterophyllus in kerala finished in 