# Imports

In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.ensemble import IsolationForest
from sklearn.cluster import KMeans

# Selecting New Reference Data With Isolation Forests

In [2]:
def select_reference_data(state, species, year):
    
    df = pd.read_csv(f"all data/citizen/{state}.csv")
    df = df[df["Species_name"] == species]
    df = df[df["Year"] == int(year)]
    df = df.drop(["Date_of_observation", "Observation_ID", "User_id", "User_Tree_id", "Species_id", "State_name", "Species_name", "Year"], axis=1)
    df = df.reset_index(drop=True)
    
    model = IsolationForest(n_estimators = 500, verbose = 1, random_state = 42)
    
    new_ref_dict = {} # Saves closest citizen observations to centroids.
    cent = {} # Saves cluster centers of each week for future use
    
    for week in df["Week"].sort_values().unique():
        
        # Isolation Forests
        
        week_df = df[df["Week"] == week]
        week_df = week_df.drop("Week", axis=1)
        week_df = week_df.reset_index(drop=True)
        
        model.fit(week_df)
        preds = model.predict(week_df)

        week_df["Predictions"] = preds

        valid_data = week_df[week_df["Predictions"] == 1]
        valid_data = valid_data.reset_index(drop=True)
        
        # Clustering

        km = KMeans(n_clusters=1, random_state=42, n_init="auto")
        clusters = km.fit_predict(valid_data)
        centroids = km.cluster_centers_
        distances = []
        
        for i, row in valid_data.iterrows():
            cluster_label = clusters[i]
            centroid = centroids[0]
            relevant_values = row.values
            float_centroid = [float(val) for val in centroid]
            distance = np.linalg.norm(relevant_values - float_centroid)
            distances.append(distance)
            
        valid_data['Distance_to_Centroid'] = distances
        closest_point = valid_data[ valid_data["Distance_to_Centroid"] == min(valid_data['Distance_to_Centroid']) ]
        cent[week] = km.cluster_centers_
        new_ref_dict[week] = closest_point.iloc[0] # Adding citizen observation closest to centroid to new reference data dictionary for the associated week
        
    # Reformatting new reference data dictionary to be returned as a dataframe matching style of reference data
    new_ref_df = pd.DataFrame(new_ref_dict).T.sort_index()
    new_ref_df.insert(0,"Week",new_ref_df.index)
    return new_ref_df

# Plotting Selected Reference Data for Each Phenophase

In [3]:
plot_path = "plots/selected_reference_data_isolation_forests"
#states = [state.replace('.csv','') for state in os.listdir("all data/citizen")]
states = ['kerala']
for state in states:
    #species_in_state = pd.read_csv(f"all data/citizen/{state}.csv")['Species_name'].unique()
    species_in_state = pd.read_csv(f"all data/citizen/{state}.csv")['Species_name'].value_counts().index[:10] # Top 10 prevalent species within that state
    for species in species_in_state:
        
        year = 2023 # !!! Do some stuff with years here (e.g. iterating over years and finding medin) !!!
        
        species_ref_df = select_reference_data(state, species, year)
        
        os.makedirs(f"{plot_path}/{state}/{species.replace(' ', '').replace('.', '').lower()}", exist_ok=True)
        
        phenophases = ['Leaves_fresh','Leaves_mature','Leaves_old','Flowers_bud','Flowers_open',
                       'Flowers_male','Flowers_Female','Fruits_unripe','Fruits_ripe','Fruits_open']
        for phenophase in phenophases:
            x = species_ref_df['Week']
            y = species_ref_df[phenophase]
            
            plt.plot(x, y, label=f'Selected citizen observation (selected reference data)')
            plt.xlabel('Week of the year (0-47)')
            plt.ylabel(f'{phenophase} value')
            plt.ylim(-2.2,2.2)
            plt.title(f'Annual {phenophase} reference data for {species} in {state}', fontsize=8)
            plt.legend()
            #plt.show()
            plt.savefig(f"{plot_path}/{state}/{species.replace(' ', '').replace('.', '').lower()}/{phenophase}_{year}")
            plt.close()