# Accuracy Assessment Stats and Confusion Matrices for Sargassum Classification

In [None]:
import os
import pandas as pd
import geopandas as gpd
import numpy as np
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score
import sklearn.metrics as skmetrics
import matplotlib.pyplot as plt
import seaborn as sns


## Load the AA shapefiles into GeodataFrame

In [None]:
# Get Validated Accuracy Assessment Shapefile
source_dir = r'/Users/arbailey/Google Drive/My Drive/sargassum/aa'
   
# Use already existing files merged across all dates
file = 'aaPoints_validated_2019.shp'
sargassum_aa_gdf = gpd.read_file(os.path.join(source_dir,file))      
sargassum_aa_gdf = sargassum_aa_gdf.dropna()  # remove rows w/null
print(sargassum_aa_gdf.describe())
sargassum_aa_gdf = sargassum_aa_gdf.astype({"validclass": int, "validpa": int})
print(sargassum_aa_gdf.dtypes)
sargassum_aa_gdf


In [None]:
# Original SR AA Patches
sr_patch_gdf = gpd.read_file(os.path.join(source_dir,'aaPatches_2019.shp'))
# TOA AA Patches -- 3 versions
toa1_patch_gdf = gpd.read_file(os.path.join(source_dir,'aaPatches_2019toa.shp'))
toa2_patch_gdf = gpd.read_file(os.path.join(source_dir,'aaPatches_2019toa_v2.shp'))
toa3_patch_gdf = gpd.read_file(os.path.join(source_dir,'aaPatches_2019toa_v3.shp'))

print(sr_patch_gdf.dtypes)
print(toa1_patch_gdf.dtypes)
toa1_patch_gdf

In [None]:
# Join samples with patch IDs to validated AA points
def combine_aa_patch(aa_gdf, patch_gdf, src):
    combo_gdf = pd.merge(aa_gdf, patch_gdf.drop(columns=['geometry']), on="aa_id", how="left")
    combo_gdf = combo_gdf.fillna(0)
    if src.startswith('toa'):
        combo_gdf = combo_gdf.drop(columns='sargassum')
        combo_gdf = combo_gdf.rename(columns={'toa_patch':'patch', 'toa_sarg':'sargassum'})
    combo_gdf = combo_gdf.astype({"patch": 'int64','sargassum':'int64'})
    combo_gdf['Date']= pd.to_datetime(combo_gdf['imagedate'])
    combo_gdf['source'] = src
    return combo_gdf

In [None]:
sr_combo_gdf = combine_aa_patch(sargassum_aa_gdf, sr_patch_gdf, 'sr1') 
sr_combo_gdf  

In [None]:
toa1_combo_gdf = combine_aa_patch(sargassum_aa_gdf, toa1_patch_gdf, 'toa1')
toa1_combo_gdf

In [None]:
toa2_combo_gdf = combine_aa_patch(sargassum_aa_gdf, toa2_patch_gdf, 'toa2')
toa2_combo_gdf

In [None]:
toa3_combo_gdf = combine_aa_patch(sargassum_aa_gdf, toa3_patch_gdf, 'toa3')
toa3_combo_gdf

In [None]:
# combo_gdf = pd.concat([sr_combo_gdf, toa1_combo_gdf, toa2_combo_gdf, toa3_combo_gdf])
# combo_gdf['Date']= pd.to_datetime(combo_gdf['imagedate'])
# combo_gdf

In [None]:
def aa_metrics(df):
    y_true = df['validpa']
    y_pred = df['sargassum']
    date_min = df['imagedate'].min()
    date_max = df['imagedate'].max()
    source = df['source'].iloc[0]
    acc_score = accuracy_score(y_true, y_pred)
    precision = skmetrics.precision_score(y_true, y_pred)
    precision_weighted = skmetrics.precision_score(y_true, y_pred, average='weighted')
    recall = skmetrics.recall_score(y_true, y_pred)
    recall_weighted = skmetrics.recall_score(y_true, y_pred, average='weighted')
    f1 = skmetrics.f1_score(y_true, y_pred)
    f1_weighted = skmetrics.f1_score(y_true, y_pred, average='weighted')
    support = skmetrics.precision_recall_fscore_support(y_true, y_pred)[3][1]
#     print(skmetrics.precision_recall_fscore_support(y_true, y_pred))
    return [source,date_min, date_max, acc_score, precision, precision_weighted, recall, recall_weighted, f1, f1_weighted, support]
 
print(aa_metrics(toa3_combo_gdf))
print(toa3_combo_gdf.groupby(['imagedate']).apply(aa_metrics))

In [None]:
sr_metrics_all = aa_metrics(sr_combo_gdf)
toa1_metrics_all = aa_metrics(toa1_combo_gdf)
toa2_metrics_all = aa_metrics(toa2_combo_gdf)
toa3_metrics_all = aa_metrics(toa3_combo_gdf)
metrics_all = [sr_metrics_all, toa1_metrics_all, toa2_metrics_all, toa3_metrics_all]
print(metrics_all)

metrics_bydate = []
sr_metrics_bydate = sr_combo_gdf.groupby(['imagedate']).apply(aa_metrics).tolist()
toa1_metrics_bydate = toa1_combo_gdf.groupby(['imagedate']).apply(aa_metrics).tolist()
toa2_metrics_bydate = toa2_combo_gdf.groupby(['imagedate']).apply(aa_metrics).tolist()
toa3_metrics_bydate = toa3_combo_gdf.groupby(['imagedate']).apply(aa_metrics).tolist()
metrics_bydate = []
metrics_bydate.extend(sr_metrics_bydate)
metrics_bydate.extend(toa1_metrics_bydate)
metrics_bydate.extend(toa2_metrics_bydate)
metrics_bydate.extend(toa3_metrics_bydate)

print(metrics_bydate)
print(len(metrics_bydate))

In [None]:
metrics_bydate_df = pd.DataFrame(metrics_bydate, columns = ['source', 'date_min', 'date_max', 'acc_score', 'precision', 'precision_weighted', 'recall', 'recall_weighted', 'f1', 'f1_weighted', 'support']) 
metrics_bydate_df['Date']= pd.to_datetime(metrics_bydate_df['date_min'])
metrics_bydate_df

In [None]:
sns.set_theme(style='darkgrid', palette='Set3', font='Arial', font_scale=1.2)

# def acc_barchart(y='recall',ylabel='Recall'):
#     f, ax = plt.subplots(figsize=(20, 10))
    
f, axs = plt.subplots(4, 2, figsize=(20, 20)) # , gridspec_kw=dict(width_ratios=[4, 3]))

# Make the plots
sns.barplot(data=metrics_bydate_df, x="date_min", y="f1", hue="source", linewidth=0.5, ec='.6', ax=axs[0][0])
sns.barplot(data=metrics_bydate_df, x="date_min", y="f1_weighted", hue="source", linewidth=0.5, ec='.6', ax=axs[0][1])
sns.barplot(data=metrics_bydate_df, x="date_min", y="recall", hue="source", linewidth=0.5, ec='.6', ax=axs[1][0])
sns.barplot(data=metrics_bydate_df, x="date_min", y="recall_weighted", hue="source", linewidth=0.5, ec='.6', ax=axs[1][1])
sns.barplot(data=metrics_bydate_df, x="date_min", y="precision", hue="source", linewidth=0.5, ec='.6', ax=axs[2][0])
sns.barplot(data=metrics_bydate_df, x="date_min", y="precision_weighted", hue="source", linewidth=0.5, ec='.6', ax=axs[2][1])
sns.barplot(data=metrics_bydate_df, x="date_min", y="acc_score", hue="source", linewidth=0.5, ec='.6', ax=axs[3][0])
sns.barplot(data=metrics_bydate_df[metrics_bydate_df.source == 'sr1'], x="date_min", y="support", color='#a6cee3',linewidth=0.5, ec='.6', ax=axs[3][1])
# Clean up the Figure -- can't figure out how to do some of these things in loop, so just hacking it ehre
axs[0][1].get_legend().remove()
axs[1][0].get_legend().remove()
axs[1][1].get_legend().remove()
axs[2][0].get_legend().remove()
axs[2][1].get_legend().remove()
axs[3][0].get_legend().remove()
# axs[3][1].get_legend().remove()

axs[0][0].set(ylim=(0, 1), xlabel='', ylabel = 'F1', yticks=[0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1])
axs[0][1].set(ylim=(0, 1), xlabel='', ylabel = 'Weighted F1', yticks=[0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1])
axs[1][0].set(ylim=(0, 1), xlabel='', ylabel = 'Recall', yticks=[0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1])
axs[1][1].set(ylim=(0, 1), xlabel='', ylabel = 'Weighted Recall', yticks=[0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1])
axs[2][0].set(ylim=(0, 1), xlabel='', ylabel = 'Precision', yticks=[0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1])
axs[2][1].set(ylim=(0, 1), xlabel='', ylabel = 'Weighted Precision', yticks=[0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1])
axs[3][0].set(ylim=(0, 1), xlabel='', ylabel = 'Overall Accuracy', yticks=[0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1])
axs[3][1].set( xlabel='', ylabel = 'Support')
f.tight_layout()

f.suptitle("Accuracy Assessment Metrics by Date for 4 Training Sets & 2 Image Sources " ,
               fontsize = 'x-large' , 
               fontweight = 'bold' )
# Adjust subplots so that titles don't overlap
f.subplots_adjust(top=.95)

# # Add a legend and informative axis label
# # ax.legend(ncol=2, loc="lower right", frameon=True)

In [None]:
metrics_all_df = pd.DataFrame(metrics_all, columns = ['source', 'date_min', 'date_max', 'acc_score', 'precision', 'precision_weighted', 'recall', 'recall_weighted', 'f1', 'f1_weighted', 'support']) 
metrics_all_df

In [None]:
metrics_all_df_wide = pd.melt(metrics_all_df, id_vars=['source'], value_vars=['f1','f1_weighted',
                                                        'recall', 'recall_weighted',
                                                        'precision', 'precision_weighted',
                                                        'acc_score',
                                                        ])
metrics_all_df_wide

In [None]:
f, ax = plt.subplots(figsize=(20, 10))
sns.barplot(data=metrics_all_df_wide, x="variable", y="value", hue="source", linewidth=0.5, ec='.6')
ax.set(xlabel='', ylabel = 'Accuracy Value', yticks=[0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1])

In [None]:
# sns.relplot(data=metrics_bydate_df, x="date_min", y="f1", hue="source", kind="line", height=4, aspect=2)
# sns.relplot(data=metrics_bydate_df, x="date_min", y="f1_weighted", hue="source", kind="line", height=4, aspect=2)
# sns.relplot(data=metrics_bydate_df, x="date_min", y="recall", hue="source", kind="line", height=4, aspect=2)
# sns.relplot(data=metrics_bydate_df, x="date_min", y="recall_weighted", hue="source", kind="line", height=4, aspect=2)
# sns.relplot(data=metrics_bydate_df, x="date_min", y="precision", hue="source", kind="line", height=4, aspect=2)
# sns.relplot(data=metrics_bydate_df, x="date_min", y="precision_weighted", hue="source", kind="line", height=4, aspect=2)
