<a href="https://colab.research.google.com/github/FFI-Vietnam/camtrap-tools/blob/main/04_create_visualization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""
This script evaluates recall accuracy of MegaDetector from a ground-truth dataset
and a result json file. Then it creates a set of visualization of recall values 
on each group of species.

After runnning this script, a 'visualizations' folder is created

visualizations
    |__ 
    
"""

In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import json
import requests
import os
from tqdm.notebook import tqdm

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# specifies Colab directories and file names
root = '/content/drive/'

dataset_folder = 'My Drive/FFI/MegaDetector Test/confusion-matrix/dataset'
contain_folder = 'My Drive/FFI/MegaDetector Test/confusion-matrix/data cleaning'
image_folder = 'My Drive/FFI/MegaDetector Test/confusion-matrix/visualization'

ground_truth_file_name = '01_ground-truth-table_Kon-Plong.csv'
MD_result_file_name = 'MegaDetector_result_2021-08-27.json'
taxon_match_table_file_name = '02_taxon-database-with-conservation-status.csv'
confusion_matrix_file_name = '03_confusion-matrix.csv'

In [None]:
# read and save file functions
def read_csv_Google_drive(root, contain_folder, file_name):
  file_path = os.path.join(root, contain_folder, file_name)
  return pd.read_csv(file_path)

def save_csv_Google_drive(df, root, contain_folder, file_name, index=False):
  """
  function to save a csv file to Google Drive
  param examples:
    root = '/content/drive/'
    contain_folder = 'My Drive/FFI/dataset'
    file_name = 'image_metadata(2020-06-26)_full.csv'
  """
  # save file to Colab runtime storage (will be deleted when this notebook is closed)
  df.to_csv('dataframe.csv', index=index)

  # save file back to Google Drive for permanent storage
  folder_path = os.path.join(root, contain_folder)
  file_path = os.path.join(root, contain_folder, file_name)
  try:
    os.makedirs(folder_path)
  except:
    pass

  with open('dataframe.csv', 'r') as f:
    df_file = f.read()

  with open(file_path, 'w') as f:
    f.write(df_file)

  print(f'File is saved to {file_name} in Google Drive at {file_path}')

def save_image_Google_drive(plt, root, contain_folder, file_name):
  """
  function to save an image file to Google Drive
  param examples:
    root = '/content/drive/'
    contain_folder = 'My Drive/FFI/dataset'
    file_name = 'MD_recall_all-species.jpg'
  """

  # save file back to Google Drive for permanent storage
  folder_path = os.path.join(root, contain_folder)
  file_path = os.path.join(root, contain_folder, file_name)
  try:
    os.makedirs(folder_path)
  except:
    pass

  plt.savefig(file_path)
  
  print(f'File is saved to {file_name} in Google Drive at {file_path}')

In [None]:
def create_confusion_matrix_by_group(taxon_match_table, confusion_matrix, species_group):
  """
  create a confusion matrix by species group, e.g. "ungulates", "small carnivores", "birds", "small mammals (squirrels and rats)", "primates"
  @params species_group: dict {group_common_name: [taxonomical_level, group_scientific_name, except]}
    example: {"Birds":["class", "Aves", []], -> all Birds
              "Small carnivores":["order", "Carnivora", ["Asian Black Bear"]]} -> all carnivores excepts for Bears
  """
  confusion_matrix_by_group = pd.DataFrame(index = ['Animal', 'Human', 'Blank', 'Total', 'Recall'])
  for group in species_group.keys():
    species_list = find_FFI_species_by_taxonomy(taxon_match_table, 
                                                species_group[group][0], 
                                                species_group[group][1])
    # remove except species, e.g. Asian Black Bear not in Small Carnivores
    species_list = list(set(species_list) - set(species_group[group][2]))
    # remove unnecessary species such as Bat, Maxomys, etc
    species_list = list(set(species_list) & set(confusion_matrix.columns))
    # lump into one group
    confusion_matrix_by_group[group] = confusion_matrix[species_list].astype('float').sum(axis = 1)

  # update recall
  for col in confusion_matrix_by_group.columns:
    if col == 'Human':
      confusion_matrix_by_group[col][4] = round(int(confusion_matrix_by_group[col][1]) / int(confusion_matrix_by_group[col][3]), 2)
    else:
      confusion_matrix_by_group[col][4] = round(int(confusion_matrix_by_group[col][0]) / int(confusion_matrix_by_group[col][3]), 2)
  return confusion_matrix_by_group

confusion_matrix_by_group = \
create_confusion_matrix_by_group(taxon_match_table, confusion_matrix, {'Ungulates'       :['order', 'Cetartiodactyla', []],
                                                                       'Birds'           :['class', 'Aves', []],
                                                                       'Small carnivores':['order', 'Carnivora', ["Asian Black Bear"]],
                                                                       'Small mammals'   :['order', 'Rodentia', []],
                                                                       'Primates'        :['order', 'Primates', ['Human']],
                                                                       'Bear'            :['family', 'Ursidae', []],
                                                                       'Pangolin'        :['family', 'Manidae', []],
                                                                       'Human'           :['species', 'sapiens', []]
                                                                       }) 

In [None]:
def create_confusion_matrix_by_conservation_status(taxon_match_table, confusion_matrix, conservation_status):
  """
  create a confusion matrix by conservation status
  @params conservation_status: list of status
    example: ["Endangered", "Vulnerable"]
  """
  def status_of(species):
    return taxon_match_table[taxon_match_table['FFI_species_name'] == species]['conservation_status'].iat[0]
  
  confusion_matrix_by_conservation_status = pd.DataFrame(index = ['Animal', 'Human', 'Blank', 'Total', 'Recall'])
  for species in confusion_matrix.columns:
    try: # avoid unrecorded species such as Banded Krait
      if status_of(species) in conservation_status:
        confusion_matrix_by_conservation_status[species] = confusion_matrix[species]
    except:
      pass
  return confusion_matrix_by_conservation_status

confusion_matrix_by_conservation_status = \
create_confusion_matrix_by_conservation_status(taxon_match_table, confusion_matrix, ["Endangered", "Vulnerable", "Critically Endangered"])

confusion_matrix_by_conservation_status

In [None]:
def create_confusion_matrix_in_group(taxon_match_table, confusion_matrix, species_group):
  """
  create a confusion matrix in each species group, e.g. "ungulates", "small carnivores", "birds", "small mammals (squirrels and rats)", "primates"
  @params species_group: list [group_common_name, taxonomical_level, group_scientific_name, except]
    example: + ["Birds", "class", "Aves", []] -> for all Birds
             + ["Small carnivores", "order", "Carnivora", ["Asian Black Bear"]]} -> for all carnivores excepts for Bears
  """
  confusion_matrix_in_group = pd.DataFrame(index = ['Animal', 'Human', 'Blank', 'Total', 'Recall'])
  species_list = find_FFI_species_by_taxonomy(taxon_match_table, 
                                              species_group[1], 
                                              species_group[2])
  # remove except species, e.g. Asian Black Bear not in Small Carnivores
  species_list = list(set(species_list) - set(species_group[3]))
  # remove unnecessary species such as Bat, Maxomys, etc
  species_list = list(set(species_list) & set(confusion_matrix.columns))
  # get columns by species
  confusion_matrix_in_group = confusion_matrix[species_list].astype('float')

  # update recall
  for col in confusion_matrix_in_group.columns:
    if col == 'Human':
      confusion_matrix_in_group[col][4] = round(int(confusion_matrix_in_group[col][1]) / int(confusion_matrix_in_group[col][3]), 2)
    else:
      confusion_matrix_in_group[col][4] = round(int(confusion_matrix_in_group[col][0]) / int(confusion_matrix_in_group[col][3]), 2)
  return confusion_matrix_in_group

confusion_matrix_in_group = create_confusion_matrix_in_group(taxon_match_table, confusion_matrix, ['Ungulates', 'order', 'Cetartiodactyla', []])
confusion_matrix_in_group

In [None]:
def visualize_recall_bargraph(group_name, confusion_matrix, threshold, color_dict, num_image_threshold=0, legend_by_color=False, custom_size=None, save_fig=False):
  """
  visualize recall values by bargraph by taxonomical class
  """

  print(f"Generating visualization plot for {group_name}...")
  total = []
  recall = []
  names = []
  colors = []

  image_count = len(mega_result['images'])

  species_list = confusion_matrix.columns.to_list()
  for species in species_list:
    if species not in ['All']:
      if confusion_matrix[species]['Total'] >= num_image_threshold:
        colors.append(color_dict[species])
        recall.append(confusion_matrix[species]['Recall'])
        total.append(confusion_matrix[species]['Total'])
        names.append((species))

  # create df to sort recall values
  recall_stats = pd.DataFrame({'name':names, 'recall':recall, 'total':total, 'colors':colors})
  recall_stats['recall'] = recall_stats['recall'].astype('float')
  recall_stats['total'] = recall_stats['total'].astype('int')
  recall_stats.sort_values('recall', inplace=True, ascending=True)

  # creating the bar plot
  species_name = recall_stats['name'].to_list()
  values = recall_stats['recall'].to_list()
  total = recall_stats['total'].to_list()
  colors = recall_stats['colors'].to_list()
  avg_recall = np.mean(recall_stats['recall'])

  if custom_size:
    fig = plt.figure(figsize=custom_size)
    
  plt.barh(species_name, values, color=colors)
  for i, v in enumerate(values):
    recall_value_text_location = v
    plt.text(recall_value_text_location, i, str(round(v,2)), color='blue', fontweight='bold')
    if not custom_size:
      num_image_text_location = v + 0.09
    else:
      num_image_text_location = v + 1/(1.5*custom_size[0])
    plt.text(num_image_text_location, i, f'{str(total[i])} images', color='blue', fontweight='bold')
  # for i, v in enumerate(total):

  plt.xlabel("Values")
  plt.ylabel("Species name")
  labels = list(color_map.keys())
  if legend_by_color:
    handles = [plt.Rectangle((0,0),1,1, color=color_map[label]) for label in labels]
    plt.legend(handles, labels)
  plt.title(f"Evaluate over {image_count} images for {group_name}" +
            f"\nminimum image amount: {num_image_threshold}" +
            f"\nthreshold:  {threshold}" +
            f"\naverage_recall_value: {round(avg_recall,2)}")
  if save_fig:
    save_image_Google_drive(plt, root, image_folder, save_fig)
  plt.show()
