# Data visualization
### This notebook contains the code necessary to re-create the visualizations we present in our paper.
### All the visualization were done in Python. The structure of this notebook is as follows:
1. Importing necessary libraries. Please see *visualization_requirements.txt* for version information.
2. Creating the helper functions for our queries to the database.
3. A markdown cell explaining the next block of cells.
4. A cell containing an SQL query that collects aggregated data and writes them to an excel file
5. Followed with cell(s) containing a visualization of the data that was collected. These cells are self-contained and can be run in any order.
6. Empty cell and go to step 3.

### Extra files needed are "faostat_production.xlsx" and "grouped_food_items.xlsx".

Author: Osman Mutlu and Nehir Kızılilsoley\
Edited by: Osman Mutlu

### Import necessary libraries

In [None]:
import psycopg2
import numpy as np
import pandas as pd

import os
from dotenv import load_dotenv

import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.colorbar as colorbar
import matplotlib.colors as mcolors
import matplotlib.ticker as mticker
from matplotlib.ticker import LogLocator
from matplotlib.patches import Patch
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.cm as cm

from d3blocks import D3Blocks

### Setup connection to database and helper functions

In [None]:
# Class that connects to the PostgreSQL database with helper functions to execute SQL queries
psycopg2.extensions.register_adapter(np.int64, psycopg2._psycopg.AsIs)

class PostgresDatabase:
    def __init__(self, db_host, db_port, db_name, db_user, db_password):
        self.connection = psycopg2.connect(user = db_user,
                                           password = db_password,
                                           host = db_host,
                                           port = db_port,
                                           database = db_name)
        self.cursor = self.connection.cursor()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()

    def commit(self):
        self.connection.commit()

    def rollback(self):
        self.connection.rollback()

    def close(self, commit=False):
        if commit:
            self.commit()

        self.cursor.close()
        self.connection.close()

    def execute(self, sql: str, params=None):
        self.cursor.execute(sql, params or ())

    def fetchall(self):
        return self.cursor.fetchall()

    def fetchone(self):
        return self.cursor.fetchone()

    def query(self, sql: str, params=None):
        self.cursor.execute(sql, params or ())
        return self.fetchall()

    def querydf(self, sql: str, params=None) -> pd.DataFrame:
        if params:
            sql = self.cursor.mogrify(sql, params).decode()

        return pd.read_sql_query(sql, self.connection)

In [None]:
# Load necessary variables stored in .env file for database connection
load_dotenv()

# Or assign them yourself
# DB_HOST = ""
# DB_PORT = ""
# DB_NAME = ""
# DB_USER = ""
# DB_PASSWORD = ""

### Total and positive number of measurements per contaminant (parameter).

In [None]:
# Contaminant measurements

with PostgresDatabase(os.getenv("DB_HOST"), os.getenv("DB_PORT"), os.getenv("DB_NAME"), os.getenv("DB_USERNAME"), os.getenv("DB_PASSWORD")) as db:
    out = db.query("select termextendedname as name, masterhierarchycode as code from ontologies_efsa.param")
    code_to_name = {code: name for (name, code) in out}
    code_to_full_name = {}
    for (_, code) in out:
        if code is not None:
            split_code = code.split(".")
            code_to_full_name[code] = "::".join([code_to_name[".".join(split_code[:i+1])] for i in range(len(split_code))])

    query = """
    SELECT type, paramid, termextendedname AS param_name, masterhierarchycode AS full_name, total_measurements, positive_measurements
    FROM (
        SELECT filetype AS type, param_id AS paramid, COUNT(*) as total_measurements, COUNT(1) FILTER (WHERE evalcode_id IN (2, 9, 11, 14)) AS positive_measurements
        FROM efsa.measurement_core
        GROUP BY type, paramid
    )
    LEFT JOIN ontologies_efsa.param
    ON paramid=ontologies_efsa.param.id
    """
    df = db.querydf(query)
    df["full_name"] = df.full_name.apply(lambda x: code_to_full_name[x])
    
    df.to_excel("contaminant_measurements.xlsx", index=False)

In [None]:
# Read the Excel file
contaminant_measurements = pd.read_excel("contaminant_measurements.xlsx")


# Separate 'full_name' into multiple columns (Lev1 to Lev8)
df_separated = contaminant_measurements['full_name'].str.split('::', expand=True)
df_separated.columns = [f'Lev{i+1}' for i in range(df_separated.shape[1])]
df_separated = pd.concat([contaminant_measurements, df_separated], axis=1)
df_separated['category'] = np.where(df_separated['Lev1'] == 'nutrients', df_separated['Lev3'], df_separated['Lev1'])


# Summarize data by Lev1 category
df_summary = df_separated.groupby('category').agg(
    total_measurements=('total_measurements', 'sum'),
    positive_measurements=('positive_measurements', 'sum')
)

df_summary['proportion_positive'] = df_summary['positive_measurements'] *100 / df_summary['total_measurements']
df_summary = df_summary.reset_index().sort_values(by='total_measurements', ascending=False)


df_summary['category'] = df_summary['category'].str.title()
df_summary_category = df_summary[
    ~df_summary['category'].str.lower().isin(['terms used for grouping purposes', 'oft term', 'not in list', 'microorganisms'])
].copy()
df_summary_category['category'] = df_summary_category['category'].str.replace(r'Persistent Organic Pollutants \(Pops\) And Other Organic Contaminants', 'POPs and Other Organic Contaminants', regex=True)
df_summary_category['category'] = df_summary_category['category'].str.replace(r'Chemical Elements \(Including Derivatives\) And Others', 'Chemical Elements and Derivatives', regex=True)



##### Lev1 #####

pal  = "Reds"
plot_dat = df_summary_category.copy()

# --- Log-scale size transformation ---
plot_dat['log_size'] = np.log10(plot_dat['total_measurements'])

# Min & max for log scale
log_min = plot_dat['log_size'].min()
log_max = plot_dat['log_size'].max()

# --- Generate log-spaced legend values ---
raw_min = plot_dat['total_measurements'].min()
raw_max = plot_dat['total_measurements'].max()

# Define nice, generic log-scale values
size_legend_values = [10**3, 10**4, 10**5, 10**6, 10**7, 10**8]  

# Filter to values within your actual data range
size_legend_values = [v for v in size_legend_values if raw_min <= v <= raw_max]

# Format labels nicely
labels = [f"{int(v):,}" for v in size_legend_values]


# --- Scale function for circle sizes ---
def scale_size(val, vmin, vmax, smin=50, smax=400):
    return smin + (val - vmin) / (vmax - vmin) * (smax - smin)

# --- Create size legend handles ---
legend_handles = [
    plt.scatter([], [], 
                s=scale_size(np.log10(val), log_min, log_max), 
                label=label,
                color='gray', 
                alpha=0.6, 
                edgecolors='black')
    for val, label in zip(size_legend_values, labels)
]

# --- Plot separate size legend ---
plt.figure(figsize=(3, 3))
plt.legend(
    handles=legend_handles, 
    title="Total Measurements", 
    scatterpoints=1, 
    frameon=False,
    labelspacing=1.2, 
    loc='center'
)
plt.axis('off')
plt.tight_layout()
plt.savefig('contaminant_Lev1_legend.png', dpi=600, bbox_inches='tight')
plt.show()



# Create a dummy colorbar using the same colormap and normalization
fig, ax = plt.subplots(figsize=(0.5, 6))
fig.subplots_adjust(bottom=0.5)

norm = mcolors.Normalize(
    vmin=plot_dat['proportion_positive'].min(), 
    vmax=plot_dat['proportion_positive'].max()
)
cbar = colorbar.ColorbarBase(
    ax, cmap=pal, norm=norm, orientation='vertical'
)
cbar.set_label('Measurements above limits (%)')
plt.savefig('contaminant_Lev1_cbar.png', dpi=600, bbox_inches='tight')
plt.show()


# --- Plot main bubble chart ---
plt.figure(figsize=(6, 5))
plt.grid(True, zorder=0, alpha = 0.5)

sns.scatterplot(
    data=plot_dat,
    x='proportion_positive', 
    y='category', 
    hue='proportion_positive',
    size='log_size',
    sizes=(50, 400), 
    palette=pal,
    zorder=3,
    edgecolor='black',  
    linewidth=0.7,
    legend=False
)
# Hardcoded text annotations
annotations = [
    {"text": "297 Million", "x": 0.12, "y": "Pesticides"},
    {"text": "75 Million", "x": 0.1, "y": "Pharmacologically Active Substances"},
    {"text": "17 Million", "x": 0.22, "y": "POPs and Other Organic Contaminants"},
    {"text": "1.6 Million", "x": 0.68, "y": "Toxins"},
    {"text": "1 Million", "x": 0.85, "y": "Chemical Elements and Derivatives"},
    {"text": "0.3 Million", "x": 0.8, "y": "Process Contaminants"},
    {"text": "59,179", "x": 0.15, "y": "Botanicals"},
    {"text": "13,574", "x": 0.27, "y": "Erucic Acid"},
    {"text": "5,663", "x": 0.65, "y": "Food Contact Materials"},
    {"text": "4,045", "x": 0.35, "y": "Food Additives"},
]


for ann in annotations:
    plt.text(
        ann["x"], 
        ann["y"], 
        ann["text"], 
        ha='left', 
        va='center',
        fontsize=9,
        color='dimgray',
        alpha=1
    )


plt.title('The first level of contaminant hierarchies')
plt.xlabel('Measurements above limits (%)')  # or just rename to 'Proportion Positive (%)' if more accurate
plt.ylabel('')
plt.tight_layout()
plt.savefig('contaminant_Lev1.png', dpi=600, bbox_inches='tight')
plt.show()



##### Lev2 #####

toxins = df_separated[df_separated['Lev1'].str.lower() == 'toxins'].groupby('Lev2').agg(
    total_measurements=('total_measurements', 'sum'),
    positive_measurements=('positive_measurements', 'sum')
).reset_index()

toxins['proportion_positive'] = toxins['positive_measurements'] / toxins['total_measurements']
toxins = toxins.reset_index().sort_values(by='total_measurements', ascending=False)
toxins['Lev2'] = toxins['Lev2'].str.title()



# Create a bubble plot with the legend outside and grid enabled
plt.figure(figsize=(8, 7))
plt.grid(True, zorder=0)
sns.scatterplot(
    x='proportion_positive', 
    y='Lev2', 
    size='total_measurements', 
    hue='proportion_positive', 
    sizes=(50, 600), 
    data=toxins,
    palette='coolwarm',
    zorder=3,
    edgecolor='black',  
    linewidth=0.7  
)

# Title and labels
plt.title('Toxins')
plt.xlabel('Proportion of Positive Measurements')
plt.ylabel('Category')

# Move the legend outside the plot
plt.legend(title='', loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.savefig('contaminant_Lev2toxins.png', dpi=600, bbox_inches='tight')
plt.show()


# %%

pops = df_separated[
    df_separated['Lev1'].str.lower() == 'persistent organic pollutants (pops) and other organic contaminants'
].groupby('Lev2').agg(
    total_measurements=('total_measurements', 'sum'),
    positive_measurements=('positive_measurements', 'sum')
).reset_index()

pops['proportion_positive'] = pops['positive_measurements'] / pops['total_measurements']
pops = pops.reset_index().sort_values(by='total_measurements', ascending=False)
pops['Lev2'] = pops['Lev2'].str.title()


# Create a bubble plot with the legend outside and grid enabled
plt.figure(figsize=(8, 7))
plt.grid(True, zorder=0)
sns.scatterplot(
    x='proportion_positive', 
    y='Lev2', 
    size='total_measurements', 
    hue='proportion_positive', 
    sizes=(50, 600), 
    data=pops,
    palette='coolwarm',
    zorder=3,
    edgecolor='black',  
    linewidth=0.7  
)

# Title and labels
plt.title('Persistent Organic Pollutants (POPs) and other organic contaminants')
plt.xlabel('Proportion of Positive Measurements')
plt.ylabel('Category')

# Move the legend outside the plot
plt.legend(title='', loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.savefig('contaminant_Lev2pops.png', dpi=600, bbox_inches='tight')
plt.show()

# %%
heavies = [
    "lead and derivatives", "cadmium and derivatives", "mercury and derivatives",
    "arsenic and derivatives", "nickel and derivatives", "copper and derivatives",
    "zinc and derivatives", "chromium and derivatives", "selenium and derivatives",
    "manganese and derivatives", "tin and derivatives", "thallium and derivatives",
    "barium and derivatives", "cobalt and derivatives", "antimony and derivatives",
    "silver and derivatives", "beryllium and derivatives", "vanadium and derivatives",
    "uranium and derivatives", "thorium and derivatives"
]

chems = df_separated[
    df_separated['Lev1'].str.lower() == 'chemical elements (including derivatives) and others'
]

chems = chems[chems['Lev2'].isin(heavies)].groupby('Lev2').agg(
    total_measurements=('total_measurements', 'sum'),
    positive_measurements=('positive_measurements', 'sum')
).reset_index()


chems['proportion_positive'] = chems['positive_measurements'] / chems['total_measurements']
chems = chems.reset_index().sort_values(by='total_measurements', ascending=False)
chems['Lev2'] = chems['Lev2'].str.title()


# Create a bubble plot with the legend outside and grid enabled
plt.figure(figsize=(8, 7))
plt.grid(True, zorder=0)
sns.scatterplot(
    x='proportion_positive', 
    y='Lev2', 
    size='total_measurements', 
    hue='proportion_positive', 
    sizes=(50, 600), 
    data=chems,
    palette='coolwarm',
    zorder=3,
    edgecolor='black',  
    linewidth=0.7  
)

# Title and labels
plt.title('Chemical Elements (Including Derivatives) and Others')
plt.xlabel('Proportion of Positive Measurements')
plt.ylabel('Category')

# Move the legend outside the plot
plt.legend(title='', loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.savefig('contaminant_Lev2chems.png', dpi=600, bbox_inches='tight')
plt.show()

In [None]:
contamin = pd.read_excel("contaminant_measurements.xlsx")

# Summarize the data
summ = (
    contamin.groupby(["type", "param_name"])
    .agg(
        total_meas=("total_measurements", "sum"),
        pos_meas=("positive_measurements", "sum"),
    )
    .reset_index()
)

# Add percentage column
summ["percentage"] = (summ["pos_meas"] / summ["total_meas"]) * 100
title_dict = ["Chemical contaminants", "Pesticide residues", "VMPRs"]

formatter0 = mticker.ScalarFormatter(useMathText=True)
formatter0.set_scientific(True)
formatter0.set_powerlimits((-2, 2))

formatter1 = mticker.ScalarFormatter(useMathText=True)
formatter1.set_scientific(True)
formatter1.set_powerlimits((-2, 2))


#%%
def plot_top_contaminants(contamin, ind, context="paper", style="ticks", figsize=(7, 3)):

    # Filter only VMPR and select top 10 contaminants
    var = pd.unique(summ['type'])[ind] #array(['chemical', 'pesticides', 'vmpr'], dtype=object)

    df = (
        summ[summ["type"] == var]
        .nlargest(10, "total_meas")  # Select top 10 by total_meas
        .copy()
    )
    df['param_name'] = df['param_name'].str.title()

    # Set style and context
    sns.set_style(style)
    sns.set_context(context)

    # Set color palette
    pal = sns.color_palette("tab10")
    color_total = pal[ind]
    color_pos = pal[ind]

    # Create figure with two subplots
    fig, axes = plt.subplots(1, 2, figsize=figsize, sharey=True)

    # Total Measurements Bar Plot
    sns.barplot(
        data=df,
        x="total_meas",
        y="param_name",
        color=color_total,
        edgecolor="white",
        linewidth=1.5,
        ax=axes[0]
    )

    axes[0].set_title("Total measurements", fontsize=11)
    axes[0].set_xlabel("Count", fontsize=9)
    axes[0].set_ylabel(title_dict[ind], fontsize=9)
    axes[0].tick_params(axis='both', which='major', labelsize=9)
    axes[0].xaxis.set_major_formatter(formatter0)
    axes[0].grid(axis='x', linestyle='--', alpha=0.7)

    # Positive Measurements Bar Plot
    sns.barplot(
        data=df,
        x="pos_meas",
        y="param_name",
        color=color_pos,
        edgecolor="white",
        linewidth=1.5,
        ax=axes[1]
    )

    axes[1].set_title("Measurements above limits", fontsize=11)
    axes[1].set_xlabel("Count", fontsize=9)
    axes[1].set_ylabel("")
    axes[1].tick_params(axis='both', which='major', labelsize=9)
    axes[1].grid(axis='x', linestyle='--', alpha=0.7)
    axes[1].xaxis.set_major_formatter(formatter1)


    sns.despine()

    # Adjust layout
    plt.tight_layout()
    plt.savefig(f'contaminant_stats_bar_{var}.png', dpi=600, bbox_inches='tight')
    plt.show()


for x in range(3):
    plot_top_contaminants(contamin, x)



def plot_all_top_contaminants(summ, title_dict):
    sns.set_style("ticks")
    sns.set_context("paper")

    pal = sns.color_palette("tab10")
    types = ['chemical', 'pesticides', 'vmpr']

    fig, axes = plt.subplots(3, 2, figsize=(7, 8), sharey='row')

    for i, var in enumerate(types):
        df = (
            summ[summ["type"] == var]
            .nlargest(10, "total_meas")
            .copy()
        )
        df['param_name'] = df['param_name'].str.title()

        color = pal[i]
        formatter0 = mticker.ScalarFormatter(useMathText=True)
        formatter0.set_scientific(True)
        formatter0.set_powerlimits((-2, 2))
        formatter1 = mticker.ScalarFormatter(useMathText=True)
        formatter1.set_scientific(True)
        formatter1.set_powerlimits((-2, 2))
        # Total measurements
        sns.barplot(
            data=df,
            x="total_meas",
            y="param_name",
            color=color,
            edgecolor="white",
            linewidth=1.5,
            ax=axes[i, 0]
        )
        axes[i, 0].set_title("Total measurements", fontsize=9, weight='bold')
        axes[i, 0].set_ylabel(title_dict[i], fontsize=10, weight='bold')
        axes[i, 0].set_xlabel("")
        axes[i, 0].xaxis.set_major_formatter(formatter0)
        axes[i, 0].tick_params(labelsize=8)
        axes[i, 0].grid(axis='x', linestyle='--', alpha=0.5)

        # Add percentage labels to the end of each bar
        for j, (val, total) in enumerate(zip(df['pos_meas'], df['total_meas'])):
            pct = val / total * 100 if total > 0 else 0
            axes[i, 1].text(
                val + 0.01 * df['pos_meas'].max(),  # x-pos
                j,                                  # y-pos
                f"{pct:.2f}%",                      
                va='center',
                ha='left',
                fontsize=7,
                color='black'
            )

        # Positive measurements
        sns.barplot(
            data=df,
            x="pos_meas",
            y="param_name",
            color=color,
            edgecolor="white",
            linewidth=1.5,
            ax=axes[i, 1]
        )
        axes[i, 1].set_title("Measurements above limits", fontsize=9, weight='bold')
        axes[i, 1].set_ylabel("")
        axes[i, 1].set_xlabel("")
        axes[i, 1].xaxis.set_major_formatter(formatter1)
        axes[i, 1].tick_params(labelsize=8)
        axes[i, 1].grid(axis='x', linestyle='--', alpha=0.5)

    sns.despine()
    fig.align_ylabels()
    plt.tight_layout()
    plt.savefig("combined_contaminants_barplot.png", dpi=1200, bbox_inches='tight')
    plt.show()

plot_all_top_contaminants(summ, title_dict)

### Total and positive number of contaminants per product per sampling strategy (either all, per year, per sampling country, or per origin country)

In [None]:
# Contaminants per product including sampling strategy

with PostgresDatabase(os.getenv("DB_HOST"), os.getenv("DB_PORT"), os.getenv("DB_NAME"), os.getenv("DB_USERNAME"), os.getenv("DB_PASSWORD")) as db:
    # Collect product info
    query = """
    SELECT t1.id AS product_id, t2.termextendedname AS matrix_product_name, t2.prodclasshierarchycode AS matrix_hierarchy,
        t3.termextendedname AS mtx_product_name, t3.masterhierarchycode AS mtx_hierarchy
    FROM ontologies_efsa.product_catalogue AS t1
    LEFT JOIN ontologies_efsa.matrix AS t2
        ON t1.matrix_id=t2.id
    LEFT JOIN ontologies_efsa.mtx AS t3
        ON t1.mtx_id=t3.id
    """
    product_df = db.querydf(query)
    # matrix
    out = db.query("select termextendedname as name, prodclasshierarchycode as code from ontologies_efsa.matrix")
    code_to_name = {code: name for (name, code) in out}
    code_to_full_name = {}
    for (_, code) in out:
        if code is not None:
            split_code = code.split(".")
            code_to_full_name[code] = "matrix::" + "::".join([code_to_name[".".join(split_code[:i+1])] for i in range(len(split_code))])
    product_df["matrix_full_name"] = product_df.matrix_hierarchy.apply(lambda x: code_to_full_name[x] if x is not None else None)

    # mtx
    out = db.query("select termextendedname as name, masterhierarchycode as code from ontologies_efsa.mtx")
    code_to_name = {code: name for (name, code) in out}
    code_to_full_name = {}
    for (_, code) in out:
        if code is not None:
            split_code = code.split(".")
            code_to_full_name[code] = "mtx::" + "::".join([code_to_name[".".join(split_code[:i+1])] for i in range(len(split_code))])

    product_df["mtx_full_name"] = product_df.mtx_hierarchy.apply(lambda x: code_to_full_name[x] if x is not None else None)

    # combine
    product_df["full_name"] = product_df.matrix_full_name.combine_first(product_df.mtx_full_name)
    product_df.matrix_product_name = "matrix::" + product_df.matrix_product_name
    product_df.mtx_product_name = "mtx::" + product_df.mtx_product_name
    product_df["product_name"] = product_df.matrix_product_name.combine_first(product_df.mtx_product_name)
    product_df.loc[product_df.full_name.isna(), "full_name"] = product_df.loc[product_df.full_name.isna(), "product_name"]
    product_df = product_df[["product_id", "product_name", "full_name"]]


    query = """
    SELECT t3.termextendedname AS samp_strategy, t2.product_id AS product_id, t1.filetype AS type, COUNT(*) AS total_measurements, COUNT(1) FILTER (WHERE t1.evalcode_id IN (2, 9, 11, 14)) AS positive_measurements
    FROM efsa.measurement_core AS t1
    LEFT JOIN efsa.sample_core AS t2
        ON t1.sample_core_id=t2.id
    LEFT JOIN ontologies_efsa.sampstr as t3
        ON t2.sampstr_id=t3.id
    GROUP BY samp_strategy, product_id, type
    """
    df = db.querydf(query)
    df = df.merge(product_df, on="product_id", how="left")

    query = """
    SELECT t3.termextendedname AS samp_strategy, t2.product_id AS product_id, t1.filetype AS type, t2.sampy AS year, COUNT(*) AS total_measurements,
        COUNT(1) FILTER (WHERE t1.evalcode_id IN (2, 9, 11, 14)) AS positive_measurements
    FROM efsa.measurement_core AS t1
    LEFT JOIN efsa.sample_core AS t2
        ON t1.sample_core_id=t2.id
    LEFT JOIN ontologies_efsa.sampstr as t3
        ON t2.sampstr_id=t3.id
    GROUP BY samp_strategy, product_id, type, year
    """
    yearly_df = db.querydf(query)
    yearly_df = yearly_df.merge(product_df, on="product_id", how="left")

    query = """
    SELECT t4.termextendedname AS samp_strategy, t2.product_id AS product_id, t1.filetype AS type, t3.termextendedname AS samp_country,
        COUNT(*) AS total_measurements, COUNT(1) FILTER (WHERE t1.evalcode_id IN (2, 9, 11, 14)) AS positive_measurements
    FROM efsa.measurement_core AS t1
    LEFT JOIN efsa.sample_core AS t2
        ON t1.sample_core_id=t2.id
    LEFT JOIN ontologies_efsa.country AS t3
        ON t2.sampcountry_id=t3.id
    LEFT JOIN ontologies_efsa.sampstr as t4
        ON t2.sampstr_id=t4.id
    GROUP BY samp_country, samp_strategy, product_id, type
    """
    samp_df = db.querydf(query)
    samp_df = samp_df.merge(product_df, on="product_id", how="left")

    query = """
    SELECT t4.termextendedname AS samp_strategy, t2.product_id AS product_id, t1.filetype AS type, t3.termextendedname AS orig_country,
        COUNT(*) AS total_measurements, COUNT(1) FILTER (WHERE t1.evalcode_id IN (2, 9, 11, 14)) AS positive_measurements
    FROM efsa.measurement_core AS t1
    LEFT JOIN efsa.sample_core AS t2
        ON t1.sample_core_id=t2.id
    LEFT JOIN ontologies_efsa.country AS t3
        ON t2.origcountry_id=t3.id
    LEFT JOIN ontologies_efsa.sampstr as t4
        ON t2.sampstr_id=t4.id
    GROUP BY orig_country, samp_strategy, product_id, type
    """
    orig_df = db.querydf(query)
    orig_df = orig_df.merge(product_df, on="product_id", how="left")

    with pd.ExcelWriter('sampling_strategy_based_positive_contaminants_per_product.xlsx') as writer:
        df.to_excel(writer, sheet_name='distribution', index=False)
        yearly_df.to_excel(writer, sheet_name='yearly distribution', index=False)
        samp_df.to_excel(writer, sheet_name='samp country distribution', index=False)
        orig_df.to_excel(writer, sheet_name='orig country distribution', index=False)

In [None]:
# Load and clean
country_dist = pd.read_excel("sampling_strategy_based_positive_contaminants_per_product.xlsx",
                             sheet_name="samp country distribution")
country_dist["samp_country"] = country_dist["samp_country"].str.title()

# Consistent strategy order
strategy_order = ['objective sampling', 'selective sampling', 'suspect sampling', 
                  'convenient sampling', 'other']

# Types and color palettes
types = ["chemical", "pesticides", "vmpr"]
type_titles = ["Chemical contaminants", "Pesticide residues", "VMPRs"]

# Dictionary to store all plot-ready data
plot_data = {}

for type_val in types:
    # --- 1. Filter data by type
    df_type = country_dist[country_dist["type"] == type_val].copy()

    # --- 2. Get total measurements per country
    country_totals = df_type.groupby("samp_country")["total_measurements"].sum()
    top_countries = country_totals.sort_values(ascending=False).head(15).index.tolist()

    # --- 3. Subset to top countries
    df_type_top = df_type[df_type["samp_country"].isin(top_countries)].copy()

    # --- 4. Aggregate total and positive measurements per country
    totals_df = df_type_top.groupby("samp_country").agg(
        total_meas=("total_measurements", "sum"),
        pos_meas=("positive_measurements", "sum")
    )

    # --- 5. Get sampling strategy totals per country
    strat_df = df_type_top.groupby(["samp_country", "samp_strategy"])["total_measurements"].sum().unstack(fill_value=0)
    strat_df = strat_df.reindex(columns=strategy_order, fill_value=0)

    # --- 6. Normalize strategy counts to 100% per country
    strat_normalized = strat_df.div(strat_df.sum(axis=1), axis=0)

    # --- 7. Combine everything
    combined = totals_df.join(strat_normalized)
    combined = combined.loc[top_countries]  # enforce order

    # --- 8. Store in dictionary for plotting
    plot_data[type_val] = {
        "data": combined,
        "title": type_titles[types.index(type_val)]
    }

# === Setup
fig, axes = plt.subplots(3, 3, figsize=(13, 10), sharey=False)
pal2 = sns.color_palette("tab10")

# Color mapping for strategies (must match order)
strategy_order = ['objective sampling', 'selective sampling', 'suspect sampling', 
                  'convenient sampling', 'other']
strategy_colors = {
    'objective sampling': '#1f77b4',
    'selective sampling': '#ff7f0e',
    'suspect sampling': '#9467bd',
    'convenient sampling': '#2ca02c',
    'other': '#8c564b'
}

# Formatters
formatter0 = mticker.ScalarFormatter(useMathText=True)
formatter0.set_scientific(True)
formatter0.set_powerlimits((-2, 2))

formatter1 = mticker.ScalarFormatter(useMathText=True)
formatter1.set_scientific(True)
formatter1.set_powerlimits((-2, 2))

# === Loop over types
for i, (type_val, content) in enumerate(plot_data.items()):
    df = content["data"]
    title = content["title"]
    countries = df.index.tolist()

    # --- Column 1: Total measurements
    sns.barplot(
        y=countries,
        x="total_meas",
        data=df.reset_index(),
        color=pal2[i],
        edgecolor="white",
        linewidth=1.5,
        ax=axes[i, 0]
    )
    axes[i, 0].set_title(f"{title} - Total measurements", fontsize=9, weight='bold')
    axes[i, 0].set_ylabel("")
    axes[i, 0].set_xlabel("")
    axes[i, 0].xaxis.set_major_formatter(formatter0)
    axes[i, 0].tick_params(labelsize=8)
    axes[i, 0].grid(axis='x', linestyle='--', alpha=0.5)

    # --- Column 2: Positive measurements with %
    sns.barplot(
        y=countries,
        x="pos_meas",
        data=df.reset_index(),
        color=pal2[i],
        edgecolor="white",
        linewidth=1.5,
        ax=axes[i, 1]
    )
    for j, (val, total) in enumerate(zip(df["pos_meas"], df["total_meas"])):
        pct = val / total * 100 if total > 0 else 0
        axes[i, 1].text(
            val + 0.01 * df["pos_meas"].max(),
            j,
            f"{pct:.2f}%",
            va='center',
            ha='left',
            fontsize=8,
            color='black'
        )
    axes[i, 1].set_title("Measurements above limits", fontsize=9, weight='bold')
    axes[i, 1].set_ylabel("")
    axes[i, 1].set_xlabel("")
    axes[i, 1].xaxis.set_major_formatter(formatter1)
    axes[i, 1].tick_params(labelsize=8)
    axes[i, 1].grid(axis='x', linestyle='--', alpha=0.5)

    # --- Column 3: 100% stacked sampling strategy
    df_reversed = df.iloc[::-1]
    bottom = np.zeros(len(df_reversed))
    for strat in strategy_order:
        axes[i, 2].barh(
            df_reversed.index,
            df_reversed[strat],
            left=bottom,
            height=0.8,
            label=strat,
            color=strategy_colors[strat]
        )
        bottom += df_reversed[strat].values

    axes[i, 2].set_title("Sampling Strategy Share", fontsize=9, weight='bold')
    axes[i, 2].set_xlabel("")
    axes[i, 2].set_ylabel("")
    axes[i, 2].tick_params(labelsize=8)
    axes[i, 2].grid(axis='x', linestyle='--', alpha=0.5)
    if i == 0:
        axes[i, 2].legend(title='Strategy', bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)

# === Final formatting
sns.despine()
plt.tight_layout()
# plt.savefig("final_typewise_top15_plot.png", dpi=1200, bbox_inches='tight')
plt.show()

### Total and positive number of contaminants per product (either all, per year, per sampling country, or per origin country)

In [None]:
# Contaminants per product

with PostgresDatabase(os.getenv("DB_HOST"), os.getenv("DB_PORT"), os.getenv("DB_NAME"), os.getenv("DB_USERNAME"), os.getenv("DB_PASSWORD")) as db:
    # Collect product info
    query = """
    SELECT t1.id AS product_id, t2.termextendedname AS matrix_product_name, t2.prodclasshierarchycode AS matrix_hierarchy,
        t3.termextendedname AS mtx_product_name, t3.masterhierarchycode AS mtx_hierarchy
    FROM ontologies_efsa.product_catalogue AS t1
    LEFT JOIN ontologies_efsa.matrix AS t2
        ON t1.matrix_id=t2.id
    LEFT JOIN ontologies_efsa.mtx AS t3
        ON t1.mtx_id=t3.id
    """
    product_df = db.querydf(query)
    # matrix
    out = db.query("select termextendedname as name, prodclasshierarchycode as code from ontologies_efsa.matrix")
    code_to_name = {code: name for (name, code) in out}
    code_to_full_name = {}
    for (_, code) in out:
        if code is not None:
            split_code = code.split(".")
            code_to_full_name[code] = "matrix::" + "::".join([code_to_name[".".join(split_code[:i+1])] for i in range(len(split_code))])
    product_df["matrix_full_name"] = product_df.matrix_hierarchy.apply(lambda x: code_to_full_name[x] if x is not None else None)

    # mtx
    out = db.query("select termextendedname as name, masterhierarchycode as code from ontologies_efsa.mtx")
    code_to_name = {code: name for (name, code) in out}
    code_to_full_name = {}
    for (_, code) in out:
        if code is not None:
            split_code = code.split(".")
            code_to_full_name[code] = "mtx::" + "::".join([code_to_name[".".join(split_code[:i+1])] for i in range(len(split_code))])

    product_df["mtx_full_name"] = product_df.mtx_hierarchy.apply(lambda x: code_to_full_name[x] if x is not None else None)

    # combine
    product_df["full_name"] = product_df.matrix_full_name.combine_first(product_df.mtx_full_name)
    product_df.matrix_product_name = "matrix::" + product_df.matrix_product_name
    product_df.mtx_product_name = "mtx::" + product_df.mtx_product_name
    product_df["product_name"] = product_df.matrix_product_name.combine_first(product_df.mtx_product_name)
    product_df.loc[product_df.full_name.isna(), "full_name"] = product_df.loc[product_df.full_name.isna(), "product_name"]
    product_df = product_df[["product_id", "product_name", "full_name"]]


    query = """
    SELECT t2.product_id AS product_id, t1.filetype AS type, COUNT(*) AS total_measurements, COUNT(1) FILTER (WHERE t1.evalcode_id IN (2, 9, 11, 14)) AS positive_measurements
    FROM efsa.measurement_core AS t1
    LEFT JOIN efsa.sample_core AS t2
        ON t1.sample_core_id=t2.id
    GROUP BY product_id, type
    """
    df = db.querydf(query)
    df = df.merge(product_df, on="product_id", how="left")

    query = """
    SELECT t2.product_id AS product_id, t1.filetype AS type, t2.sampy AS year, COUNT(*) AS total_measurements,
        COUNT(1) FILTER (WHERE t1.evalcode_id IN (2, 9, 11, 14)) AS positive_measurements
    FROM efsa.measurement_core AS t1
    LEFT JOIN efsa.sample_core AS t2
        ON t1.sample_core_id=t2.id
    GROUP BY product_id, type, year
    """
    yearly_df = db.querydf(query)
    yearly_df = yearly_df.merge(product_df, on="product_id", how="left")

    query = """
    SELECT t2.product_id AS product_id, t1.filetype AS type, t3.termextendedname AS samp_country,
        COUNT(*) AS total_measurements, COUNT(1) FILTER (WHERE t1.evalcode_id IN (2, 9, 11, 14)) AS positive_measurements
    FROM efsa.measurement_core AS t1
    LEFT JOIN efsa.sample_core AS t2
        ON t1.sample_core_id=t2.id
    LEFT JOIN ontologies_efsa.country AS t3
        ON t2.sampcountry_id=t3.id
    GROUP BY samp_country, product_id, type
    """
    samp_df = db.querydf(query)
    samp_df = samp_df.merge(product_df, on="product_id", how="left")

    query = """
    SELECT t2.product_id AS product_id, t1.filetype AS type, t3.termextendedname AS orig_country,
        COUNT(*) AS total_measurements, COUNT(1) FILTER (WHERE t1.evalcode_id IN (2, 9, 11, 14)) AS positive_measurements
    FROM efsa.measurement_core AS t1
    LEFT JOIN efsa.sample_core AS t2
        ON t1.sample_core_id=t2.id
    LEFT JOIN ontologies_efsa.country AS t3
        ON t2.origcountry_id=t3.id
    GROUP BY orig_country, product_id, type
    """
    orig_df = db.querydf(query)
    orig_df = orig_df.merge(product_df, on="product_id", how="left")

    with pd.ExcelWriter('positive_contaminants_per_product.xlsx') as writer:
        df.to_excel(writer, sheet_name='distribution', index=False)
        yearly_df.to_excel(writer, sheet_name='yearly distribution', index=False)
        samp_df.to_excel(writer, sheet_name='samp country distribution', index=False)
        orig_df.to_excel(writer, sheet_name='orig country distribution', index=False)

In [None]:
# Load the Excel files
country_dist = pd.read_excel("positive_contaminants_per_product.xlsx", sheet_name="samp country distribution")
faostat = pd.read_excel("extra_files/faostat_production.xlsx")
# Rename columns in faostat
faostat = faostat.rename(columns={"value": "productionFAO", "area": "samp_country"})
faostat = faostat[(faostat["year"] >= 2000) & (faostat["year"] <= 2024)]

country_dist["samp_country"] = country_dist["samp_country"].str.title()
faostat["samp_country"] = faostat["samp_country"].str.title()


# Summarize country_dist data by type and country
country_meas_df = country_dist.groupby(["type", "samp_country"]).agg(
    total_meas=("total_measurements", "sum"),
    pos_meas=("positive_measurements", "sum")
).reset_index()

country_meas_df["percentage"] = (country_meas_df["pos_meas"] / country_meas_df["total_meas"]) * 100

# Create top country list
top_countries = (
    country_meas_df.groupby("samp_country")
    .agg(total_meas_sum=("total_meas", "sum"))
    .sort_values("total_meas_sum", ascending=False)
    # .head(28)
    .reset_index()
)
# top_countries["order"] = range(len(top_countries))  

# Filter and reorder country_meas_df
filtered_df = country_meas_df[country_meas_df["samp_country"].isin(top_countries["samp_country"])]

# Merge total production data into filtered_df and reorder
fao_summ = faostat.groupby("samp_country").agg(
    total_production=("productionFAO", lambda x: x.sum(skipna=True))
).reset_index()

# Merge in FAO production
summary_df = top_countries.merge(
    fao_summ, on="samp_country", how="left"
)



formatter0 = mticker.ScalarFormatter(useMathText=True)
formatter0.set_scientific(True)
formatter0.set_powerlimits((-2, 2))

formatter1 = mticker.ScalarFormatter(useMathText=True)
formatter1.set_scientific(True)
formatter1.set_powerlimits((-2, 2))

#%%
fig, axes = plt.subplots(2, 1, figsize=(8, 7), sharex=True)
    # Set style and context
sns.set_style("ticks")
sns.set_context("paper")

pal = sns.color_palette("tab20", n_colors=len(summary_df))

# Total Measurements Bar Plot
sns.barplot(
        data=summary_df,
        x="samp_country",
        y="total_meas_sum",
        palette=pal,
        edgecolor="white",
        linewidth=1.5,
        ax=axes[0]
    )
axes[0].set_title("Total measurements", fontsize=11)
axes[0].set_xlabel(" ", fontsize=9)
axes[0].set_ylabel("Number of measurements", fontsize=9)
axes[0].tick_params(axis='both', which='major', labelsize=9)
axes[0].yaxis.set_major_formatter(formatter0)
axes[0].grid(axis='y', linestyle='--', alpha=0.7)

    # Positive Measurements Bar Plot
sns.barplot(
        data=summary_df,
        x="samp_country",
        y="total_production",
        palette=pal,
        edgecolor="white",
        linewidth=1.5,
        ax=axes[1]
    )
axes[1].set_title("Total production (FAO)", fontsize=11)
axes[1].set_xlabel(" ", fontsize=9)
axes[1].set_ylabel("Tonnes")
axes[1].tick_params(axis='both', which='major', labelsize=9)
axes[1].grid(axis='y', linestyle='--', alpha=0.7)
axes[1].yaxis.set_major_formatter(formatter1)
axes[1].tick_params(axis='x', rotation=90)

sns.despine()
plt.tight_layout()
plt.savefig('country_stats_FAO.png', dpi=600, bbox_inches='tight')
plt.show()


#%%
def create_combined_df(filtered_df, top_countries):
    types = filtered_df["type"].unique()
    countries = top_countries["samp_country"].unique()

    combined_data = []
    for country in countries:
        country_data = {"samp_country": country}
        for type_val in types:
            type_df = filtered_df[(filtered_df["type"] == type_val) & (filtered_df["samp_country"] == country)]
            if not type_df.empty:
                country_data[f"{type_val}_total"] = type_df["total_meas"].iloc[0]
                country_data[f"{type_val}_pos"] = type_df["pos_meas"].iloc[0]
            else:
                country_data[f"{type_val}_total"] = 0  # or NaN, depending on preference
                country_data[f"{type_val}_pos"] = 0    # or NaN, depending on preference
        combined_data.append(country_data)

    combined_df = pd.DataFrame(combined_data)
    combined_df["samp_country"] = pd.Categorical(combined_df["samp_country"], categories=countries, ordered=True)
    combined_df = combined_df.sort_values(by="samp_country") #sort to ensure proper order.
    return combined_df

# Example usage (assuming you have filtered_df and top_countries defined):
combined_df = create_combined_df(filtered_df, top_countries)

types = ["chemical", "pesticides", "vmpr"]
tit = ["Chemical contaminants", "Pesticide residues", "VMPRs"]
pal2 = sns.color_palette("tab10")

combined_df = combined_df.iloc[0:15,:].reset_index(drop=True)
combined_df["samp_country"] = pd.Categorical(
    combined_df["samp_country"],
    categories=combined_df["samp_country"].tolist(),  # use only the 20
    ordered=True
)


# Create subplots in a 3x2 grid
fig, axes = plt.subplots(3, 2, figsize=(7, 10), sharey=True)

for i, type_val in enumerate(types):
    formatter0 = mticker.ScalarFormatter(useMathText=True)
    formatter0.set_scientific(True)
    formatter0.set_powerlimits((-2, 2))
    formatter1 = mticker.ScalarFormatter(useMathText=True)
    formatter1.set_scientific(True)
    formatter1.set_powerlimits((-2, 2))
    # Total measurements
    sns.barplot(
        data=combined_df,
        y="samp_country",
        x=f"{type_val}_total",
        color=pal2[i],
        edgecolor="white",
        linewidth=1.5,
        ax=axes[i, 0]
    )
    axes[i, 0].set_title(f"{tit[i]} - Total measurements", fontsize=9, weight='bold')
    axes[i, 0].set_ylabel("", fontsize=11, weight='bold')
    axes[i, 0].set_xlabel("")
    axes[i, 0].xaxis.set_major_formatter(formatter0)
    axes[i, 0].tick_params(labelsize=8)
    axes[i, 0].grid(axis='x', linestyle='--', alpha=0.5)
    axes[i, 0].tick_params(axis='y') 

    sns.barplot(
        data=combined_df,
        y="samp_country",
        x=f"{type_val}_pos",
        color=pal2[i],
        edgecolor="white",
        linewidth=1.5,
        ax=axes[i, 1]
    )

    # Add percentage labels to the end of each bar
    for j, (val, total) in enumerate(zip(combined_df[f"{type_val}_pos"], combined_df[f"{type_val}_total"])):
        pct = val / total * 100 if total > 0 else 0
        axes[i, 1].text(
            val + 0.01 * combined_df[f"{type_val}_pos"].max(),  # x-pos
            j,                                  # y-pos
            f"{pct:.2f}%",                      
            va='center',
            ha='left',
            fontsize=8,
            color='black'
        )
    axes[i, 1].set_title(f"Measurements above limits", fontsize=9, weight='bold')
    axes[i, 1].set_ylabel("", fontsize=11, weight='bold')
    axes[i, 1].set_xlabel("")
    axes[i, 1].xaxis.set_major_formatter(formatter1)
    axes[i, 1].tick_params(labelsize=8)
    axes[i, 1].grid(axis='x', linestyle='--', alpha=0.5)
    axes[i, 1].tick_params(axis='y') 


sns.despine()
plt.tight_layout()
plt.savefig('country_stats_total_positive_combined.png', dpi=1200, bbox_inches='tight')
plt.show()

In [None]:
file_path = "positive_contaminants_per_product.xlsx"

food_measurements = pd.read_excel(file_path)

# Separate the 'full_name' column into multiple columns
df_separated = food_measurements['full_name'].str.split('::', expand=True)
# Rename the new columns (similar to Lev1 to Lev12 in R)
df_separated.columns = [f'Lev{i+1}' for i in range(df_separated.shape[1])]
# Concatenate the separated columns back to the original DataFrame
df_separated = pd.concat([food_measurements, df_separated], axis=1)


# Apply the logic to create the 'result' column
def create_result(row):
    if row['type'] == "pesticides" and row['Lev2'] == "all lists" and row['Lev3'] == "food" and row['Lev4'] == "plant commodities (rpcs)" and row['Lev5'] == "fruit rpcs":
        return row['Lev5']
    elif row['type'] == "pesticides" and row['Lev2'] == "all lists" and row['Lev3'] == "food" and row['Lev4'] == "plant commodities (rpcs)" and row['Lev5'] != "fruit rpcs":
        return row['Lev5']     
    elif row['type'] == "pesticides" and row['Lev2'] == "all lists" and row['Lev3'] == "food" and (row['Lev4'] not in ["plant commodities (rpcs)"] or pd.isna(row['Lev4'])) and not pd.isna(row['Lev5']):
        return row['Lev5']
    elif row['type'] == "pesticides" and row['Lev2'] == "all lists" and row['Lev3'] == "food" and (row['Lev4'] not in ["plant commodities (rpcs)"]) and pd.isna(row['Lev5']):
        return f"Other {row['Lev4']}"
    elif row['type'] == "pesticides" and row['Lev2'] == "all lists" and row['Lev3'] == "feed":
        return row['Lev3']
    elif row['type'] == "pesticides" and row['Lev2'] == "all lists" and row['Lev3'] not in ["food", "feed"]:
        return f"Other {row['Lev3']}"
    elif row['type'] == "pesticides" and row['Lev2'] == "animal products":
        return row['Lev3']
    elif row['type'] == "pesticides" and row['Lev2'] == "cereals" and not pd.isna(row['Lev3']):
        return row['Lev3']
    elif row['type'] == "pesticides" and row['Lev2'] == "cereals" and pd.isna(row['Lev3']):
        return row['Lev2']
    elif row['type'] == "pesticides" and row['Lev2'] == "fruits and nuts, vegetables and other plant products" and not pd.isna(row['Lev4']):
        return row['Lev4']
    elif row['type'] == "pesticides" and row['Lev2'] == "fruits and nuts, vegetables and other plant products" and pd.isna(row['Lev4']):
        return f"Other {row['Lev2']}"
    elif row['type'] == "pesticides" and row['Lev2'] not in ["all lists", "animal products", "cereals", "fruits and nuts, vegetables and other plant products"]:
        return row['Lev2']
        
        
        
    elif row['type'] == "chemical" and row['Lev3'] == "food" and row['Lev4'] == "plant commodities (rpcs)":
        return row['Lev5']
    elif row['type'] == "chemical" and row['Lev3'] == "food" and row['Lev4'] != "plant commodities (rpcs)" and not pd.isna(row['Lev5']):
        return row['Lev5']
    elif row['type'] == "chemical" and row['Lev3'] == "food" and row['Lev4'] != "plant commodities (rpcs)" and pd.isna(row['Lev5']):
        return f"Other {row['Lev4']}"
    elif row['type'] == "chemical" and row['Lev3'] == "feed":
        return row['Lev3']
    elif row['type'] == "chemical" and row['Lev3'] not in ["food", "feed"]:
        return f"Other {row['Lev4']}"
        
        
        
    elif row['type'] == "vmpr" and row['Lev3'] == "food" and not pd.isna(row['Lev5']):
        return row['Lev5']
    elif row['type'] == "vmpr" and row['Lev3'] == "food" and pd.isna(row['Lev5']):
        return row['Lev4']
    elif row['type'] == "vmpr" and row['Lev3'] == "feed":
        return row['Lev3']
    elif row['type'] == "vmpr" and row['Lev3'] == "facets":
        return row['Lev3']
    elif row['type'] == "vmpr" and row['Lev3'] == "non-food matrices":
        return row['Lev3']
    elif row['type'] == "vmpr" and row['Lev3'] not in ["food", "feed", "facets", "non-food matrices"]:
        return "Other VMPR"
    else:
        return pd.NA

df2 = df_separated.copy()
df2['result'] = df2.apply(create_result, axis=1)

# Summarize data by category
df_summary2 = df2.groupby(['type', 'result']).agg(
    total_measurements=pd.NamedAgg(column='total_measurements', aggfunc='sum'),
    positive_measurements=pd.NamedAgg(column='positive_measurements', aggfunc='sum')
).reset_index()

df_summary2['proportion_positive'] = df_summary2['positive_measurements'] / df_summary2['total_measurements']
df_summary2 = df_summary2.sort_values(by='total_measurements', ascending=False)
df_summary2 = df_summary2.dropna(subset=['result']) # Equivalent to na.omit(df_summary2)

# Load the CSV file into a DataFrame
category_df = pd.read_excel("extra_files/grouped_food_items.xlsx")

# Build the mapping dictionary
CATEGORY_MAPPING = {}
for _, row in category_df.iterrows():
    item = row['Item'].strip().lower()
    category = row['Category'].strip()
    CATEGORY_MAPPING[item] = category

def assign_category(row):
    result = row['result']
    if pd.isna(result):
        return result
    result_lower = result.lower()
    for item, category in CATEGORY_MAPPING.items():
        if item in result_lower:
            return category
    return result


df = df_summary2.copy()
df['category'] = df.apply(assign_category, axis=1)

df_grouped = df.groupby(['type', 'category']).agg(
    total_measurements=pd.NamedAgg(column='total_measurements', aggfunc='sum'),
    positive_measurements=pd.NamedAgg(column='positive_measurements', aggfunc='sum'),
).reset_index()

# Capitalize the first letter of the category
df_grouped['category'] = df_grouped['category'].str[0].str.upper() + df_grouped['category'].str[1:]
df_grouped = df_grouped.sort_values(by='total_measurements', ascending=False)
df_grouped['proportion_positive'] = df_grouped['positive_measurements'] / df_grouped['total_measurements']



title_dict = ["Chemical contaminants", "Pesticide residues", "VMPRs"]
type_list = sorted(df_grouped["type"].unique())

def plot_all_top_contaminants(food, context="paper", style="ticks", figsize=(9, 12)):
    toy = food.dropna()

    sns.set_style(style)
    sns.set_context(context)

    pal = sns.color_palette("tab10")

    # Setup figure with 3 rows (for each 'type'), 2 columns (total vs positive)
    fig, axes = plt.subplots(3, 2, figsize=figsize, sharey='row')

    for ind, var in enumerate(type_list):
        # var = pd.unique(toy['type'])[ind]
        df = toy[toy["type"] == var].copy()

    # Capitalize and filter top 10
        df['category'] = df['category'].str.title()
        # df = df.nlargest(10, "total_measurements").copy()
        df['category'] = pd.Categorical(df['category'], categories=df['category'].unique()[::-1], ordered=True)

        color_total = pal[ind]
        color_pos = pal[ind]
        formatter0 = mticker.ScalarFormatter(useMathText=True)
        formatter0.set_scientific(True)
        formatter0.set_powerlimits((-2, 2))
        formatter1 = mticker.ScalarFormatter(useMathText=True)
        formatter1.set_scientific(True)
        formatter1.set_powerlimits((-2, 2))
    
        # TOTAL Measurements
        sns.barplot(
            data=df,
            x="total_measurements",
            y="category",
            color=color_total,
            edgecolor="white",
            linewidth=1.5,
            ax=axes[ind, 0]
        )
        axes[ind, 0].set_title(f"{title_dict[ind]} – Total measurements", fontsize=9, weight='bold')
        axes[ind, 0].set_xscale("log")
        axes[ind, 0].set_xlabel("")
        axes[ind, 0].set_ylabel("" if ind > 0 else " ", fontsize=8)
        axes[ind, 0].tick_params(axis='both', which='major', labelsize=8)
        axes[ind, 0].xaxis.set_major_formatter(formatter0)
        axes[ind, 0].grid(True, which='both', axis='x', linestyle='--', alpha=0.4)
        axes[ind, 0].xaxis.set_major_locator(LogLocator(base=10.0, subs=(1.0,), numticks=10))
        axes[ind, 0].xaxis.set_minor_locator(LogLocator(base=10.0, subs=np.arange(2, 10) * 0.1, numticks=10))
        axes[ind, 0].tick_params(which='minor', length=3)
        axes[ind, 0].invert_yaxis() 


        # POSITIVE Measurements
        sns.barplot(
            data=df,
            x="positive_measurements",
            y="category",
            color=color_pos,
            edgecolor="white",
            linewidth=1.5,
            ax=axes[ind, 1]
        )
        # Add percentage labels
        for cat, pos_val, total_val in zip(df['category'], df['positive_measurements'], df['total_measurements']):
            if total_val > 0 and pos_val > 0:
                pct = pos_val / total_val * 100
                axes[ind, 1].text(
                    pos_val + 0.001 * df['positive_measurements'].max(),
                    cat,
                    f"{pct:.2f}%",
                    va='center',
                    ha='left',
                    fontsize=7,
                    color='black'
                )

        axes[ind, 1].set_title(f"Measurements above limits", fontsize=9, weight='bold')
        axes[ind, 1].set_xscale("log")
        axes[ind, 1].set_xlabel("")
        axes[ind, 1].set_ylabel("")
        axes[ind, 1].tick_params(axis='both', which='major', labelsize=8)
        axes[ind, 1].xaxis.set_major_formatter(formatter1)
        axes[ind, 1].grid(True, which='both', axis='x', linestyle='--', alpha=0.4)
        axes[ind, 1].xaxis.set_major_locator(LogLocator(base=10.0, subs=(1.0,), numticks=10))
        axes[ind, 1].xaxis.set_minor_locator(LogLocator(base=10.0, subs=np.arange(2, 10) * 0.1, numticks=10))
        axes[ind, 1].tick_params(which='minor', length=3)
        axes[ind, 1].invert_yaxis() 

    sns.despine()
    plt.tight_layout()
    plt.savefig("all_bar_food_groupings.png", dpi=1200, bbox_inches="tight")
    plt.show()

plot_all_top_contaminants(df_grouped)

In [None]:
contamin = pd.read_excel("positive_contaminants_per_product.xlsx", sheet_name="yearly distribution")

# Summarize the data
summ = (
    contamin.groupby(["type", "year"])
    .agg(
        total_meas=("total_measurements", "sum"),
        pos_meas=("positive_measurements", "sum"),
    )
    .reset_index()
)

# Add percentage column
summ["percentage"] = (summ["pos_meas"] / summ["total_meas"]) * 100
summ = summ.drop(summ[summ['year'].isin([1970, 1998, 1999, 2107])].index)

pal = sns.color_palette("tab10")
formatter0 = mticker.ScalarFormatter(useMathText=True)
formatter0.set_scientific(True)
formatter0.set_powerlimits((-2, 2))

formatter1 = mticker.ScalarFormatter(useMathText=True)
formatter1.set_scientific(True)
formatter1.set_powerlimits((-2, 2))




# Create the figure with two subplots and shared x-axis
chem = summ[summ['type'] == 'chemical']


fig, axes = plt.subplots(2, 1, figsize=(6, 6), sharex=True)

# Plot total measurements on the first subplot (axes[0])
axes[0].plot(chem['year'], chem['total_meas'], 
             marker='o', 
             linestyle='-', 
             color = pal[0], 
             linewidth = 2)
axes[0].set_title('Chemical contaminants', fontsize=14)
axes[0].set_ylabel('Total measurements', fontsize=12)
axes[0].grid(True, linestyle='--', alpha=.5)
axes[0].xaxis.set_major_locator(mticker.MaxNLocator(nbins='auto', integer=True, prune=None))
axes[0].tick_params(axis='both', labelsize=11)
axes[0].yaxis.set_major_formatter(formatter0)


axes[1].plot(chem['year'], chem['pos_meas'], 
             marker='o', 
             linestyle='-', 
             color = pal[0], 
             linewidth = 2)
axes[1].set_title('')
axes[1].set_ylabel('Measurements above limits', fontsize=12)
axes[1].grid(True, linestyle='--', alpha=.5)
axes[1].xaxis.set_major_locator(mticker.MaxNLocator(nbins='auto', integer=True, prune=None))
axes[1].set_xlabel('Year', fontsize=12) 
axes[1].tick_params(axis='both', labelsize=11)
# axes[1].set_xticks(chem['year'].unique())
axes[1].yaxis.set_major_formatter(formatter1)

fig.align_ylabels(axes)
plt.setp(axes[1].get_xticklabels(), rotation=45, ha='right')
plt.tight_layout()
plt.savefig('yearly_dist_lineplot_chem.png', dpi=600, bbox_inches='tight')
plt.show()




# Create the figure with two subplots and shared x-axis
pest = summ[summ['type'] == 'pesticides']
fig, axes = plt.subplots(2, 1, figsize=(6, 6), sharex=True)

# Plot total measurements on the first subplot (axes[0])
axes[0].plot(pest['year'], pest['total_meas'], 
             marker='o', 
             linestyle='-', 
             color = pal[1], 
             linewidth = 2)
axes[0].set_title('Pesticide residues', fontsize=14)
axes[0].set_ylabel('Total measurements', fontsize=12)
axes[0].grid(True, linestyle='--', alpha=.5)
axes[0].xaxis.set_major_locator(mticker.MaxNLocator(nbins='auto', integer=True, prune=None))
axes[0].tick_params(axis='both', labelsize=11)
axes[0].yaxis.set_major_formatter(formatter0)


# Plot positive measurements on the second subplot (axes[1])
axes[1].plot(pest['year'], pest['pos_meas'], 
             marker='o', 
             linestyle='-', 
             color = pal[1], 
             linewidth = 2)
axes[1].set_title('')
axes[1].set_ylabel('Measurements above limits', fontsize=12)
axes[1].grid(True, linestyle='--', alpha=.5)
axes[1].xaxis.set_major_locator(mticker.MaxNLocator(nbins='auto', integer=True, prune=None))
axes[1].set_xlabel('Year', fontsize=12) 
axes[1].tick_params(axis='both', labelsize=11)
# axes[1].set_xticks(pest['year'].unique())
axes[1].yaxis.set_major_formatter(formatter1)

fig.align_ylabels(axes)
plt.setp(axes[1].get_xticklabels(), rotation=45, ha='right')
plt.tight_layout()
plt.savefig('yearly_dist_lineplot_pest.png', dpi=600, bbox_inches='tight')
plt.show()





# Create the figure with two subplots and shared x-axis
vmpr = summ[summ['type'] == 'vmpr']
vmpr = vmpr.drop(vmpr[vmpr['year'] == 2107].index)

fig, axes = plt.subplots(2, 1, figsize=(6,6), sharex=True)

axes[0].plot(vmpr['year'], vmpr['total_meas'], 
             marker='o', 
             linestyle='-', 
             color = pal[2],
             linewidth = 2)
axes[0].set_title('VMPRs', fontsize = 14)
axes[0].set_ylabel('Total measurements', fontsize = 12)
axes[0].grid(True, linestyle='--', alpha=.5)
axes[0].xaxis.set_major_locator(mticker.MaxNLocator(nbins='auto', integer=True, prune=None))
axes[0].tick_params(axis='both', labelsize=11)
axes[0].yaxis.set_major_formatter(formatter0)


axes[1].plot(vmpr['year'], vmpr['pos_meas'], 
             marker='o', 
             linestyle='-', 
             color = pal[2],
             linewidth = 2)
axes[1].set_title('')
axes[1].set_ylabel('Measurements above limits', fontsize=12)
axes[1].grid(True, linestyle='--', alpha=.5)
axes[1].xaxis.set_major_locator(mticker.MaxNLocator(nbins='auto', integer=True, prune=None))
axes[1].set_xlabel('Year', fontsize=12) 
axes[1].tick_params(axis='both', labelsize=11)
# axes[1].set_xticks(vmpr['year'].unique())
axes[1].yaxis.set_major_formatter(formatter1)

fig.align_ylabels(axes)
plt.setp(axes[1].get_xticklabels(), rotation=45, ha='right')
plt.tight_layout()
plt.savefig('yearly_dist_lineplot_vmpr.png', dpi=600, bbox_inches='tight')
plt.show()





# Custom legend labels
colors = pal[:3]
labels = ['Chemical contaminants', 'Pesticide residues', 'VMPRs']

# Create patch handles
legend_handles = [Patch(facecolor=pal[i], edgecolor='black', label=labels[i]) for i in range(3)]

# Create a separate legend figure
fig, ax = plt.subplots(figsize=(3, 1))  
ax.axis('off')  
legend = ax.legend(handles=legend_handles, 
                   loc='center', 
                   frameon=False, 
                   ncol=1, 
                   fontsize=12)

plt.savefig("GENERAL_legends.png", dpi=600, bbox_inches='tight')
plt.show()

### Total and positive number of measurements per origin and sampling countries (all or per year)

In [None]:
# Origin, sampling countries distribution

with PostgresDatabase(os.getenv("DB_HOST"), os.getenv("DB_PORT"), os.getenv("DB_NAME"), os.getenv("DB_USERNAME"), os.getenv("DB_PASSWORD")) as db:
    # Samples grouped by original country, sampling country, and year
    query = "SELECT origcountry_id AS orig_country, sampcountry_id AS samp_country, sampy AS year, count(*) AS number_of_samples FROM efsa.sample_core GROUP BY orig_country, samp_country, year ORDER BY orig_country, samp_country, year ASC"
    df = db.querydf(query)
    df.loc[df.orig_country.isna(), "orig_country"] = 9999

    query = "SELECT id, termextendedname FROM ontologies_efsa.country"
    country_id_to_name = {id: name for (id, name) in db.query(query)}
    country_id_to_name[9999] = "unknown"
    
    df.orig_country = df.orig_country.apply(lambda x: country_id_to_name[x])
    df.samp_country = df.samp_country.apply(lambda x: country_id_to_name[x])


    # Samples with at least one above legal limit measurement
    query = """
    SELECT t1.origcountry_id AS orig_country, t1.sampcountry_id AS samp_country, t1.sampy AS year, count(*) AS number_of_samples
    FROM efsa.sample_core AS t1
    INNER JOIN (SELECT sample_core_id FROM efsa.measurement_core WHERE evalcode_id IN (2, 9, 11, 14) GROUP BY sample_core_id) AS t2
        ON t1.id=t2.sample_core_id
    GROUP BY orig_country, samp_country, year
    ORDER BY orig_country, samp_country, year ASC
    """
    pos_df = db.querydf(query)
    pos_df.loc[pos_df.orig_country.isna(), "orig_country"] = 9999
    pos_df.orig_country = pos_df.orig_country.apply(lambda x: country_id_to_name[x])
    pos_df.samp_country = pos_df.samp_country.apply(lambda x: country_id_to_name[x])

    all_df = df.groupby(by=["orig_country", "samp_country"], as_index=False).number_of_samples.sum()
    all_pos_df = pos_df.groupby(by=["orig_country", "samp_country"], as_index=False).number_of_samples.sum()
    with pd.ExcelWriter('original_sampling_country.xlsx') as writer:
        all_df.to_excel(writer, sheet_name='distribution')
        df.to_excel(writer, sheet_name="yearly distribution")
        all_pos_df.to_excel(writer, sheet_name='pos distribution')
        pos_df.to_excel(writer, sheet_name="pos yearly distribution")

In [None]:
# Load data
data = pd.read_excel("original_sampling_country.xlsx", sheet_name="yearly distribution", index_col=0)
positive_data = pd.read_excel("original_sampling_country.xlsx", sheet_name="pos yearly distribution", index_col=0)
positive_data = positive_data.rename(columns={'number_of_samples': 'pos_number_of_samples'})

data = data.drop(data[data['year'].isin([1970, 1998, 1999, 2107])].index)
positive_data = positive_data.drop(positive_data[positive_data['year'].isin([1970, 1998, 1999, 2107])].index)

# Merge datasets
merged_data = data.merge(positive_data, on=['orig_country', 'samp_country', 'year'], how='left')
merged_data['pos_number_of_samples'] = merged_data['pos_number_of_samples'].fillna(0)
# Sum across years for same (origin, sampling) country pairs
merged_data = merged_data.groupby(['orig_country', 'samp_country'], as_index=False).agg({
    'number_of_samples': 'sum',
    'pos_number_of_samples': 'sum'
})

# Recalculate positive ratio after aggregation
merged_data['positive_ratio'] = merged_data['pos_number_of_samples'] / merged_data['number_of_samples']
merged_data['orig_country'] = merged_data['orig_country'].str.title()
merged_data['samp_country'] = merged_data['samp_country'].str.title()



# Assuming merged_data is already loaded and merged
sample_counts = merged_data['number_of_samples']

# Parameters
top_n = 20
ord = 'positive_ratio' #'positive_ratio' or "number_of_samples"
sample_cutoff = 5000


# Filter data
filtered_data = merged_data[
    (merged_data['orig_country'] != merged_data['samp_country']) & 
    (merged_data['number_of_samples'] > sample_cutoff)
]
filtered_data = filtered_data.nlargest(top_n, ord)

# Prepare chord diagram data
chord_data = filtered_data[['orig_country', 'samp_country', 'number_of_samples']].copy()
chord_data.columns = ['source', 'target', 'weight']  # Directly use D3Blocks' expected names



tit = f"chord_diag_top{top_n}count_top{sample_cutoff}samp_order{ord}"

# Store positive ratios for each edge
edge_ratios = filtered_data[['orig_country', 'samp_country', 'positive_ratio']].copy()
edge_ratios.columns = ['source', 'target', 'positive_ratio']

# Create chord diagram
d3 = D3Blocks()
d3.chord(
    chord_data,
    filepath=f'{tit}.html',
    title= tit,
    figsize = [600,600],
    fontsize=12,
    arrowhead=20,
    cmap='tab20',  # This applies to nodes only
    notebook=False
)

def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=256):
    """Returns a truncated colormap from cmap between minval and maxval"""
    new_cmap = LinearSegmentedColormap.from_list(
        f'trunc({cmap.name},{minval:.2f},{maxval:.2f})',
        cmap(np.linspace(minval, maxval, n))
    )
    return new_cmap

# Normalize and map colors from a colormap (e.g., viridis)
original_cmap = cm.get_cmap('Reds')
# cmap = cm.get_cmap('viridis_r')  # Choose any: 'viridis', 'coolwarm', 'inferno', etc.
cmap = truncate_colormap(original_cmap, 0.15, 1.0) 
norm = mcolors.Normalize(vmin=edge_ratios['positive_ratio'].min(), vmax=edge_ratios['positive_ratio'].max())

# Apply color to each edge based on positive_ratio
for _, row in edge_ratios.iterrows():
    source = row['source']
    target = row['target']
    ratio = row['positive_ratio']
    color = mcolors.to_hex(cmap(norm(ratio)))  # Convert RGBA to HEX

    d3.edge_properties.loc[
        (d3.edge_properties['source'] == source) &
        (d3.edge_properties['target'] == target),
        'color'
    ] = color

# Show diagram
d3.show()

# %%
fig, ax = plt.subplots(figsize=(0.5, 4))
fig.subplots_adjust(left=0.4, right=0.9)

# Create scalar mappable
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])

# Explicitly provide the axis to use for the colorbar
cbar = fig.colorbar(sm, cax=ax, orientation='vertical')
cbar.set_label('Sample above legal limits', fontsize=12, labelpad=15)
cbar.ax.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0, decimals=1))
cbar.ax.tick_params(labelsize=11)

plt.savefig(f'{tit}_legend.png', dpi=200, bbox_inches='tight')
plt.show()

In [None]:
# Load data
data = pd.read_excel("original_sampling_country.xlsx", 
                     sheet_name="yearly distribution", index_col=0)
positive_data = pd.read_excel("original_sampling_country.xlsx", 
                              sheet_name="pos yearly distribution", index_col=0)
positive_data = positive_data.rename(columns={'number_of_samples': 'pos_number_of_samples'})
# data = data.drop(data[data['year'] == 2107].index)
data = data.drop(data[data['year'].isin([1970, 1998, 1999, 2107])].index)
positive_data = positive_data.drop(positive_data[positive_data['year'].isin([1970, 1998, 1999, 2107])].index)

#  Keep only the last 5 years
latest_years = sorted(data['year'].unique())[-5:]  # Get last 5 years
data = data[data['year'].isin(latest_years)]
positive_data = positive_data[positive_data['year'].isin(latest_years)]

# Merge datasets
merged_data = data.merge(positive_data, on=['orig_country', 'samp_country', 'year'], how='left')
merged_data['pos_number_of_samples'] = merged_data['pos_number_of_samples'].fillna(0)

# Sum across years for same (origin, sampling) country pairs
merged_data = merged_data.groupby(['orig_country', 'samp_country'], as_index=False).agg({
    'number_of_samples': 'sum',
    'pos_number_of_samples': 'sum'
})

# Recalculate positive ratio after aggregation
merged_data['positive_ratio'] = merged_data['pos_number_of_samples'] / merged_data['number_of_samples']

merged_data['orig_country'] = merged_data['orig_country'].str.title()
merged_data['samp_country'] = merged_data['samp_country'].str.title()



# Assuming merged_data is already loaded and merged
sample_counts = merged_data['number_of_samples']

sample_cutoff = 100



# Parameters
top_n = 20
ord = 'positive_ratio' #'positive_ratio' or "number_of_samples"

# Filter data
filtered_data = merged_data[
    (merged_data['orig_country'] != merged_data['samp_country']) & 
    (merged_data['number_of_samples'] > sample_cutoff)
]
filtered_data = filtered_data.nlargest(top_n, ord)


# Prepare chord diagram data
chord_data = filtered_data[['orig_country', 'samp_country', 'number_of_samples']].copy()
chord_data.columns = ['source', 'target', 'weight']  # Directly use D3Blocks' expected names



tit = f"Last3Years_chord_diag_top{top_n}count_top{sample_cutoff}samp_order{ord}"

# Store positive ratios for each edge
edge_ratios = filtered_data[['orig_country', 'samp_country', 'positive_ratio']].copy()
edge_ratios.columns = ['source', 'target', 'positive_ratio']

# Create chord diagram
d3 = D3Blocks()
d3.chord(
    chord_data,
    filepath=f'{tit}.html',
    title= tit,
    figsize = [600,600],
    fontsize=12,
    arrowhead=20,
    cmap='tab20',  # This applies to nodes only
    notebook=False
)

def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=256):
    """Returns a truncated colormap from cmap between minval and maxval"""
    new_cmap = LinearSegmentedColormap.from_list(
        f'trunc({cmap.name},{minval:.2f},{maxval:.2f})',
        cmap(np.linspace(minval, maxval, n))
    )
    return new_cmap

# Normalize and map colors from a colormap (e.g., viridis)
original_cmap = cm.get_cmap('Reds')
# cmap = cm.get_cmap('viridis_r')  # Choose any: 'viridis', 'coolwarm', 'inferno', etc.
cmap = truncate_colormap(original_cmap, 0.15, 1.0) 
norm = mcolors.Normalize(vmin=edge_ratios['positive_ratio'].min(), vmax=edge_ratios['positive_ratio'].max())

# Apply color to each edge based on positive_ratio
for _, row in edge_ratios.iterrows():
    source = row['source']
    target = row['target']
    ratio = row['positive_ratio']
    color = mcolors.to_hex(cmap(norm(ratio)))  # Convert RGBA to HEX

    d3.edge_properties.loc[
        (d3.edge_properties['source'] == source) &
        (d3.edge_properties['target'] == target),
        'color'
    ] = color

# Show diagram
d3.show()


fig, ax = plt.subplots(figsize=(0.5, 4))
fig.subplots_adjust(left=0.4, right=0.9)

# Create scalar mappable
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])

# Explicitly provide the axis to use for the colorbar
cbar = fig.colorbar(sm, cax=ax, orientation='vertical')
cbar.set_label('Sample above legal limits', fontsize=12, labelpad=15)
cbar.ax.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0, decimals=1))
cbar.ax.tick_params(labelsize=11)

plt.savefig(f'{tit}_legend.png', dpi=200, bbox_inches='tight')
plt.show()



data = pd.read_excel("original_sampling_country.xlsx", 
                     sheet_name="yearly distribution", index_col=0)
positive_data = pd.read_excel("original_sampling_country.xlsx", 
                              sheet_name="pos yearly distribution", index_col=0)
positive_data = positive_data.rename(columns={'number_of_samples': 'pos_number_of_samples'})
data = data.drop(data[data['year'] == 2107].index)


# Merge datasets
merged_data = data.merge(positive_data, on=['orig_country', 'samp_country', 'year'], how='left')
merged_data['pos_number_of_samples'] = merged_data['pos_number_of_samples'].fillna(0)


# Get total samples per year (this includes all years)
total_samples_per_year = merged_data.groupby('year')['number_of_samples'].sum().reset_index()
total_samples_per_year = total_samples_per_year.rename(columns={'number_of_samples': 'total_samples'})

# Filter for unknown origin
unknown_origin = merged_data[merged_data['orig_country'].str.lower() == 'unknown']

# Get unknown samples per year
unknown_samples_per_year = unknown_origin.groupby('year')['number_of_samples'].sum().reset_index()
unknown_samples_per_year = unknown_samples_per_year.rename(columns={'number_of_samples': 'unknown_samples'})

# Merge and fill missing unknown counts with 0
percentages = total_samples_per_year.merge(unknown_samples_per_year, on='year', how='left')
percentages['unknown_samples'] = percentages['unknown_samples'].fillna(0)

# Calculate percentage
percentages['unknown_percentage'] = (percentages['unknown_samples'] / percentages['total_samples']) * 100

print(percentages)


# Group by sampling country to see targets
targets = unknown_origin.groupby('samp_country')['number_of_samples'].sum().reset_index()

# Sort to see who was most frequently targeted
targets = targets.sort_values(by='number_of_samples', ascending=False)

print(targets)

### Sample distribution with respect to number of positive measurements, and total measurements

In [None]:
# Distribution of samples according to their number of positive measurements
with PostgresDatabase(os.getenv("DB_HOST"), os.getenv("DB_PORT"), os.getenv("DB_NAME"), os.getenv("DB_USERNAME"), os.getenv("DB_PASSWORD")) as db:
    query = """
            SELECT positive_measurements, count(*) AS num_samples
            FROM
            (
                SELECT sample_core_id, count(*) as positive_measurements
                FROM (SELECT sample_core_id FROM efsa.measurement_core WHERE evalcode_id IN (2, 9, 11, 14)) AS t1
                LEFT JOIN efsa.sample_core AS t2
                    ON t1.sample_core_id=t2.id
                GROUP BY t1.sample_core_id
            )
            GROUP BY positive_measurements
            ORDER BY positive_measurements
            """
    df = db.querydf(query)
    df.to_excel("sample_distribution_wrt_positive_measurements.xlsx", index=False)

    query = """
            SELECT total_measurements, count(*) AS num_samples
            FROM
            (
                SELECT sample_core_id, count(*) as total_measurements
                FROM (SELECT sample_core_id FROM efsa.measurement_core) AS t1
                LEFT JOIN efsa.sample_core AS t2
                    ON t1.sample_core_id=t2.id
                GROUP BY t1.sample_core_id
            )
            GROUP BY total_measurements
            ORDER BY total_measurements
            """
    df = db.querydf(query)
    df.to_excel("sample_distribution_wrt_total_measurements.xlsx", index=False)

In [None]:
# Load data
data = pd.read_excel("sample_distribution_wrt_total_measurements.xlsx")
positive_data = pd.read_excel("sample_distribution_wrt_positive_measurements.xlsx")


# Set Seaborn style
sns.set(style="whitegrid")

# Create square plot
fig, ax = plt.subplots(figsize=(6, 6))  # Square: width == height

# Scatter plot
sns.scatterplot(
    x="total_measurements",
    y="num_samples",
    data=data,
    color="tab:blue",
    s=15,
    alpha=0.8,
    ax=ax
)


# Log scale for y-axis
ax.set_yscale("log")

# Titles and labels
ax.set_title("Relationship Between Total Measurements and Number of Samples", fontsize=14)
ax.set_xlabel("Total Measurements", fontsize=12)
ax.set_ylabel("Number of Samples (Log Scale)", fontsize=12)

plt.tight_layout()
ax.set_aspect(1.0 / ax.get_data_ratio(), adjustable='box')
plt.savefig("samples_vs_measurements.png", dpi=600, bbox_inches='tight')  # Save before plt.show()
plt.show()