# Across Country Plots

Plots: 
- Latitude
  - vs crop size
  - vs activity
- Activity through the season


In [None]:
import os
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
from amber_inferences.utils.config import load_credentials
from amber_inferences.utils.deployment_summary import deployment_data

In [None]:
inference_csvs = []

for root, dirs, files in os.walk(os.path.abspath(f'/gws/ssde/j25b/ceh_generic/kgoldmann/')):
    for file in files:
        if file.endswith('.csv') and 'compute' not in file and 'gpu' not in file and 'inferences_tracking' in root:
            inference_csvs.append(os.path.join(root, file))


In [None]:
len(inference_csvs)

In [None]:
download_dir=f'./data/qc_plots/all'
plot_dir = f'./sandbox/plots/all'
os.makedirs(download_dir, exist_ok=True)
os.makedirs(plot_dir, exist_ok=True)

In [None]:
# Modern Web Design Color Palette
# Define web-optimized colors for better accessibility and modern design

web_colors = ['#2E86AB', '#A23B72', '#F18F01', '#C73E1D', '#7209B7',
            '#2F9B69', '#dbb037','#a8dadc', '#264653']

web_colors = {
    'primary': web_colors[0],
    'secondary': web_colors[1],
    'accent1': web_colors[2],
    'accent2': web_colors[3],
    'accent3': web_colors[4],
    'accent4': web_colors[5],
    'accent5': web_colors[6],
    'accent6': web_colors[7],
    'text': web_colors[8],
    'background': '#f5f5f5'  # Light background
}



# Create matplotlib-compatible color list
modern_palette = [web_colors['primary'], web_colors['secondary'], web_colors['accent1'],
                  web_colors['accent2'], web_colors['accent3'], web_colors['accent4'],
                  web_colors['accent5'], web_colors['accent6']]

# Set as default seaborn palette
sns.set_palette(modern_palette)

# display the colors with names underneath
plt.figure(figsize=(10, 2))
for i, (name, color) in enumerate(web_colors.items()):
    plt.fill_between([i, i + 1], 0, 1, color=color)
    plt.text(i + 0.5, -0.1, name, ha='center', va='top', fontsize=10)
plt.xlim(0, len(web_colors))
plt.ylim(-0.2, 1)
plt.axis('off')
plt.show()

In [None]:
# make a colour map so Anguilla is primary, Costa Rica is secondary, Kenya is accent 2, Singapore is accent 5, and Thailand is accent 5
country_map = {
    'Anguilla': web_colors['primary'],
    'Costa Rica': web_colors['secondary'],
    'costarica': web_colors['secondary'],
    'Kenya': web_colors['accent2'],
    'Singapore': web_colors['accent5'],
    'Thailand': web_colors['accent6']
}


bucket_names = {
    'anguilla': 'aia',
    'Costa Rica': 'cri',
    'costarica': 'cri',
    'kenya': 'ken',
    'singapore': 'sgp',
    'thailand': 'tha'
}

# General Plots

## Number of crops

In [None]:
# Count unique image paths per country and deployment
def count_unique_images_per_country_dep(inference_csvs):
    """
    Read all inference CSVs and count unique image paths per country and deployment.
    Returns a DataFrame with columns: country, dep, unique_image_count
    """
    results = []

    for c in tqdm(inference_csvs, desc='Counting unique images per country/deployment'):
        try:


            # Read the full dataset with appropriate columns
            df = pd.read_csv(c, usecols= ['image_path'], low_memory=False, on_bad_lines='skip')

            if df.empty:
                continue

            # Extract deployment and country from file path
            country = os.path.dirname(c).split('/')[-2].replace('_inferences_tracking', '')

            # Count unique image paths for this CSV
            unique_images = df['image_path'].nunique()



            dep = os.path.basename(c).split('.')[0].split('_')[0]

            results.append({
                'country': country,
                'dep': dep,
                'unique_image_count': unique_images,
                'csv_file': os.path.basename(c)
            })

        except Exception as e:
            print(f" - Error reading {c}: {e}")
            continue

    # Convert to DataFrame
    results_df = pd.DataFrame(results)

    if not results_df.empty:
        # Sort by country then by deployment
        results_df = results_df.sort_values(['country', 'dep']).reset_index(drop=True)

        # Add summary statistics
        print(f"\nSummary:")
        print(f"Total countries: {results_df['country'].nunique()}")
        print(f"Total deployments: {results_df['dep'].nunique()}")
        print(f"Total unique images across all deployments: {results_df['unique_image_count'].sum():,}")

        # Show summary by country
        country_summary = results_df.groupby('country').agg({
            'dep': 'nunique',
            'unique_image_count': 'sum'
        }).rename(columns={'dep': 'num_deployments', 'unique_image_count': 'total_images'})

        print(f"\nImages per country:")
        print(country_summary.to_string())

    return results_df

In [None]:
# Execute the function to get unique image counts
image_counts_df = count_unique_images_per_country_dep(inference_csvs)
image_counts_df.head(10)

In [None]:
# Get deployment names for the image counts data
credentials = load_credentials("../credentials.json")
def get_deployment_names_for_images(image_counts_df):
    """
    Get deployment names for the image counts data using the same approach as moth_df
    """
    deployment_info = []

    for bucket_name in image_counts_df['country'].unique():
        print(f"Processing {bucket_name}")
        deps = image_counts_df.loc[image_counts_df['country'] == bucket_name, 'dep'].unique()

        print(f" - Deployment: {deps}")

        # Get deployment data
        try:
            dep_data = deployment_data(
                credentials,
                subset_countries=[bucket_names[bucket_name]],
                subset_deployments=list(deps),  # Convert to list
                include_file_count=False,
            )
            # If dep_data is a list of dictionaries, extend the list
            if isinstance(dep_data, list):
                deployment_info.extend(dep_data)
            else:
                deployment_info.append(dep_data)
        except Exception as e:
            print(f" - Error fetching deployment data for {deps}: {e}")
            continue

    # Convert deployment info to DataFrame
    deployment_df = pd.DataFrame()
    for i in range(len(deployment_info)):
        sub_dict = deployment_info[i]
        dep_df = pd.DataFrame.from_dict(sub_dict, orient='index')
        deployment_df = pd.concat([deployment_df, dep_df], ignore_index=True)

    deployment_df = deployment_df.reset_index(drop=True)

    # Merge with image counts data
    image_counts_with_names = image_counts_df.merge(
        deployment_df,
        left_on='dep',
        right_on='deployment_id',
        how='left'
    )

    image_counts_with_names.columns = image_counts_with_names.columns.str.replace('_x', '')

    return image_counts_with_names



In [None]:
# Apply the function to get deployment names
image_counts_with_names = get_deployment_names_for_images(image_counts_df)


In [None]:
# Create a visualization of image counts by country and deployment
def plot_image_counts_by_country_dep(df, save_fig=False):
    """
    Create a bar plot showing unique image counts per deployment by country.
    """
    if df.empty:
        print("No data to plot")
        return

    # Set up the plot style
    plt.style.use('default')
    sns.set_palette(modern_palette)

    # Clean country names
    df = df.copy()
    df['country'] = df['country'].str.replace('rica', ' rica').str.title()

    # Sort by country and image count for better visualization
    # df = df.sort_values(['country', 'unique_image_count'], ascending=[True, False])

    # Create consistent country-color mapping
    all_countries = sorted(df['country'].unique())

    # Create figure
    fig, ax = plt.subplots(figsize=(16, 8), dpi=300)

    # Create the bar plot
    sns.barplot(
        data=df,
        x="location_name",
        y="unique_image_count",
        hue="country",
        palette=country_map,
        dodge=False,
        ax=ax,
        alpha=0.8,
        errorbar=None
    )

    # Customize the plot
    ax.set_xlabel("Deployment", fontsize=14, fontweight='bold')
    ax.set_ylabel("Number of Images", fontsize=14, fontweight='bold')
    ax.set_title("Number of Images Processed by Deployment and Country",
                fontsize=16, fontweight='bold', color=web_colors['text'], pad=20)

    # Rotate x-axis labels for better readability
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", fontsize=10)

    # Format y-axis with commas
    ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{int(x):,}'))

    # Customize legend
    legend = ax.legend(
        title='Country',
        title_fontsize=12,
        fontsize=10,
        loc='center left',
        bbox_to_anchor=(1, 0.5),
        frameon=True,
        fancybox=True,
        shadow=True
    )
    legend.get_frame().set_alpha(0.9)

    # Add grid and styling
    ax.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)
    ax.set_axisbelow(True)
    ax.set_facecolor(web_colors['background'])
    sns.despine()

    plt.tight_layout()
    plt.show()

    if save_fig:
        plt.savefig('images_processed.png', dpi=300, bbox_inches='tight',
                   facecolor='white', edgecolor='none')

# Plot the results
plot_image_counts_by_country_dep(image_counts_with_names, save_fig=True)

# Number of Crops and Moths by Deployment and Country

In [None]:
def create_input_df(inference_csvs):
    complete_df = pd.DataFrame()
    for c in tqdm(inference_csvs, desc='reading in the csvs'):
        try:
            input_df = pd.read_csv(c, low_memory=False, on_bad_lines='skip')

        except Exception as e:
            print(f" - Error reading {c}: {e}")
            continue

        try:
            input_df = input_df.loc[input_df['crop_status'] != 'No detections for this image.', ]
            input_df = input_df.loc[input_df['crop_status'] != 'Image corrupt', ]

            if input_df.shape[0] == 0:
                # print(f"  - No detections in {os.path.basename(c)}")
                continue
            input_df = input_df.drop_duplicates(subset=[ 'x_min', 'x_max', 'y_min', 'y_max'])

            input_df['dep'] = os.path.basename(c).split('.')[0].split('_')[0]

            input_df['crop_area'] = (input_df['x_max'] - input_df['x_min']) * (input_df['y_max'] - input_df['y_min'])


            input_df['country'] = os.path.dirname(c).split('/')[-2].replace('_inferences_tracking', '')

            complete_df = pd.concat([complete_df, input_df[['dep',  'crop_area', 'order_name', 'country']]], ignore_index=True)
        except Exception as e:
            print(f" - Error processing {c}: {e}")
            continue
        del input_df

    complete_df = complete_df.reset_index(drop=True)
    return complete_df


In [None]:
moth_df = create_input_df(inference_csvs)

In [None]:
moth_df.head()

In [None]:
# Apply the function to get deployment names for moth_df
moth_df_with_names = get_deployment_names_for_images(moth_df)


In [None]:
moth_df_with_names.head()

In [None]:
def counts_by_country_dep(df, save_fig=False, plot_width=15, subset=False):
    """
    Create publication-ready plot showing file counts per deployment by country for a specific data type.
    """
    # Set publication-quality style with modern web colors
    plt.style.use('default')
    sns.set_palette(modern_palette)  # Use modern web palette

    df['country'] = df['country'].str.replace('rica', ' rica').str.title()

    sub_str = 'Crop'
    if subset:
        df = df.loc[(df['order_name'].str.contains('Lepidoptera'))]
        sub_str='Moth'


    # get the number of rows per deployment and country
    df2 = df.groupby(['location_name', 'country']).size().reset_index(name='count')
    df2 = df2.rename(columns={'country': 'Country'})

    # Sort by country then count for better visual organization
    df2 = df2.sort_values(by=['Country', 'count'], ascending=[True, False])

    # Create consistent country-color mapping
    all_countries = sorted(df2['Country'].unique())  # Get all countries from full dataset
    country_color_map = {country: modern_palette[i % len(modern_palette)]
                        for i, country in enumerate(all_countries)}

    # Create figure with specific size for publication
    fig, ax = plt.subplots(figsize=(plot_width, 6), dpi=300)

    # Create the bar plot with modern web styling and consistent country colors
    sns.barplot(
        data=df2,
        x="location_name",
        y="count",
        hue="Country",
        palette=country_map,  # Use consistent country-color mapping
        dodge=False,  # No dodging needed since we're coloring by country
        ax=ax,
        alpha=0.8  # Slight transparency for elegance
    )

    # Customize the plot for publication quality
    ax.set_xlabel("Deployment", fontsize=14, fontweight='bold')
    ax.set_ylabel(f"Number of {sub_str}s", fontsize=14, fontweight='bold')
    ax.set_title(f"{sub_str} Counts by Deployment and Country",
                fontsize=16, fontweight='bold', color=web_colors['text'], pad=20)

    # Rotate x-axis labels for better readability
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", fontsize=10)

    # Improve y-axis
    ax.tick_params(axis='y', labelsize=11)
    ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{int(x):,}'))

    # Customize legend
    legend = ax.legend(
        title='Country',
        title_fontsize=12,
        fontsize=10,
        loc='center left',
        bbox_to_anchor=(1, 0.5),
        frameon=True,
        fancybox=True,
        shadow=True
    )
    legend.get_frame().set_alpha(0.9)

    # Add grid for better readability
    ax.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)
    ax.set_axisbelow(True)
    sns.despine() # Remove top and right spines
    plt.tight_layout(pad=2.0)


    ax.set_facecolor(web_colors['background'])

    plt.show()

    if save_fig:
        plt.savefig(f'{sub_str}_counts_by_deployment.png', dpi=300, bbox_inches='tight',
                   facecolor='white', edgecolor='none')

In [None]:
# plot the number of rows per deployment and country
counts_by_country_dep(moth_df_with_names)

In [None]:
counts_by_country_dep(moth_df_with_names, subset=True)

# Latitude Plots

In [None]:
deployment_info

In [None]:
all_df = pd.DataFrame()
for i in range(len(deployment_info)):
    sub_dict = deployment_info[i]
    dep_df = pd.DataFrame.from_dict(sub_dict, orient='index')
    all_df = pd.concat([all_df, dep_df], ignore_index=True)

all_df = all_df.reset_index(drop=True)
all_df

In [None]:
# match moth_df with deployment_info
moth_df = moth_df.merge(all_df, left_on='dep', how='left', right_on='deployment_id')
moth_df

In [None]:
# subset to order_name containing 'Lepidoptera'
moth_sub = moth_df[moth_df['order_name'].str.contains('Lepidoptera', na=False)]

In [None]:
moth_sub['country'] = moth_sub['country_x']

In [None]:
# create a violin plot of latitude vs crop_area, subplot by country
# Set up a nice color palette
colors = sns.color_palette("Set2", n_colors=len(moth_sub['country'].unique()))
country_colors = dict(zip(sorted(moth_sub['country'].unique()), colors))

fig, axes = plt.subplots(2, 3, figsize=(18, 12), sharey=True)
fig.patch.set_facecolor('white')
axes = axes.flatten()

for i, (country, group) in enumerate(moth_sub.groupby('country')):
    ax = axes[i]

    # Create violin plot with custom styling
    parts = ax.violinplot([group[group['dep'] == dep]['crop_area'].values
                          for dep in sorted(group['dep'].unique())],
                         positions=range(len(group['dep'].unique())),
                         showmeans=True, showmedians=True)

    # Style the violin plot
    for pc in parts['bodies']:
        pc.set_facecolor(country_colors[country])
        pc.set_alpha(0.7)
        pc.set_edgecolor('darkgray')
        pc.set_linewidth(1)

    # Style the statistical lines
    parts['cmeans'].set_color('red')
    parts['cmeans'].set_linewidth(2)
    parts['cmedians'].set_color('darkblue')
    parts['cmedians'].set_linewidth(2)
    parts['cbars'].set_color('black')
    parts['cmins'].set_color('black')
    parts['cmaxes'].set_color('black')

    # Customize the axis
    ax.set_title(country.title(), fontsize=14, fontweight='bold',
                color=country_colors[country], pad=20)
    ax.set_xlabel('Deployment', fontsize=12, fontweight='bold')
    ax.set_ylabel('Crop Area (pixels²)' if i % 3 == 0 else '', fontsize=12, fontweight='bold')
    ax.set_yscale('log')

    # Set x-axis labels
    ax.set_xticks(range(len(group['dep'].unique())))
    ax.set_xticklabels(sorted(group['dep'].unique()), rotation=45, ha='right', fontsize=10)

    # Add subtle grid
    ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
    ax.set_facecolor('#fafafa')

    # Add border
    for spine in ax.spines.values():
        spine.set_edgecolor('gray')
        spine.set_linewidth(1.2)

# Style remaining empty plots
for ax in axes[len(moth_sub['country'].unique()):]:
    ax.set_facecolor('white')
    ax.set_xticks([])
    ax.set_xticklabels([])
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.grid(False)
    for spine in ax.spines.values():
        spine.set_visible(False)

plt.tight_layout()
plt.suptitle('Crop Area Distribution of Lepidoptera by Deployment and Country',
             fontsize=18, fontweight='bold', y=0.98, color='darkslategray')

# Add a subtle background color to the figure
fig.patch.set_facecolor('#f8f9fa')

plt.savefig(os.path.join(plot_dir, 'crop_area_vs_deployment_by_country.png'),
            dpi=300, bbox_inches='tight', facecolor='white')
plt.show()


In [None]:
# average number of detections per night, by deployment
avg_detections = moth_df.groupby(['dep', 'country_x']).size().reset_index(name='detections')


In [None]:
moth_df.shape

In [None]:
moth_df['order_name'].value_counts()

## Seasonal Activity Plots

In [None]:
def activity_pop(inference_csvs):
    complete_df = pd.DataFrame()
    for c in tqdm(inference_csvs, desc='reading in the csvs'):
        try:
            input_df = pd.read_csv(c, low_memory=False)

        except Exception as e:
            print(f" - Error reading {c}: {e}")
            continue

        try:
            input_df = input_df.loc[input_df['crop_status'] != 'No detections for this image.', ]
            input_df = input_df.loc[input_df['crop_status'] != 'Image corrupt', ]

            if input_df.shape[0] == 0:
                continue

            if input_df.shape[0] == 0:
                # print(f"  - No detections in {os.path.basename(c)}")
                continue
            input_df = input_df.drop_duplicates(subset=['image_path', 'x_min', 'x_max', 'y_min', 'y_max'])


            # if deployment_name in columns set dep=deployment_name, else split on basename
            if 'deployment_name' in input_df.columns:
                input_df['dep'] = input_df['deployment_name']
            else:
                input_df['dep'] = os.path.basename(c).split('.')[0].split('_')[0]

            # get the date from the filename
            input_df['date'] = os.path.basename(c).split('.')[0].split('_')[-1]

            # get the average number of detections per depoyment per night
            input_df['avg_detections'] = input_df.groupby(['dep', 'date', 'order_name'])['image_path'].transform('count')


            input_df['country'] = os.path.dirname(c).split('/')[-2].replace('_inferences_tracking', '')
            df = input_df[['dep', 'date', 'order_name', 'country']].value_counts().reset_index()


            # # if 'latitude' in input_df.columns:
            # if 'latitude' not in input_df.columns:
            #     input_df['latitude'] = 0
            complete_df = pd.concat([complete_df, df], ignore_index=True)
        except Exception as e:
            print(f" - Error processing {c}: {e}")
            continue
        del input_df

    complete_df = complete_df.reset_index(drop=True)
    return complete_df

In [None]:
activity_df = activity_pop(inference_csvs)

In [None]:
# for each deployment get the average count
activity_mean = activity_df.groupby(['date', 'order_name', 'country'])['count'].mean().reset_index()

activity_mean.head()

In [None]:
activity_mean['date'] = pd.to_datetime(activity_mean['date'], format='%Y-%m-%d')

# get the rolling average of the count
activity_mean['rolling_count_5'] = activity_mean.groupby(['order_name', 'country'])['count'].transform(lambda x: x.rolling(window=5, min_periods=1).mean())
activity_mean['rolling_count_7'] = activity_mean.groupby(['order_name', 'country'])['count'].transform(lambda x: x.rolling(window=7, min_periods=1).mean())

# Add day of year and month for seasonal analysis
activity_mean['day_of_year'] = activity_mean['date'].dt.dayofyear

# if year = 2025, add a year to day_of_year
activity_mean.loc[activity_mean['date'].dt.year == 2025, 'day_of_year'] += 365

activity_mean['month'] = activity_mean['date'].dt.month
activity_mean['month_name'] = activity_mean['date'].dt.strftime('%b')

# Check the data
print("Date range:", activity_mean['date'].min(), "to", activity_mean['date'].max())
print("Countries:", activity_mean['country'].unique())
activity_mean.head(20)

In [None]:
# Create seasonal activity plot - all countries on same plot with day of year as x-axis

def plot_seasonal_activity_by_order(activity_df, column='count'):
    df = activity_df.copy()

    # sort by date
    df = df.sort_values(by='date')
    df.rename(columns={'order_name': 'Order'}, inplace=True)

    # Set up beautiful color palette for countries
    countries = sorted(df['country'].unique())
    colors = sns.color_palette("husl", n_colors=len(countries))  # Vibrant, distinct colors
    country_colors = dict(zip(countries, colors))

    # Apply colors to the palette
    sns.set_palette([country_colors[country] for country in countries])

    # Create the FacetGrid using day_of_year with enhanced styling
    g = sns.FacetGrid(df, col='Order', hue='country', col_wrap=3, height=5, aspect=1.2,
                      sharey=False, sharex=False, margin_titles=True)
    g.map(sns.lineplot, 'day_of_year', column, linewidth=2.5, alpha=0.8, marker='o', markersize=4)

    # Set labels and formatting for each subplot
    for ax in g.axes.flat:
        ax.set_xlabel('')
        ax.set_ylabel('Average Crops per Night', fontsize=11, fontweight='bold')

        # Enhanced grid styling
        ax.grid(True, alpha=0.4, linestyle='-', linewidth=0.5, color='lightgray')
        ax.set_facecolor('#fafafa')  # Light background

        # Style the subplot borders
        for spine in ax.spines.values():
            spine.set_edgecolor('darkgray')
            spine.set_linewidth(1.5)

        # Enhance subplot titles
        title = ax.get_title()
        if title:
            # Extract order name and make it more readable
            order_name = title.split(' = ')[-1] if ' = ' in title else title
            ax.set_title(order_name.replace('_', ' ').title(),
                        fontsize=12, fontweight='bold', pad=15, color='darkslategray')

        # Add month labels on bottom with better styling
        month_starts = [1, 32, 60, 91, 121, 152, 182, 213, 244, 274, 305, 335]
        month_labels = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']

        # Add second year months if data spans multiple years
        if ax.get_xlim()[1] > 365:
            month_starts_year2 = [x + 365 for x in month_starts]
            month_starts.extend(month_starts_year2)
            month_labels.extend(month_labels)

        xlim = ax.get_xlim()
        visible_months = []
        visible_labels = []

        for month_day, label in zip(month_starts, month_labels):
            if xlim[0] <= month_day <= xlim[1]:
                visible_months.append(month_day)
                visible_labels.append(label)

        if visible_months:
            ax.set_xticks(visible_months)
            ax.set_xticklabels(visible_labels, fontsize=10, fontweight='bold')
            ax.set_xlabel('Month', fontsize=11, fontweight='bold', color='darkslategray')

        # Add year labels with enhanced styling
        ax_year = ax.twiny()
        ax_year.set_frame_on(False)
        ax_year.xaxis.set_ticks_position('bottom')
        ax_year.xaxis.set_label_position('bottom')
        ax_year.spines['bottom'].set_position(('outward', 35))

        subplot_data = df
        if len(subplot_data) > 0:
            years = sorted(subplot_data['date'].dt.year.unique())
            year_positions = []
            year_labels = []

            for year in years:
                if year == years[0]:
                    year_pos = min(df['day_of_year'])
                else:
                    year_pos = 0 + (year - years[0]) * 365

                if xlim[0] <= year_pos <= xlim[1]:
                    year_positions.append(year_pos)
                    year_labels.append(str(year))

            if year_positions:
                ax_year.set_xticks(year_positions)
                ax_year.set_xticklabels(year_labels, fontsize=11, weight='bold', color='navy')
                ax_year.set_xlim(xlim)
                ax_year.set_xlabel('Year', fontsize=11, fontweight='bold', color='navy')

    # Handle insufficient data with better styling
    for ax in g.axes.flat:
        if not ax.has_data():
            ax.remove()
        else:
            order_name = ax.get_title().split(' = ')[-1] if ' = ' in ax.get_title() else ''
            order_data = df[df['Order'] == order_name] if order_name else pd.DataFrame()

            if len(order_data) < 5:
                # Style for insufficient data
                ax.set_facecolor('#f0f0f0')
                ax.text(0.5, 0.5, f'Insufficient Data\n({len(order_data)} observations)',
                       transform=ax.transAxes, ha='center', va='center',
                       fontsize=14, fontweight='bold', alpha=0.6, style='italic',
                       bbox=dict(boxstyle="round,pad=0.3", facecolor='lightcoral', alpha=0.3))

                # Hide ticks but keep the styled background
                ax.set_xticks([])
                ax.set_xticklabels([])
                ax.set_xlabel('')

                # Hide year axis for insufficient data plots
                for child_ax in ax.figure.axes:
                    if hasattr(child_ax, 'get_shared_x_axes'):
                        shared_axes = child_ax.get_shared_x_axes().get_siblings(ax)
                        if child_ax in shared_axes and child_ax != ax:
                            child_ax.set_xticks([])
                            child_ax.set_xticklabels([])
                            child_ax.set_xlabel('')

    # Enhanced legend styling
    g.add_legend(title='Country', bbox_to_anchor=(1.05, 1), loc='upper left',
                borderaxespad=0., frameon=True, shadow=True, fancybox=True)

    # Style the legend
    legend = g._legend
    if legend:
        legend.set_title('Country', prop={'size': 12, 'weight': 'bold'})
        legend.get_frame().set_facecolor('white')
        legend.get_frame().set_alpha(0.9)
        legend.get_frame().set_edgecolor('darkgray')
        legend.get_frame().set_linewidth(1.5)

    plt.tight_layout()
    return g

In [None]:
my_plt = plot_seasonal_activity_by_order(activity_mean)
plt.suptitle('🦋 Seasonal Activity Patterns: Crops per Night by Country 🦋',
             fontsize=20, fontweight='bold', y=0.98, color='darkslategray')
plt.gcf().patch.set_facecolor('#f8f9fa')  # Light background for the figure
plt.savefig(os.path.join(plot_dir, 'seasonal_activity_by_order_country.png'),
            dpi=300, bbox_inches='tight', facecolor='white')
plt.show()

In [None]:
my_plt = plot_seasonal_activity_by_order(activity_mean, column='rolling_count_5')
plt.suptitle('🦋 Seasonal Activity Patterns: 5-Night Rolling Average 🦋\nSmoothed Trends Across Countries',
             fontsize=18, fontweight='bold', y=0.98, color='darkslategray')
plt.gcf().patch.set_facecolor('#f8f9fa')
plt.savefig(os.path.join(plot_dir, 'rolling_seasonal_activity_by_order_country_5_night.png'),
            dpi=300, bbox_inches='tight', facecolor='white')
plt.show()

In [None]:
my_plt = plot_seasonal_activity_by_order(activity_mean, column='rolling_count_7')
plt.suptitle('🦋 Seasonal Activity Patterns: 7-Night Rolling Average 🦋\nWeekly Trends Across Countries',
             fontsize=18, fontweight='bold', y=0.98, color='darkslategray')
plt.gcf().patch.set_facecolor('#f8f9fa')
plt.savefig(os.path.join(plot_dir, 'rolling_seasonal_activity_by_order_country_7_night.png'),
            dpi=300, bbox_inches='tight', facecolor='white')
plt.show()

# Activity by Latitude

# Total Species

In [None]:
# List all unique top_1_species in inference_csvs
import pandas as pd
from tqdm import tqdm

unique_species = set()

for c in tqdm(inference_csvs, desc='Extracting top_1_species'):
    try:
        df = pd.read_csv(c, usecols=['top_1_species'], low_memory=False, on_bad_lines='skip')
        unique_species.update(df['top_1_species'].dropna().unique())
    except Exception as e:
        print(f" - Error reading {c}: {e}")

unique_species = sorted(unique_species)
print(f"Found {len(unique_species)} unique top_1_species.")
unique_species[:20]  # Show first 20 as a preview

In [None]:
print(f"Found {len(unique_species)} unique top_1_species.")
# unique_species[:20]  # Show first 20 as a preview

In [None]:
# Get number of unique top_*_species for all inference_csvs
import pandas as pd
from tqdm import tqdm

def analyze_all_top_species(inference_csvs):
    """
    Analyze all top_*_species columns across all inference CSVs
    Returns counts of unique species for each top_* column
    """

    # Initialize sets for each top_species column
    species_columns = ['top_1_species', 'top_2_species', 'top_3_species', 'top_4_species', 'top_5_species']
    unique_species_dict = {col: set() for col in species_columns}

    # Track which columns are actually present
    available_columns = set()

    for c in tqdm(inference_csvs, desc='Analyzing top_*_species across all CSVs'):
        try:
            # First check what columns are available in this CSV
            df_sample = pd.read_csv(c, nrows=1, low_memory=False, on_bad_lines='skip')
            file_columns = [col for col in species_columns if col in df_sample.columns]
            available_columns.update(file_columns)

            if not file_columns:
                continue

            # Read only the available species columns
            df = pd.read_csv(c, usecols=file_columns, low_memory=False, on_bad_lines='skip')

            # Add unique species from each column to respective sets
            for col in file_columns:
                if col in df.columns:
                    unique_species_dict[col].update(df[col].dropna().unique())

        except Exception as e:
            print(f" - Error reading {c}: {e}")
            continue

    # Convert sets to sorted lists and count
    results = {}
    for col in species_columns:
        if col in available_columns:
            unique_list = sorted(unique_species_dict[col])
            results[col] = {
                'count': len(unique_list),
                'species_list': unique_list
            }
        else:
            results[col] = {
                'count': 0,
                'species_list': []
            }

    return results, available_columns

# Run the analysis
species_results, available_cols = analyze_all_top_species(inference_csvs)

print("=== TOP SPECIES ANALYSIS SUMMARY ===")
print(f"Available species columns: {sorted(available_cols)}")
print()

for col in ['top_1_species', 'top_2_species', 'top_3_species', 'top_4_species', 'top_5_species']:
    count = species_results[col]['count']
    if count > 0:
        print(f"{col}: {count:,} unique species")
    else:
        print(f"{col}: Not available in CSV files")

print()
print("=== DETAILED BREAKDOWN ===")

In [None]:
# Display detailed results for each top_species column
for col in ['top_1_species', 'top_2_species', 'top_3_species', 'top_4_species', 'top_5_species']:
    if species_results[col]['count'] > 0:
        print(f"\n--- {col.upper()} ---")
        print(f"Total unique species: {species_results[col]['count']:,}")

        # Show first 10 and last 10 species as examples
        species_list = species_results[col]['species_list']
        if len(species_list) <= 20:
            print("All species:", species_list)
        else:
            print("First 10 species:", species_list[:10])
            print("Last 10 species:", species_list[-10:])
            print("... (middle species omitted) ...")
    else:
        print(f"\n--- {col.upper()} ---")
        print("No data available")

print(f"\n=== SUMMARY ===")
total_unique_across_all = set()
for col in species_results:
    total_unique_across_all.update(species_results[col]['species_list'])

print(f"Total unique species across ALL top_*_species columns: {len(total_unique_across_all):,}")
print(f"Files processed: {len(inference_csvs)}")
print(f"Available species ranking columns: {len(available_cols)}")

In [None]:
# Create a comparative analysis visualization
import matplotlib.pyplot as plt
import seaborn as sns

def plot_species_count_comparison(species_results):
    """
    Create a bar plot comparing unique species counts across top_*_species columns
    """
    # Prepare data for plotting
    columns = []
    counts = []

    for col in ['top_1_species', 'top_2_species', 'top_3_species', 'top_4_species', 'top_5_species']:
        if species_results[col]['count'] > 0:
            columns.append(col.replace('_', ' ').title())
            counts.append(species_results[col]['count'])

    if not columns:
        print("No data available for plotting")
        return

    # Create the plot
    plt.figure(figsize=(12, 6))

    # Use the existing color palette
    bars = plt.bar(columns, counts, color=modern_palette[:len(columns)], alpha=0.8, edgecolor='darkgray', linewidth=1.5)

    # Customize the plot
    plt.title('Number of Unique Species by Ranking Position\nAcross All Inference CSV Files',
              fontsize=16, fontweight='bold', color=web_colors['text'], pad=20)
    plt.xlabel('Species Ranking Column', fontsize=14, fontweight='bold')
    plt.ylabel('Number of Unique Species', fontsize=14, fontweight='bold')

    # Add value labels on bars
    for bar, count in zip(bars, counts):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(counts)*0.01,
                f'{count:,}', ha='center', va='bottom', fontsize=12, fontweight='bold')

    # Format y-axis with commas
    plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{int(x):,}'))

    # Style the plot
    plt.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)
    plt.gca().set_facecolor(web_colors['background'])
    sns.despine()

    plt.tight_layout()
    plt.show()

    # Save the plot
    plt.savefig(os.path.join(plot_dir, 'unique_species_by_ranking.png'),
                dpi=300, bbox_inches='tight', facecolor='white')

# Create the visualization
plot_species_count_comparison(species_results)