# Analyze the results of the prediction
1. Generate the metadataset by running all cells in `create-metadataset.ipynb`
2. Run `train.py`. Make sure to use a config with the postprocessing step `VisualizationBlock`
3. Find the directory in `outputs/` that was created at the time you ran `train.py` and copy the path
4. Paste the path in the `RESULTS_PATH` variable below
5. View the plots, explore the tabels in your IDE, and view the images with the dashboard using their `tile_id`

In [17]:
from pathlib import Path

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [18]:
RESULTS_PATH = Path('../outputs/2024-01-29/16-09-04/results.csv')

In [19]:
# Load the processed metadataset
metadataset = pd.read_csv('../data/processed/metadata.csv', index_col=0)
metadataset

Unnamed: 0_level_0,cloud,land,missing_landsat,kelp,in_train
tile_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
JW725114,0.008294,0.142604,0.000000,0.000082,True
UX493605,0.004155,0.303135,0.000000,0.007404,True
OU500661,0.039673,0.254376,0.000000,0.000000,True
DC227980,0.009371,0.429110,0.000000,0.000000,True
SS602790,0.061763,0.837020,0.000000,0.000000,True
...,...,...,...,...,...
UT495238,0.297796,0.601306,0.296580,,False
GE987629,0.307053,0.125967,0.307020,,False
EN974536,0.348498,0.714710,0.348269,,False
KI806222,0.215600,0.396090,0.133959,,False


Load prediction results csv to analyze performance

In [20]:
# Load the prediction results
results = pd.read_csv(RESULTS_PATH, index_col=0)
results.head()

Unnamed: 0_level_0,in_val,sum_targets,sum_preds,intersections,dice_coef,Unnamed: 6
image_key,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
ZZ975262,0.0,0.0,2.4934e-15,0.0,0.0,
MG581061,0.0,0.0,1.425192e-14,0.0,0.0,
MG668025,0.0,0.0,1.800233e-15,0.0,0.0,
MG965101,0.0,0.0,1.365679e-15,0.0,0.0,
MH318964,0.0,0.0,3.495734e-07,0.0,0.0,


Join the results with the metadata

In [21]:
results = metadataset.merge(results,left_index=True, right_index=True, how='inner')
results

Unnamed: 0,cloud,land,missing_landsat,kelp,in_train,in_val,sum_targets,sum_preds,intersections,dice_coef,Unnamed: 6
JW725114,0.008294,0.142604,0.000000,0.000082,True,0.0,10.000000,1.342341e+03,1.441950e-22,2.132523e-25,
UX493605,0.004155,0.303135,0.000000,0.007404,True,0.0,906.999939,1.315559e+03,7.524381e+02,6.770917e-01,
OU500661,0.039673,0.254376,0.000000,0.000000,True,0.0,0.000000,1.750560e-14,0.000000e+00,0.000000e+00,
DC227980,0.009371,0.429110,0.000000,0.000000,True,0.0,0.000000,7.180388e-15,0.000000e+00,0.000000e+00,
SS602790,0.061763,0.837020,0.000000,0.000000,True,0.0,0.000000,6.384926e+00,0.000000e+00,0.000000e+00,
...,...,...,...,...,...,...,...,...,...,...,...
CW974988,0.378718,0.085527,0.356580,0.000000,True,0.0,0.000000,4.159625e-14,0.000000e+00,0.000000e+00,
VQ623772,0.486841,0.512327,0.396473,0.000000,True,0.0,0.000000,1.833054e-14,0.000000e+00,0.000000e+00,
LX380049,0.394204,0.742816,0.326229,0.000000,True,0.0,0.000000,2.039163e-14,0.000000e+00,0.000000e+00,
OY863116,0.449755,0.585037,0.368482,0.000000,True,0.0,0.000000,1.388007e-14,0.000000e+00,0.000000e+00,


In [22]:
results['union'] = results[' sum_targets'] + results['sum_preds']

# Compute error
Error is roughly "how much would our score increase if we had perfect predictions for this error"
Lower is better.

In [26]:
# compute error (difference between dice if prediction would've been perfect and actual dice)
total_intersection = results['intersections'].sum()
total_union = results['union'].sum()

perfect_dice = (2 * results[' sum_targets'] + 2 * (total_intersection - results['intersections'])) / ((total_union - results['sum_preds'] - results[' sum_targets']) + 2 * results[' sum_targets'])
all_dice = (2 * total_intersection) / (total_union)
results['perfect_dice'] = perfect_dice
results["all_dice"] = all_dice
results['error'] = perfect_dice - all_dice
results["error_per"] = results['error'] * 100
results

Unnamed: 0,cloud,land,missing_landsat,kelp,in_train,in_val,sum_targets,sum_preds,intersections,dice_coef,Unnamed: 6,union,perfect_dice,all_dice,error,error_per
JW725114,0.008294,0.142604,0.000000,0.000082,True,0.0,10.000000,1.342341e+03,1.441950e-22,2.132523e-25,,1.352341e+03,0.747921,0.74781,1.100744e-04,0.011007
UX493605,0.004155,0.303135,0.000000,0.007404,True,0.0,906.999939,1.315559e+03,7.524381e+02,6.770917e-01,,2.222559e+03,0.747877,0.74781,6.656276e-05,0.006656
OU500661,0.039673,0.254376,0.000000,0.000000,True,0.0,0.000000,1.750560e-14,0.000000e+00,0.000000e+00,,1.750560e-14,0.747810,0.74781,0.000000e+00,0.000000
DC227980,0.009371,0.429110,0.000000,0.000000,True,0.0,0.000000,7.180388e-15,0.000000e+00,0.000000e+00,,7.180388e-15,0.747810,0.74781,0.000000e+00,0.000000
SS602790,0.061763,0.837020,0.000000,0.000000,True,0.0,0.000000,6.384926e+00,0.000000e+00,0.000000e+00,,6.384926e+00,0.747811,0.74781,5.170505e-07,0.000052
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
CW974988,0.378718,0.085527,0.356580,0.000000,True,0.0,0.000000,4.159625e-14,0.000000e+00,0.000000e+00,,4.159625e-14,0.747810,0.74781,0.000000e+00,0.000000
VQ623772,0.486841,0.512327,0.396473,0.000000,True,0.0,0.000000,1.833054e-14,0.000000e+00,0.000000e+00,,1.833054e-14,0.747810,0.74781,0.000000e+00,0.000000
LX380049,0.394204,0.742816,0.326229,0.000000,True,0.0,0.000000,2.039163e-14,0.000000e+00,0.000000e+00,,2.039163e-14,0.747810,0.74781,0.000000e+00,0.000000
OY863116,0.449755,0.585037,0.368482,0.000000,True,0.0,0.000000,1.388007e-14,0.000000e+00,0.000000e+00,,1.388007e-14,0.747810,0.74781,0.000000e+00,0.000000


In [27]:
# Now lets filter the df on where sum_targets is larger than 0 and sort on asc dice coef

df_filter = results[results['kelp'] > 0]
print(len(results), len(df_filter))
df_filter = df_filter.sort_values(by='error', ascending=True)
df_filter.head(25)

5635 3526


Unnamed: 0,cloud,land,missing_landsat,kelp,in_train,in_val,sum_targets,sum_preds,intersections,dice_coef,Unnamed: 6,union,perfect_dice,all_dice,error,error_per
UK316277,0.115037,0.425143,0.034922,8e-06,True,0.0,1.0,2.447695e-15,0.0,0.0,,1.0,0.747811,0.74781,1.355986e-07,1.4e-05
WV436832,0.000237,0.069829,0.0,8e-06,True,0.0,1.0,6.001526e-16,1.60958e-36,3.2191609999999996e-36,,1.0,0.747811,0.74781,1.355986e-07,1.4e-05
HN769989,0.000784,0.855388,0.0,8e-06,True,0.0,1.0,6.245881e-12,2.241892e-32,4.483783e-32,,1.0,0.747811,0.74781,1.355986e-07,1.4e-05
BR449497,0.212237,0.158098,0.212237,8e-06,True,0.0,1.0,4.745836e-14,0.0,0.0,,1.0,0.747811,0.74781,1.355986e-07,1.4e-05
MN946165,0.017437,0.875943,0.0,8e-06,True,0.0,1.0,1.508601e-15,6.533922e-33,1.306784e-32,,1.0,0.747811,0.74781,1.355986e-07,1.4e-05
KF876043,0.050678,0.257355,0.0,8e-06,True,0.0,1.0,1.507092e-14,5.3324819999999996e-20,1.0664959999999999e-19,,1.0,0.747811,0.74781,1.355986e-07,1.4e-05
VX117844,4.1e-05,0.210318,0.0,8e-06,True,0.0,1.0,7.493432e-16,1.31014e-22,2.62028e-22,,1.0,0.747811,0.74781,1.355986e-07,1.4e-05
HG208639,0.024939,0.478759,0.0,8e-06,True,0.0,1.0,1.787332e-15,1.913392e-22,3.8267830000000003e-22,,1.0,0.747811,0.74781,1.355986e-07,1.4e-05
HP740967,0.081543,0.896955,0.0,8e-06,True,0.0,1.0,4.319727e-15,0.0,0.0,,1.0,0.747811,0.74781,1.355986e-07,1.4e-05
VM993605,7.3e-05,0.726555,0.0,8e-06,True,0.0,1.0,7.631091e-16,1.370935e-34,2.741869e-34,,1.0,0.747811,0.74781,1.355986e-07,1.4e-05


In [None]:
# Plot the correlation matrix with sns

#Use a good colormap 
cmap = sns.diverging_palette(230, 20, as_cmap=True)
fig, ax = plt.subplots(figsize=(10,10))
corr = results.corr()
sns.heatmap(corr, annot=True, cmap=cmap, ax=ax)
plt.show()

# Plots against error

### Do we predict too much or too little kelp?


In [None]:
total_predicted = results['sum_preds'].sum()
total_actual = results['sum_targets'].sum()
print(total_predicted, total_actual)
# Print as a percentage

print(f"In total we predict {total_predicted / total_actual * 100}% of the true kelp in the dataset")
print(f"Kelp in dataset: {total_actual / (350 * 350 * len(results)) * 100}%")

results["TP"] = results["intersections"]
a = (350 * 350) - results["sum_preds"]
b = (350 * 350) - results["sum_targets"]
#Set results TN to the min(a,b)
results["TN"] = a.where(a < b, b)
results["FP"] = results["sum_preds"] - results["intersections"]
results["FN"] = results["sum_targets"] - results["intersections"]

#Print the sum of the confusion matrix with annotation f string
print(f"TP: {results['TP'].sum()}, FP: {results['FP'].sum()}, FN: {results['FN'].sum()}, TN: {results['TN'].sum()}")


# Create the confusion matrix from the results
#Print in a nice format rounded by 3 decimals
confusion_matrix = results[['TP', 'FP', 'FN', 'TN']].sum().values.reshape(2,2)
# Print the confusion matrix using seaborn in float with 3 decimals

#Print the one with percentages on the right

confusion_matrix_perc = (confusion_matrix / confusion_matrix.sum()) * 100
fig, ax = plt.subplots(figsize=(10,10))

#Name the axis correctly to the format of a heatmap with TP, FP, FN, TN

sns.heatmap(confusion_matrix_perc, annot=True, fmt='.3f', cmap=cmap)
#Add correct predictions and labels to the confusion matrix axes
# Set x axis to 0,1
ax.set_xticklabels([1,0])
ax.set_yticklabels([1,0])
ax.set_ylabel('Predicted labels')
ax.set_xlabel('True labels')



plt.show()





In [24]:
# Plot scatterplots of error
for col in ['dice_coef','kelp', 'land', 'cloud', 'missing_landsat']:
    sns.scatterplot(data=results, x=col, y='error')
    plt.show()

ValueError: Could not interpret value `error` for `y`. An entry with this name does not appear in `data`.

# Plots against dice coef

In [25]:
for col in ['error','kelp', 'land', 'cloud', 'missing_landsat']:
    sns.scatterplot(data=results, x=col, y='dice_coef')
    plt.show()

ValueError: Could not interpret value `error` for `x`. An entry with this name does not appear in `data`.

In [None]:
# There is a clear correlation between error and kelp. 
# Create a new column that is the error divided by kelp
# Then make the plots against error again
results['error_per_kelp'] = results['error'] / results['kelp']


# remove outliers with an error_per_kelp larger than n stds
n = 2
results_clean = results[results['kelp'] > 0]
mean = results_clean['error_per_kelp'].mean()
stds = results_clean['error_per_kelp'].std()
results_clean = results_clean[results_clean['error_per_kelp'] < mean + n*stds]


for col in ['dice_coef','kelp', 'land', 'cloud', 'missing_landsat']:
    
    sns.scatterplot(data=results_clean, x=col, y='error_per_kelp')
    plt.show()