# Searching for Specific Moths

In this notebook, we will search for specific moths in a dataset. Specifically, those which are: 
- large
- sloth moths
- highest test accuracy

In [None]:
import os
import pandas as pd
import boto3
import json
from boto3.s3.transfer import TransferConfig
from PIL import Image
import numpy as np
from tqdm import tqdm

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

In [None]:
# set the working directory
os.chdir(os.path.expanduser('~/amber-inferences'))

In [None]:
region='cri'
country='costarica'
download_dir=f'./data/qc_plots/{country}'
os.makedirs(download_dir, exist_ok=True)

inference_dir = os.path.abspath(f'/gws/nopw/j04/ceh_generic/kgoldmann/{country}_inferences_tracking/')

#listdir recursively
def listdir_recursive(path):
    for root, dirs, files in os.walk(path):
        for file in files:
            yield os.path.join(root, file)

# Get all csv files in the inference directory
inference_csvs = list(listdir_recursive(inference_dir))
inference_csvs = [c for c in inference_csvs if c.endswith('.csv')]

In [None]:
len(inference_csvs)

## Plotting and Data Wrangling Functions

In [None]:
def download_images(s3_client, config, key, download_dir, bucket_name):
    download_path = os.path.join(download_dir, os.path.basename(key))
    s3_client.download_file(bucket_name, key, download_path, Config=config)

In [None]:
def initialise_session(credentials_file="credentials.json"):
    """
    Load AWS and API credentials from a configuration file and initialise an AWS session.

    Args:
        credentials_file (str): Path to the credentials JSON file.

    Returns:
        boto3.Client: Initialised S3 client.
    """
    with open(credentials_file, encoding="utf-8") as config_file:
        aws_credentials = json.load(config_file)
    session = boto3.Session(
        aws_access_key_id=aws_credentials["AWS_ACCESS_KEY_ID"],
        aws_secret_access_key=aws_credentials["AWS_SECRET_ACCESS_KEY"],
        region_name=aws_credentials["AWS_REGION"],
    )
    client = session.client("s3", endpoint_url=aws_credentials["AWS_URL_ENDPOINT"])
    return client

client = initialise_session('./credentials.json')

In [None]:
# Transfer configuration for optimised S3 download
transfer_config = TransferConfig(
    max_concurrency=20,  # Increase the number of concurrent transfers
    multipart_threshold=8 * 1024 * 1024,  # 8MB
    max_io_queue=1000,
    io_chunksize=262144,  # 256KB
)

In [None]:
def subset_by_species(inference_csvs, species_names, top_n=1, confidence_threshold=0):
    df_moths = pd.DataFrame()
    for c in tqdm(inference_csvs, desc='reading in the csvs'):
        try:
            input_df = pd.read_csv(c, low_memory=False)

        except Exception as e:
            print(f" - Error reading {c}: {e}")
            continue

        subset_df = pd.DataFrame()
        for i in range(top_n):
            col_name = f'top_{i+1}_species'
            if col_name not in input_df.columns:
                print(f" - Column {col_name} not found in {c}. Skipping this file.")
                continue
            temp = input_df.loc[(input_df[col_name].isin(species_names)) &
                                (input_df[col_name.replace('species', 'confidence')] > confidence_threshold), ]
            subset_df = pd.concat([subset_df, temp], ignore_index=True)
        prefix = os.path.basename(os.path.dirname(c)) + "/snapshot_images/"
        subset_df['key'] =  subset_df['image_path'].apply(lambda x: f"{prefix}{os.path.basename(x)}")

        df_moths = pd.concat([df_moths, subset_df], ignore_index=True)
        del subset_df
        del input_df
    return df_moths

# High Test Accuracy Moths

In [None]:
# load in the json files
def load_json(filename):
    with open(filename, 'r') as f:
        data = json.load(f)
    return data

region_list = {
    'costarica':'03',
}
country = 'costarica'

with open(f'/home/users/katriona/amber-inferences/sandbox/turing-{country}_v{region_list[country]}_taxon-accuracy.json') as f:
    accuracy = json.load(f)

info = accuracy['About']
(k := next(iter(accuracy)), accuracy.pop(k))

In [None]:
len(accuracy['species'])

In [None]:
tax_df_list = {}
tax =

for tax in accuracy.keys():
    print(tax)

    tax_acc = accuracy[tax]

    # Convert the dictionary to a DataFrame
    df = pd.DataFrame.from_dict(tax_acc, orient='index', columns=['Top1 Accuracy', 'Total Test Points'])

    # Reset the index to turn the index into a column
    df.reset_index(inplace=True)

    # Rename the index column to 'Family'
    df.rename(columns={'index': tax}, inplace=True)

    df['Total Train Points'] = df['Total Test Points']/0.15 * 0.75

    tax_df_list[tax] = df

In [None]:
#Aside: just checking something...
all_species = tax_df_list['species']['species']

# subset to where Timocratica is in all_species
species_names = [s for s in all_species if 'Timocratica' in s]
species_names

In [None]:
tax_df_list['species']

no_sig = tax_df_list['species']
no_sig = no_sig.loc[(no_sig['Top1 Accuracy'] > 90) & (no_sig['Total Train Points'] > 200)]

print(f'There are {no_sig.shape[0]} species with >90% accuracy and >200 training points')

In [None]:
# filter the inferences for these species
df_moths = subset_by_species(inference_csvs, no_sig['species'])
df_moths.head()

In [None]:
crops = pd.DataFrame(df_moths['top_1_species'].value_counts())
crops

In [None]:
df_moths.head()

# group by top_1_species, subset to 20 random rows with a mix of dep, and image_date
df_moths_subset = df_moths.groupby('top_1_species').apply(lambda x: x.sample(n=min([20, crops['count'][-1]]), random_state=42)).reset_index(drop=True)

In [None]:
download_dir = '/gws/nopw/j04/ceh_generic/kgoldmann/cr_confident_species'

# save the csv
df_moths_subset.to_csv(os.path.join(download_dir, 'cr_confident_species.csv'), index=False)

In [None]:
buffer = 5

for i, row in tqdm(df_moths_subset.iterrows(), desc='downloading images', total=df_moths_subset.shape[0]):
    try:
        download_images(client, transfer_config, row['key'], download_dir, 'cri')
    except Exception as e:
        print(f" - Error downloading {k}: {e}")

    # crop the image and save it
    image_path = os.path.join(download_dir, str(os.path.basename(row['key'])))
    os.makedirs(os.path.join(download_dir, row['top_1_species'].replace(' ', '_')), exist_ok=True)
    cropped_image_path = os.path.join(download_dir, row['top_1_species'].replace(' ', '_'), f"{row['crop_status']}_{os.path.basename(row['key'])}")
    try:
        with Image.open(image_path) as img:
            # Crop the image
            x_min = float(row['x_min']) -buffer
            y_min = float(row['y_min']) -buffer
            x_max = float(row['x_max']) +buffer
            y_max = float(row['y_max']) +buffer


            img_cropped = img.crop((x_min, y_min, x_max, y_max))
            # Save the cropped image
            img_cropped.save(cropped_image_path)
        os.remove(image_path)  # Remove the original image after cropping

    except Exception as e:
        print(f" - Error cropping {image_path}: {e}")
        continue

# Sloth Moths

In [None]:
sloth_moths = pd.read_csv('../gbif_download_standalone/species_checklists/costarica-moths-keys-nodup.csv')

In [None]:
example_sms = ['Bradypodicola hahneli',
               'Cryptoses choloepi',
               'Cryptoses waagei',
               'Cryptoses rufipictus',
               'Bradypophila garbei']

In [None]:
sloth_moths = sloth_moths.loc[sloth_moths['family_name'] == 'Pyralidae', ]
# sloth_moths = sloth_moths.loc[(sloth_moths['gbif_species_name'].isin(example_sms)) |
#                               (sloth_moths['search_species_name'].isin(example_sms)) |
#                               (sloth_moths['species_name_provided'].isin(example_sms)), ]

sloth_moths

In [None]:
df_sm = subset_by_species(inference_csvs, example_sms, 5, 0.05)

In [None]:
df_sm['top_1_species'].unique()

In [None]:
df_sm['top_1_confidence'].plot(kind='hist', bins=50, title='Top 1 Confidence for Sloth Moths')

In [None]:
download_dir = '/gws/nopw/j04/ceh_generic/kgoldmann/sloth_moths'

# save the csv
df_sm.to_csv(os.path.join(download_dir, 'sloth_moths.csv'), index=False)

In [None]:
buffer = 5

for i, row in tqdm(df_sm.iterrows(), desc='downloading images', total=df_sm.shape[0]):
    try:
        download_images(client, transfer_config, row['key'], download_dir, 'cri')
    except Exception as e:
        print(f" - Error downloading {k}: {e}")

    # crop the image and save it
    image_path = os.path.join(download_dir, str(os.path.basename(row['key'])))
    os.makedirs(os.path.join(download_dir, row['top_1_species'].replace(' ', '_')), exist_ok=True)
    cropped_image_path = os.path.join(download_dir, row['top_1_species'].replace(' ', '_'), f"{row['crop_status']}_{os.path.basename(row['key'])}")
    try:
        with Image.open(image_path) as img:
            # Crop the image
            x_min = float(row['x_min']) -buffer
            y_min = float(row['y_min']) -buffer
            x_max = float(row['x_max']) +buffer
            y_max = float(row['y_max']) +buffer


            img_cropped = img.crop((x_min, y_min, x_max, y_max))
            # Save the cropped image
            img_cropped.save(cropped_image_path)
        os.remove(image_path)  # Remove the original image after cropping

    except Exception as e:
        print(f" - Error cropping {image_path}: {e}")
        continue

In [None]:
df_sm