<div>
<img src= "ERMA_workflow.png" width="1000"/>
</div>

In [None]:
# 1. Step similarity search
# Input: fasta files
# Output: tabular similarity search results
# commands:
# usearch vs silva_v138.2 database (510495 reads): usearch -usearch_local {input.fasta} -db {input.silva} -blast6out {output.silva_results} -evalue 1e-5 -threads {params.internal_threads} -strand plus -mincols 200 2> {log}
# diamond vs card_v3.3.0 database (4840 reads): diamond blastx -d {input.card} -q {input.fasta} -o {output.card_results} --outfmt 6 --evalue 1e-5 --quiet --threads {params.internal_threads} 2> {log}
# Notes: Many rules that prepare the similarity search are reproduced with simple bash commands

import subprocess
import pathlib, os
from pathlib import Path
import shutil
from IPython.display import display, Markdown

# === Paths ===
base = pathlib.Path(os.path.dirname(pathlib.Path().resolve()))
github = base / ".github"

silva_dir = github / "data/silva_db"
card_dir = github / "data/card_db"
fastq_dir = github / "data/fastq"
test_out = base / ".test_steps"

fastq = fastq_dir / "test_epic_data.fastq.gz"
fasta = fastq.with_suffix(".fasta")
silva_gz = silva_dir / "sub_silva_seq_RNA.fasta.gz"
silva_fa = silva_gz.with_suffix("")
translated_silva = silva_fa.with_name(silva_fa.name.replace("_RNA", ""))
card_tar = card_dir / "card_seq.tar.bz2"
card_fasta = card_dir / "protein_fasta_protein_homolog_model.fasta"
card_db = card_dir / "card_db.dmnd"
result_dir = test_out / "results"
card_results = result_dir / "card_results.txt"
silva_results = result_dir / "SILVA_results.txt"

# === Utils ===
def run(cmd, silent=False):
    result = subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL if silent else None, stderr=subprocess.DEVNULL if silent else None)
    if result.returncode != 0:
        raise RuntimeError(f"Command failed: {cmd}")

def count_lines(file, pattern=None):
    cmd = f"grep -c '{pattern}' {file}" if pattern else f"wc -l < {file}"
    return int(subprocess.check_output(cmd, shell=True))

def clean(folder, keep):
    for item in Path(folder).iterdir():
        if item.name not in keep:
            if item.is_file():
                item.unlink()
            elif item.is_dir():
                shutil.rmtree(item)

# === Prepare and run similarity search ===
run(f"mkdir -p {result_dir}")
run(f"seqtk seq -a {fastq} > {fasta}")
run(f"gzip -dk {silva_gz}")
run(f"seqtk seq -r {silva_fa} > {translated_silva}")
run(f"tar -xjf {card_tar} -C {card_dir}")
run(f"diamond makedb --in {card_fasta} -d {card_db.with_suffix('')}")
run(f"diamond blastx -d {card_db} -q {fasta} -o {card_results} --outfmt 6 --evalue 1e-5 --threads 1 --quiet")
run(f"usearch -usearch_local {fasta} -db {translated_silva} -blast6out {silva_results} -evalue 1e-5 -threads 1 -strand plus -mincols 200 > /dev/null 2>&1", silent=True)

# === Summary ===
print(f"\nsample,state,total_count")
print(f"Number of FastQ input reads,{count_lines(fasta, '^>')}")
print(f"Diamond output hits,test,{count_lines(card_results)}")
print(f"Usearch output hits,test,{count_lines(silva_results)}")

# === Cleanup ===
clean(card_dir, {card_tar.name})
for f in fastq_dir.glob("*.fasta"): f.unlink()
for f in silva_dir.glob("*.fasta"): f.unlink()

# === Report ===
display(Markdown(f"### Processing Complete\n- CARD hits: `{count_lines(card_results)}`\n- SILVA hits: `{count_lines(silva_results)}`"))

In [None]:
# 2. Integrate similarity search results
# Selfwritten python script "integrate_blast_data.py"
# Input: diamond, usearch results, ARO Mapping file
# Output: Processed integrated search results

import pandas as pd
import concurrent.futures
import os, pathlib, subprocess

# === Paths ===
base = pathlib.Path(os.path.dirname(pathlib.Path().resolve()))
github = base / ".github"
log_dir = base / "logs"
log_dir.mkdir(exist_ok=True)

silva_dir = github / "data/silva_db"
card_dir = github / "data/card_db"
result_dir = base / ".test_steps/results"

silva_res = result_dir / "SILVA_results.txt"
card_res = result_dir / "card_results.txt"
aro_file = "aro_index.tsv"
aro_path = card_dir / "aro_index.tsv"
aro_tar = card_dir / "card_seq.tar.bz2"
card_interm = result_dir / "card_intermed.csv"
silva_interm = result_dir / "silva_intermed.csv"
result = result_dir / "integrated_result.csv"

# === Utils ===
def run(cmd):
    result = subprocess.run(cmd, shell=True)
    if result.returncode != 0:
        raise RuntimeError(f"Command failed: {cmd}")

# === Extract Aro ===
run(f"tar -xvjf {aro_tar} ./{aro_file}; mv {aro_file} {card_dir}")

# === Integrate Script ===
def process_card_results(card_path, aro_path, blast_columns, output_path):
    """Process CARD results and save them to an intermediate output file"""
    aro_df = pd.read_csv(aro_path, sep="\t")

    with open(card_path, "rt") as f_in, open(output_path, "w") as f_out:
        card_df = pd.read_csv(f_in, sep="\t", names=blast_columns)
        card_df["part"] = "ABR"
        # Extract ARO accession (formatted like: ARO|...|ACCESSION|...)
        card_df["ARO Accession"] = card_df["subject_id"].str.split(
            "|", expand=True
        )[2]
        merged_df = card_df.merge(aro_df, on="ARO Accession", how="left")
        merged_df.to_csv(f_out, index=False)


def process_silva_results(silva_path, blast_columns, output_path):
    """Process SILVA results and save them to an intermediate output file."""

    with open(silva_path, "rt") as f_in, open(output_path, "w") as f_out:
        silva_df = pd.read_csv(f_in, sep="\t", names=blast_columns)
        silva_df["part"] = "16S"
        # Extract the primary accession (before '.') from SILVA subject_id
        silva_df["primaryAccession"] = silva_df["subject_id"].str.split(
            ".", expand=True
        )[0]
        silva_df["genus"] = silva_df["subject_id"].str.split(";").str[-2]
        silva_df.to_csv(f_out, index=False)


def merge_results(card_output, silva_output, final_output):
    """Merge processed CARD and SILVA results into one final output file and update overview"""
    card_df = pd.read_csv(card_output)
    silva_df = pd.read_csv(silva_output)

    combined_df = pd.concat([silva_df, card_df])
    combined_df.to_csv(final_output, index=False)

    # Count number of rows in the combined DataFrame
    count = len(combined_df)

    print(f"Merged similarity hits,{count}\n")

blast_columns = [
    "query_id",
    "subject_id",
    "perc_identity",
    "align_length",
    "mismatches",
    "gap_opens",
    "q_start",
    "q_end",
    "s_start",
    "s_end",
    "evalue",
    "bit_score",
]

with concurrent.futures.ThreadPoolExecutor() as executor:
    future_card = executor.submit(
        process_card_results, card_res, aro_path, blast_columns, card_interm
    )
    future_silva = executor.submit(
        process_silva_results, silva_res, blast_columns, silva_interm
    )

    future_card.result()
    future_silva.result()

merge_results(card_interm, silva_interm, result)

# === Cleanup ===
run(f"rm {result_dir}/*intermed*; rm {card_dir}/aro*")

In [None]:
# 3. Filter Blast results
# Selfwritten python script "filter_blast_results.py"
# Input: integrated_filtered_results.csv
# Output: filtered_results.csv

import pandas as pd
import os, pathlib, subprocess

# === Paths ===
base = pathlib.Path(os.path.dirname(pathlib.Path().resolve()))
result_dir = base / ".test_steps/results"

overview_table = result_dir / "overview_table.txt"
merge_result = result_dir / "integrated_result.csv"
filter_result = result_dir / "filtered_result.csv"

# === Utils ===
def run(cmd):
    result = subprocess.run(cmd, shell=True)
    if result.returncode != 0:
        raise RuntimeError(f"Command failed: {cmd}")
    
# === Filter Script ===
dtype_dict = {
    "query_id": "string",
    "perc_identity": "float",
    "align_length": "int",
    "evalue": "float",
    "part": "string",
    "genus": "string",
    "AMR Gene Family": "string",
}


def read_input_data(input_file):
    """Load relevant columns from input file with proper dtypes"""
    return pd.read_csv(input_file, sep=",", dtype=dtype_dict, usecols=dtype_dict.keys())


def filter_by_identity(df, part, min_similarity):
    """Filter BLAST result for either ABR and 16S part based on percent identity"""
    data_pre = df[df["part"] == part]
    filtered = data_pre[data_pre["perc_identity"] > min_similarity * 100]
    filtered_count = len(data_pre) - len(filtered)
    return filtered, filtered_count


def keep_max_identity_per_query(df):
    """For each query_id, keep only rows with the highest percent identity"""
    max_identities = df.groupby("query_id")["perc_identity"].max().reset_index()
    merged = df.merge(max_identities, on=["query_id", "perc_identity"])
    return merged

def keep_best_per_query(df):
    """For each query_id, keep the row with the highest perc_identity and lowest evalue"""
    return (
        df.sort_values(
            by=["query_id"] + ["perc_identity", "evalue"], 
            ascending=[True,False, True]
            ).drop_duplicates(subset="query_id", keep="first")
    )

def clean_16s_query_ids(df):
    """Remove anything after the first whitespace in 16S query IDs"""
    df["query_id"] = df["query_id"].str.split().str[0]
    return df


def merge_parts_on_query_id(abr_data, s16_data):
    """Return only rows with query_ids present in both ABR and 16S data"""
    common_ids = pd.Index(abr_data["query_id"]).intersection(s16_data["query_id"])
    return (
        abr_data[abr_data["query_id"].isin(common_ids)],
        s16_data[s16_data["query_id"].isin(common_ids)],
    )

def write_summary(sample, stats):
    """Write all filtering summary statistics to the overview file"""
    if overview_table.is_file():
        for stat_name, value in stats.items():
            print(f"{sample},{stat_name},{value}")
    else:
        with open(overview_table, "a") as file:
            for stat_name, value in stats.items():
                file.write(f"{sample},{stat_name},{value}\n")
                print(f"{sample},{stat_name},{value}")        

def rename_for_merge(df,part):
    df_renamed = df.rename(columns={
        "perc_identity": "perc_identity_"+part,
        "align_length": "align_length_"+part,
        "evalue": "evalue_"+part,
    })
    return df_renamed

def filter_blast_results(input_file, output_file, min_similarity):
    """Main filtering logic for BLAST results across ABR and 16S data parts"""
    df = read_input_data(input_file)

    # ABR filtering
    abr_threshold_filtered, abr_removed_identity = filter_by_identity(df, "ABR", min_similarity)
    abr_best_identity = keep_max_identity_per_query(abr_threshold_filtered)
    abr_best_query = keep_best_per_query(abr_best_identity)
    abr_final = rename_for_merge(abr_best_query ,"ABR")
    abr_removed_max = len(abr_threshold_filtered) - len(abr_final)

    # 16S filtering
    s16_threshold_filtered, s16_removed_identity = filter_by_identity(df, "16S", min_similarity)
    s16_cleaned = clean_16s_query_ids(s16_threshold_filtered)
    s16_best_identity = keep_max_identity_per_query(s16_cleaned)
    s16_best_query = keep_best_per_query(s16_best_identity)
    s16_final = rename_for_merge(s16_best_query,"16S")
    s16_removed_max = len(s16_threshold_filtered) - len(s16_final)

    # Match ABR and 16S by query_id
    abr_common, s16_common = merge_parts_on_query_id(abr_final, s16_final)
    removed_query_id_mismatch = (len(abr_final) + len(s16_final)) - (
        len(abr_common)
    )

    # Merge side-by-side on query_id
    merged = pd.merge(
        abr_final[["query_id", "AMR Gene Family", "perc_identity_ABR", "align_length_ABR", "evalue_ABR"]],
        s16_final[["query_id", "genus", "perc_identity_16S", "align_length_16S", "evalue_16S"]],
        on="query_id",
        how="inner",
    )
    merged.to_csv(output_file, index=False)

    # Extract sample and part from file path
    sample = "test_epic_data"

    # Write summary
    stats = {
        "Diamond hits < similarity threshold": "-" + str(abr_removed_identity),
        "Diamond hits NOT highest percentage identity per query": "-" + str(abr_removed_max),
        "Usearch hits < similarity threshold": "-" + str(s16_removed_identity),
        "Usearch hits NOT highest percentage identity per query": "-" + str(s16_removed_max),
        "Query hit in only one of two databases": "-" + str(removed_query_id_mismatch),
        "Filtered fusion reads": len(merged),
    }
    write_summary(sample, stats)

filter_blast_results(merge_result, filter_result, 0.8)


In [None]:
# 4. Create abundance table
# Selfwritten python script "generate_genus_distribution_table.py"
# Input: all filtered_result.csv parts of one sample
# Output: abundance plot over all ABRs

import pandas as pd
import os, pathlib

# === Paths ===
base = pathlib.Path(os.path.dirname(pathlib.Path().resolve()))
result_dir = base / ".test_steps/results"

filter_result = result_dir / "filtered_result.csv"
abundance_result = result_dir / "genera_abundance.csv"

# === Abundance Table Script ===

def process_combined_data(combined_data, sample_name):
    combined_data["sample"] = sample_name

    # Count genus occurrences per AMR Gene Family
    genus_counts = (
        combined_data.groupby(["sample", "AMR Gene Family", "genus"])
        .size()
        .reset_index(name="genus_count")
    )

    # Calculate total genus count per AMR Gene Family within each sample
    total_counts = (
        genus_counts.groupby(["sample", "AMR Gene Family"])["genus_count"]
        .sum()
        .reset_index(name="total_count")
    )

    # Join and calculate relative abundance
    result = pd.merge(genus_counts, total_counts, on=["sample", "AMR Gene Family"])
    result["relative_genus_count"] = round(
        result["genus_count"] / result["total_count"], 4
    )
    return result

def load_and_merge_parts(file_list):
    """Load and merges dataframes from compressed CSV files"""
    data_frames = []
    for file in file_list:
        try:
            df = pd.read_csv(file)
            data_frames.append(df)
        except Exception as e:
            print(f"Skipping file due to read error [{file}]: {repr(e)}")
    if data_frames:
        merged_df = pd.concat(data_frames, ignore_index=True)
    else:
        merged_df = pd.DataFrame()
    return merged_df


def export_genera_abundance(input_files, output_path):
    """Group input files by sample"""
    sample_to_files = {}
    for file in [str(input_files)]:
        # Extract sample name from the file path, assuming 3rd-to-last split is the sample name
        sample = "test_epic_data"
        sample_to_files.setdefault(sample, []).append(file)

    all_data = []

    for sample_name, files in sample_to_files.items():
        merged_data = load_and_merge_parts(files)
        sample_data = process_combined_data(merged_data, sample_name)
        all_data.append(sample_data)

    final_df = pd.concat(all_data, ignore_index=True)
    final_df = final_df.sort_values(by=["sample","AMR Gene Family","genus_count"], ascending=False)

    # Export the final aggregated data to a CSV file
    final_df.to_csv(output_path, index=False)
    display(final_df)

export_genera_abundance(filter_result, abundance_result)


In [None]:
# 5. Create stacked bar abundance plot
# Selfwritten python script "generate_genus_distribution_plot.py"
# Input: abundance file
# Output: bubble plot per sample

import os, pathlib
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# === Paths ===
base = pathlib.Path(os.path.dirname(pathlib.Path().resolve()))
result_dir = base / ".test_steps/results"

abundance_result = result_dir / "genera_abundance.csv"
abundance_bar_plot = result_dir / "combined_genus_abundance_barplot.html"

# ─── Constants ─────────────────────────────────────────────────────────
RESERVED_COLOR = 'rgb(217,217,217)'
AMR_MIN_FRACTION = 0.01

def get_genus_colors(all_genera):
    """Assign consistent, distinguishable colors to each genus."""
    top_colors = [
        '#D62728',  # dark red
        '#FF7F0E',  # orange
        '#8B4513',  # brown
        '#1F77B4',  # dark blue
        '#800080',  # purple
        '#7F7F7F',  # gray
        '#2CA02C',  # dark green
        '#1E90FF',  # blue
        '#BA55D3',  # medium orchid
        '#BCBD22',  # yellow-green
    ]

    fallback_palette = (
        px.colors.qualitative.Pastel +
        px.colors.qualitative.Set3 +
        px.colors.qualitative.Alphabet +
        px.colors.qualitative.Light24 +
        px.colors.qualitative.Bold
    )

    # Remove duplicates and reserved color from palette
    color_pool = list(dict.fromkeys(top_colors + fallback_palette))
    if RESERVED_COLOR in color_pool:
        color_pool.remove(RESERVED_COLOR)

    # Assign genera with a unique color each
    genus_list = [g for g in all_genera if g != "Others"]
    if len(genus_list) > len(color_pool):
        raise ValueError(f"Too many genera ({len(genus_list)}) for available color pool.")
    genus_colors = {g: color_pool[i] for i, g in enumerate(genus_list)}
    genus_colors["Others"] = RESERVED_COLOR
    return genus_colors

def preprocess_abundance(df, amr, min_genus_abundance, force_include, force_exclude):
    """Filter and aggregate genus abundance data for a given AMR family."""
    df_amr = df[df["AMR Gene Family"] == amr].copy()

    # Determine low-abundance or excluded genera
    low_abundance = df_amr[
        ((df_amr["relative_genus_count"] <= min_genus_abundance) & (~df_amr["genus"].isin(force_include))) |
        (df_amr["genus"].isin(force_exclude))
    ]
    others = (
        low_abundance.groupby(['sample', 'total_count'], as_index=False)
        .agg({"relative_genus_count": "sum"})
        .assign(genus="Others")
    )
    others["sample_label"] = others["sample"] + " (" + others["total_count"].astype(str) + ")"

    # Remove excluded genera
    df_amr = df_amr[~df_amr["genus"].isin(force_exclude)]
    df_amr = df_amr.sort_values(by=['sample','AMR Gene Family','genus_count'],ascending=[True,False,False])
    # plot high abundance or forced-includes
    df_amr_filtered = df_amr[
        (df_amr["relative_genus_count"] > min_genus_abundance) | (df_amr["genus"].isin(force_include))
    ]

    # Add "Others"
    df_final = pd.concat([df_amr_filtered, others], ignore_index=True)
    df_final["sample_label"] = df_final["sample"] + " (" + df_final["total_count"].astype(str) + ")"
    return df_final


def plot_stacked_abundance(
    observed_csv,
    output_html,
    min_genus_abundance,
    force_include=None,
    force_exclude=None,
    min_reads=20
):

    force_include = force_include or []
    force_exclude = force_exclude or []

    df = pd.read_csv(observed_csv)
    df = df.sort_values(["sample", "genus_count"], ascending=[True, False])
    df = df[df["total_count"] > min_reads]
    amr_totals = df.groupby("AMR Gene Family")["total_count"].sum()
    total_all = amr_totals.sum()
    amrs_to_plot = amr_totals[amr_totals >= total_all * AMR_MIN_FRACTION].index.tolist()

    if not amrs_to_plot:
        print("No AMR Gene Families meet the abundance threshold.")
        return

    df = df[df["AMR Gene Family"].isin(amrs_to_plot)]
    amrs = sorted(df["AMR Gene Family"].unique())
    samples = df["sample"].nunique()

    fig = make_subplots(
        rows=len(amrs),
        cols=1,
        subplot_titles=amrs,
        vertical_spacing=0.15,
    )

    for i, amr in enumerate(amrs, start=1):
        df_amr = preprocess_abundance(
            df, amr, min_genus_abundance, force_include, force_exclude
        )
        genus_colors = get_genus_colors(df_amr["genus"].unique())

        legendgroup = f"group{i}"  # unique group per subplot
        for genus in df_amr["genus"].unique():
            genus_data = df_amr[df_amr["genus"] == genus]
            fig.add_trace(
                go.Bar(
                    x=genus_data["sample_label"],
                    y=genus_data["relative_genus_count"],
                    name=genus,
                    marker_color=genus_colors[genus],
                    legendgroup=legendgroup,
                    legendgrouptitle=dict(text=amr) if genus == df_amr["genus"].unique()[0] else None,
                    showlegend=True,
                ),
                row=i,
                col=1,
            )

        # Custom legend positioning for each subplot (optional, only needed if separating legends visually)
        fig.update_layout(
            legend=dict(
                y=1,
                yanchor="top",
                x=2.5-np.log10(samples),
                xanchor="left",
                tracegroupgap=500  # adds spacing between legend groups
            ),
            margin=dict(r=300)  # enough space for long legends
        )

    fig.update_layout(
        barmode="stack",
        title="Relative Genus Abundance per AMR Gene Family",
        height=800 * len(amrs),
        width=1000 * np.log10(samples) if samples > 2 else 500,
        plot_bgcolor="white",
        yaxis=dict(tickformat=".0%"),
        showlegend=True,
        margin=dict(r=300),
    )
    fig.update_xaxes(tickangle=45)
    fig.update_yaxes(title_text="Relative Abundance",categoryorder="array",categoryarray=sorted(df_amr["sample_label"].unique()))

    fig.show()
    # fig.write_html(output_html)


if __name__ == "__main__":
    input_csv = "/local/work/adrian/ERMA/results/abundance/combined_genus_abundance.csv"
    output_html = ""
    min_abundance = 0.01
    #sys.stderr = open(snakemake.log[0], "w")
    plot_stacked_abundance(input_csv, output_html, float(min_abundance))


In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.express as px

RESERVED_COLOR = "rgb(217,217,217)"

def get_genus_colors(all_genera):
    """Assign consistent, distinguishable colors to each genus."""
    top_colors = [
        "#D62728", "#FF7F0E", "#8B4513", "#1F77B4", "#800080",
        "#7F7F7F", "#2CA02C", "#1E90FF", "#BA55D3", "#BCBD22"
    ]

    fallback_palette = (
        px.colors.qualitative.Pastel
        + px.colors.qualitative.Set3
        + px.colors.qualitative.Alphabet
        + px.colors.qualitative.Light24
        + px.colors.qualitative.Bold
    )

    color_pool = list(dict.fromkeys(top_colors + fallback_palette))
    if RESERVED_COLOR in color_pool:
        color_pool.remove(RESERVED_COLOR)

    genus_list = [g for g in all_genera if g != "Others"]
    if len(genus_list) > len(color_pool):
        raise ValueError(
            f"Too many genera ({len(genus_list)}) for available color pool."
        )

    genus_colors = {g: color_pool[i] for i, g in enumerate(genus_list)}
    genus_colors["Others"] = RESERVED_COLOR
    return genus_colors


def preprocess_abundance(df, min_genus_abundance, force_include, force_exclude):
    """Filter and aggregate genus abundance data."""
    df_amr = df.copy()

    # Group low-abundance or excluded genera into 'Others'
    low_abundance = df_amr[
        ((df_amr["relative_genus_count"] <= min_genus_abundance) & ~df_amr["genus"].isin(force_include))
        | df_amr["genus"].isin(force_exclude)
    ]
    others = (
        low_abundance.groupby(["sample", "total_count"], as_index=False)
        .agg({"relative_genus_count": "sum"})
        .assign(genus="Others")
    )
    others["sample_label"] = others["sample"] + " (" + others["total_count"].astype(str) + ")"

    # Keep only included & high-abundance genera
    df_amr = df_amr[~df_amr["genus"].isin(force_exclude)]
    df_amr_filtered = df_amr[
        (df_amr["relative_genus_count"] > min_genus_abundance)
        | df_amr["genus"].isin(force_include)
    ]
    df_amr_filtered = df_amr_filtered.sort_values(["sample", "AMR Gene Family", "genus_count"], ascending=[True, False, False])
    df_amr_filtered["sample_label"] = df_amr_filtered["sample"] + " (" + df_amr_filtered["total_count"].astype(str) + ")"

    return pd.concat([df_amr_filtered, others], ignore_index=True)


def plot_stacked_abundance(
    observed_csv,
    output_html=None,
    min_genus_abundance=0.01,
    force_include=None,
    force_exclude=None,
):
    """Plot a single stacked bar chart of genus abundance for most prevalent AMR per sample."""
    force_include = force_include or []
    force_exclude = force_exclude or []

    df = pd.read_csv(observed_csv)
    amr = df.groupby("AMR Gene Family")["total_count"].sum().idxmax()
    print(amr)
    # Keep only AMR with highest total_count per sample
    df_max = df[df["AMR Gene Family"] == amr]
    display(df_max)
    amrs = df_max["AMR Gene Family"].unique()
    samples = df_max["sample"].nunique()

    df_amr = preprocess_abundance(df_max, min_genus_abundance, force_include, force_exclude)
    genus_colors = get_genus_colors(df_amr["genus"].unique())

    fig = go.Figure()
    legend_added = set()

    for genus in df_amr["genus"].unique():
        genus_data = df_amr[df_amr["genus"] == genus]
        fig.add_trace(
            go.Bar(
                x=genus_data["sample_label"],
                y=genus_data["relative_genus_count"],
                name=genus,
                marker_color=genus_colors[genus],
                showlegend=genus not in legend_added,
            )
        )
        legend_added.add(genus)

    # Layout
    fig.update_layout(
        barmode="stack",
        title=f"Relative Genus Abundance (Most Abundant AMR Gene Family: {', '.join(amrs)})",
        height=800,
        width=1000 * np.log10(samples) if samples > 2 else 500,
        plot_bgcolor="white",
        legend_title="Genus",
        xaxis=dict(tickangle=45),
        yaxis=dict(title="Relative Abundance", tickformat=".0%"),
    )

    if output_html:
        fig.write_html(output_html)
    else:
        fig.show()



if __name__ == "__main__":
    input_csv = "/local/work/adrian/ERMA/results/abundance/combined_genus_abundance.csv"
    output_html = ""
    min_abundance = 0.01
    #sys.stderr = open(snakemake.log[0], "w")
    plot_stacked_abundance(input_csv, output_html, float(min_abundance))


In [None]:
# 6. Create bubble plot
# Selfwritten python script "generate_genus_distribution_plot.py"
# Input: abundance file
# Output: bubble plot per sample

import pandas as pd
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.io as pio
import os, pathlib

# === Paths ===
base = pathlib.Path(os.path.dirname(pathlib.Path().resolve()))
result_dir = base / ".test_steps/results"

abundance_result = result_dir / "genera_abundance.csv"
bubble_plot = result_dir / "combined_genus_abundance_bubbleplot.html"

# === Bubble Plot Script ===
def load_filtered_data(input_csv, min_total_count=100):
    """Load CSV and filter AMR Gene Families by minimum total genus count"""
    df = pd.read_csv(input_csv, sep=",")
    return df[df["total_count"] > min_total_count]


def get_top_genera_per_sample(df, top_n):
    """Return dicts of top genera per sample (set and list forms)"""
    top_sets = {}
    top_lists = {}
    for sample in df["sample"].unique():
        sample_df = df[df["sample"] == sample]
        top = (
            sample_df.sort_values(by="relative_genus_count", ascending=False)
            .head(top_n)["genus"]
            .tolist()
        )
        top_sets[sample] = set(top)
        top_lists[sample] = top
    return top_sets, top_lists


def select_genera(top_sets, top_lists, max_genera, min_overlap):
    """Select a list of genera to display using overlap or merged ranking"""
    if not top_sets:
        return []

    overlap = set.intersection(*top_sets.values())
    total_genus = sum(len(lst) for lst in top_lists.values())

    if len(overlap) >= min_overlap:
        return list(overlap)[:max_genera]
    elif total_genus > max_genera:
        combined = set(overlap)
        sample_iters = {s: iter(l) for s, l in top_lists.items()}

        while len(combined) < max_genera:
            for gen_iter in sample_iters.values():
                try:
                    while True:
                        genus = next(gen_iter)
                        if genus not in combined:
                            combined.add(genus)
                            break
                except StopIteration:
                    continue
                if len(combined) >= max_genera:
                    break
        return list(combined)
    else:
        return list({genus for sublist in top_lists.values() for genus in sublist})


def add_amr_family_subplot(
    fig, df, amr_family, col_idx, max_genera, min_overlap, top_per_sample
):
    """Filter and add a subplot for one AMR Gene Family to the main figure"""
    df_amr = df[df["AMR Gene Family"] == amr_family]
    if df_amr.empty:
        return

    top_sets, top_lists = get_top_genera_per_sample(df_amr, top_per_sample)
    selected = select_genera(top_sets, top_lists, max_genera, min_overlap)
    df_plot = df_amr[df_amr["genus"].isin(selected)]

    scatter = px.scatter(
        df_plot,
        x="sample",
        y="genus",
        size="relative_genus_count",
        color="total_count",
        hover_name="genus",
        hover_data={
            "genus_count": True,
            "relative_genus_count": True,
            "total_count": True,
            "sample": False,
        },
        size_max=20,
        color_continuous_scale="Greens",
    )

    for trace in scatter.data:
        fig.add_trace(trace, row=1, col=col_idx)


def create_bubble_plot_grid(df, max_genera, min_overlap, top_per_sample):
    """Create the full multi-subplot bubble chart"""
    families = df["AMR Gene Family"].unique()
    num_cols = len(families) if len(df) > 1 else 1

    fig = make_subplots(
        rows=1,
        cols=num_cols,
        subplot_titles=list(families),
        horizontal_spacing=0.2,
    )

    for idx, family in enumerate(families, start=1):
        add_amr_family_subplot(
            fig, df, family, idx, max_genera, min_overlap, top_per_sample
        )

    fig.update_layout(
        title="Bubble Plots of Top Genera for Each AMR Gene Family",
        plot_bgcolor="lightgrey",
        height=900,
        width=500 * num_cols,
        coloraxis_colorbar=dict(title="Fusion Read Count"),
    )
    fig.update_yaxes(categoryorder="category descending")
    fig.update_xaxes(categoryorder="category ascending")

    return fig


def create_bubble_plots_combined(
    input_csv, output_html, max_genera=20, min_overlap=10, top_per_sample=20
):
    """Load input, pass to processing function and save plot"""
    df = load_filtered_data(input_csv)
    fig = create_bubble_plot_grid(df, max_genera, min_overlap, top_per_sample)
    pio.write_html(fig, file=output_html)
    pio.show(fig)

create_bubble_plots_combined(abundance_result, bubble_plot)

In [None]:
# 7. Create boxplots
# Selfwritten python script "percidt_per_genus.py"
# Input: all filtered_result.csv parts of one sample
# Output: boxplot over all samples per percentage identity, number of unique hits and genera

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os, pathlib

# === Paths ===
base = pathlib.Path(os.path.dirname(pathlib.Path().resolve()))
result_dir = base / ".test_steps/results"

filter_result = result_dir / "filtered_result.csv"
boxplot = result_dir / "genus_idt_per_genus_plot.png"


def generate_percentage_idt_per_genus(input_files, output_file):
    all_data = []  # List to hold DataFrames from all input files

    for input_file in [str(input_files)]:
        df = pd.read_csv(
            input_file,
            sep=",",
            header=0,
        )
        all_data.append(df)

    # Combine all partitions into a single DataFrame
    combined_data = pd.concat(all_data)

    # Calculate genus query counts
    genus_query_counts = (
        combined_data.groupby("genus")["query_id"].nunique().reset_index()
    )
    genus_query_counts.columns = ["genus", "unique_query_count"]

    # Keep only the top 20 genera
    top20_species = genus_query_counts.nlargest(20, "unique_query_count")

    # Filter combined_data to retain only the top 20 genera
    combined_data = combined_data[combined_data["genus"].isin(top20_species["genus"])]

    # Now filter genus_query_counts as well
    genus_query_counts = genus_query_counts[
        genus_query_counts["genus"].isin(top20_species["genus"])
    ]

    # Define order for the x-axis
    genus_order = top20_species.sort_values(by="unique_query_count", ascending=False)[
        "genus"
    ]

    # Plotting
    fig, ax1 = plt.subplots(figsize=(15, 8))
    sns.boxplot(
        x="genus",
        y="perc_identity_16S",
        data=combined_data,
        ax=ax1,
        order=genus_order,
        fliersize=0.0,
        color="dodgerblue",
    )
    ax1.set_xlabel("Bacterial Genus")
    ax1.set_ylabel("Percentage Identity (boxplot)", color="royalblue")
    ax1.set_title(
        "Boxplot of Percentage Identity and Read Counts for Each Bacterial Genus"
    )
    ax1.set_xticklabels(ax1.get_xticklabels(), rotation=90)

    # Add a second y-axis for unique query counts
    ax2 = ax1.twinx()
    sns.barplot(
        x="genus",
        y="unique_query_count",
        data=genus_query_counts,
        ax=ax2,
        alpha=0.2,
        color="purple",
        order=genus_order,
    )
    ax2.set_ylabel("Number of hits (bar)", color="violet")

    plt.tight_layout()
    plt.savefig(output_file)
    plt.show()
    plt.close()

generate_percentage_idt_per_genus(filter_result, boxplot)


In [None]:
# 7. Create boxplots
# Selfwritten python scripts "boxplot_[align_lengths,evalue,percidt].py"
# Input: all filtered_result.csv parts of one sample
# Output: boxplot over all samples per parameter alignment lengths, E-value or percentage identity

import pandas as pd
import seaborn as sns
import os, pathlib
import matplotlib.pyplot as plt

"""
This script takes a list of all filtered fasta files, combines e-value information 
across samples, and visualizes the distribution of e-values using boxplots split 
by part (ABR/16S) and sample.
"""

PRETTY_LABELS = {
    "align_length": "Alignment length",
    "perc_identity": "Percentage identity",
    "evalue": "E-value"
}

def read_and_process_partitioned_data(partition_files, sample, param):
    """Read and process partitioned files for a single sample."""
    data_frames = []
    sample_name = sample
    param = param
    for part_file in partition_files:
        if os.path.exists(part_file):
            df = pd.read_csv(
                part_file, header=0, sep=","
            )
            #df[f"{param}_ABR"] = df[f"{param}_ABR"] * 3
            long_df = pd.melt(
                df,
                id_vars=["query_id"],
                value_vars=[param + "_ABR", param + "_16S"],
                var_name="part",
                value_name=param
            )

            # Normalize part labels
            long_df["part"] = long_df["part"].str.replace(param + "_", "")
            long_df["sample"] = sample_name
            data_frames.append(long_df)
        
    if data_frames:
        return pd.concat(data_frames)
    else:
        return None


def plot_boxplots(data, output_file):
    """
    Generate and save boxplots of e-values across samples and parts (ABR vs. 16S).

    Args:
        data (pd.DataFrame): Combined dataframe containing 'sample', 'evalue', and 'part'.
        output_file (str): Path to save the resulting plot.
    """
    plt.figure(figsize=(15, 10))
    flierprops = dict(markerfacecolor="0.75", markersize=2, linestyle="none")
    sns.boxplot(x="sample", y="perc_identity", hue="part", data=data, flierprops=flierprops)
    #plt.yscale("log")
    plt.title("Boxplot of e-values for ABR and 16S parts across samples -Filtered-")
    plt.xlabel("Sample")
    plt.ylabel("Percentage identity")
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()
    plt.close()


def main(filtered_fasta_files, sample_names, param, output_file):
    """Main function to process partitioned files for each sample and generate the plot."""
    all_data = []

    # Loop over each sample's partitioned CSV files
    for sample in sample_names:
        data = read_and_process_partitioned_data(
            [file for file in filtered_fasta_files], sample, param
        )
        if data is not None:
            all_data.append(data)

    if all_data:
        combined_data = pd.concat(all_data)
        plot_boxplots(combined_data, output_file)
    else:
        print("No data found.")


if __name__ == "__main__":
    base = pathlib.Path(os.path.dirname(pathlib.Path().resolve()))
    result_dir = base / ".test_steps/results"

    filter_result = result_dir / "filtered_result.csv"
    boxplot = result_dir / f"combine_boxplot.png"
    
    filtered_fasta_files = filter_result
    output_file = boxplot  # Single output file for all panels
    sample_names = "test_epic_data"
    param = "perc_identity"
    main([str(filtered_fasta_files)], [sample_names], param, output_file)


In [None]:
# 8. Create Attrition plot
# Selfwritten python scripts "plot_attrition.py"
# Input: overview table
# Output: plot of count overview throughout ERMA process with respect to rejection breakdown

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os, pathlib

# === Paths ===
base = pathlib.Path(os.path.dirname(pathlib.Path().resolve()))
sample = "test_epic_data"
result_dir = base / ".test_steps/results"
overview_table = result_dir / "overview_table.txt"
overview_plot = result_dir / "overview_plot.png"

# === Category Definitions (now match the final labels directly) ===
MAIN_CATEGORIES = [
    "Number of FastQ input reads",
    "Merged similarity hits",
    "Filtered fusion reads",
]

FILTER_REASONS = {
    "Diamond hits < similarity threshold": "royalblue",
    "Diamond hits NOT highest percentage identity per query": "purple",
    "Usearch hits < similarity threshold": "#a6d854",
    "Usearch hits NOT highest percentage identity per query": "#66c2a5",
    "Query hit in only one of two databases": "#ffd92f",
}

MAIN_COLOR_MAP = {
    "Number of FastQ input reads": "seagreen",
    "Merged similarity hits": "#fc8d62",
    "Filtered fusion reads": "#8da0cb",
}

# === Load and summarize the table ===
def load_and_summarize_data(path):
    df = pd.read_csv(path, names=["sample", "state", "count"])
    df["count"] = df["count"].astype(int).abs()

    main_df = df[df["state"].isin(MAIN_CATEGORIES)].pivot(index="sample", columns="state", values="count").fillna(0)
    filter_df = df[df["state"].isin(FILTER_REASONS)].pivot(index="sample", columns="state", values="count").fillna(0)

    return main_df, filter_df

# === Plotting function ===
def plot_summary(main_df, filter_df, output_path):
    samples = main_df.index
    x = np.arange(len(samples))
    bar_width = 0.18
    overlay_width = 0.1

    fig, ax = plt.subplots(figsize=(12, 7))

    # Plot main bars with offsets
    offsets = np.linspace(-bar_width, bar_width, len(MAIN_CATEGORIES))
    for i, col in enumerate(MAIN_CATEGORIES):
        if col not in main_df.columns:
            continue
        ax.bar(
            x + offsets[i],
            main_df[col],
            bar_width,
            label=col,
            color=MAIN_COLOR_MAP.get(col, "gray"),
        )

    # Plot filter stack bars *on top* of "Filtered fusion reads"
    if "Filtered fusion reads" in main_df.columns:
        bottom = main_df["Filtered fusion reads"].values.copy()
    else:
        bottom = np.zeros_like(x)

    for reason in FILTER_REASONS:
        heights = filter_df[reason].values if reason in filter_df.columns else np.zeros_like(x)
        ax.bar(
            x + bar_width,
            heights,
            overlay_width,
            bottom=bottom,
            label=reason,
            color=FILTER_REASONS.get(reason, "gray"),
        )
        bottom += heights

    # Axis formatting
    ax.set_xticks(x)
    ax.set_xticklabels(samples, rotation=45)
    ax.set_ylabel("Similarity search hit count")
    ax.set_xlabel("Sample")
    ax.set_title("Similarity Search Processing with Rejection Breakdown")

    # Split legend into main vs. filter
    handles, labels = ax.get_legend_handles_labels()
    main_labels = MAIN_CATEGORIES
    filter_labels = FILTER_REASONS

    legend1 = ax.legend(
        [handles[labels.index(l)] for l in main_labels if l in labels],
        main_labels,
        loc="upper left",
        bbox_to_anchor=(1.02, 1),
        title="Hit Process",
    )
    legend2 = ax.legend(
        [handles[labels.index(l)] for l in filter_labels if l in labels],
        filter_labels,
        loc="upper left",
        bbox_to_anchor=(1.02, 0.55),
        title="Filtering Reasons",
    )
    ax.add_artist(legend1)

    plt.tight_layout()
    plt.savefig(output_path)
    plt.show()

# === Execute ===
main_df, filter_df = load_and_summarize_data(overview_table)
plot_summary(main_df, filter_df, overview_plot)


In [None]:
# 9. Create Abundance data
# Selfwritten python script "single_genera_abundance_table.py"
# Input: Overview table created iteritavely within the snakemake run
# Output: barplots for all samples showing generated and filtered similarity search hits
# Note: Overview Table is created here after the process while in the original snakemake run
#       it's created iteratively within the workflow.

import pandas as pd
import os, sys

def write_dummy_line(sample_name):
    dummy_line = {
        "sample": sample_name,
        "AMR Gene Family": "NA",
        "genus": "NA",
        "genus_count": 0,
        "total_count": 0,
        "relative_genus_count": 0,
    }
    merged_data = pd.DataFrame([dummy_line])
    return merged_data

def process_combined_data(combined_data, sample_name):

    combined_data["sample"] = sample_name

    genus_counts = (
        combined_data.groupby(["sample", "AMR Gene Family", "genus"])
        .size()
        .reset_index(name="genus_count")
    )

    total_counts = (
        genus_counts.groupby(["sample", "AMR Gene Family"])["genus_count"]
        .sum()
        .reset_index(name="total_count")
    )

    genus_counts = pd.merge(
        genus_counts, total_counts, on=["sample", "AMR Gene Family"], how="left"
    )
    genus_counts["relative_genus_count"] = round(
        genus_counts["genus_count"] / genus_counts["total_count"], 4
    )

    return genus_counts


def export_genera_abundance(input_files, sample_name, parts, output_path):
    sample_input_files = [f for f in input_files]
    part_dfs = []
    for part in parts:
        matching_files = [f for f in sample_input_files]
        print(sample_input_files,matching_files)
        if not matching_files:
            continue
        input_file = matching_files[0]
        df = pd.read_csv(
            input_file, sep=",",  header=0
        )
        part_dfs.append(df)

    if not part_dfs:
        print(f"No valid parts found for sample: {sample_name}")
        dummy_df = write_dummy_line(sample_name)
        dummy_df.to_csv(output_path, index=False)
        return        

    full_sample_df = pd.concat(part_dfs, ignore_index=True)
    processed_data = process_combined_data(full_sample_df, sample_name)

    processed_data = processed_data.sort_values(
        by=["sample", "genus_count"], ascending=False
    )

    display(processed_data)


if __name__ == "__main__":
    base = pathlib.Path(os.path.dirname(pathlib.Path().resolve()))
    result_dir = base / ".test_steps/results"

    filter_result = result_dir / "filtered_result.csv"
    table = result_dir / f"single_abundance_table.csv"
    
    filtered_fasta_files = filter_result
    
    input_file = filter_result
    output_path = table
    sample_name = "test_epic_data"    
    parts = ["001"]
    export_genera_abundance([str(input_file)], sample_name, parts, output_path)


In [None]:
import pandas as pd
import pathlib
from IPython.core.display import HTML

# === Paths ===
base = pathlib.Path().resolve()
result_dir = base / "results"
overview_table = result_dir / "overview_table.txt"
overview_html = "overview_table.html"

# Read the input table
df = pd.read_csv(overview_table, sep=",", header=None, names=["sample","step","total_count"])

# Mapping step -> State
step_to_state = {
    "Number of FastQ input reads": "Input reads",
    "Diamond output hits": "Similarity search",
    "Usearch output hits": "Similarity search",
    "Merged similarity hits": "Similarity search",
    "Diamond hits < similarity threshold": "Filtration",
    "Diamond hits NOT highest percentage identity per query": "Filtration",
    "Usearch hits < similarity threshold": "Filtration",
    "Usearch hits NOT highest percentage identity per query": "Filtration",
    "Query hit in only one of two databases": "Filtration",
    "Filtered fusion reads": "Output reads"
}

df["state"] = df["step"].map(step_to_state)

# Reorder and sort
df = df[["sample", "state", "step", "total_count"]]
state_order = ["Input reads", "Similarity search", "Filtration", "Output reads"]
df["state"] = pd.Categorical(df["state"], categories=state_order, ordered=True)
df = df.sort_values(by=["sample", "state"])

# === HTML with rowspan for merged cells ===

html = """
<html>
<head>
<style>
    table.styled-table {
        border-collapse: collapse;
        margin: 25px 0;
        font-size: 0.95em;
        font-family: sans-serif;
        min-width: 600px;
        box-shadow: 0 0 10px rgba(0, 0, 0, 0.15);
    }
    table.styled-table thead tr {
        background-color: #009879;
        color: #ffffff;
        text-align: left;
    }
    table.styled-table th,
    table.styled-table td {
        padding: 10px 12px;
        border: 1px solid #ddd;
    }
    table.styled-table tbody tr:nth-child(even) {
        background-color: #f3f3f3;
    }
</style>
</head>
<body>
<table class="styled-table">
<thead>
    <tr><th>Sample</th><th>State</th><th>Step</th><th>Count</th></tr>
</thead>
<tbody>
"""

# Group and track rowspans
grouped = df.groupby(["sample", "state"])
for (sample, state), group in grouped:
    sample_rowspan = len(df[df["sample"] == sample])
    state_rowspan = len(group)
    
    first_state = True
    for i, row in group.iterrows():
        html += "<tr>"
        if i == df[df["sample"] == sample].index[0]:
            html += f'<td rowspan="{sample_rowspan}">{sample}</td>'
        if first_state:
            html += f'<td rowspan="{state_rowspan}">{state}</td>'
            first_state = False
        html += f"<td>{row['step']}</td><td>{row['total_count']}</td>"
        html += "</tr>"

html += """
</tbody>
</table>
</body>
</html>
"""
display(HTML(html))
# Write to file
with open(overview_html, "w") as f:
    f.write(html)


In [None]:
import pandas as pd
import pathlib
from IPython.core.display import HTML

# === Paths ===
base = pathlib.Path().resolve()
result_dir = base / "results"
overview_table = result_dir / "genera_abundance.csv"
overview_html = ""

# Read the input table
df = pd.read_csv(overview_table, sep=",", header=0)

# === HTML with rowspan for merged cells ===

html = """
<html>
<head>
<style>
    table.styled-table {
        border-collapse: collapse;
        margin: 25px 0;
        font-size: 0.95em;
        font-family: sans-serif;
        min-width: 600px;
        box-shadow: 0 0 10px rgba(0, 0, 0, 0.15);
    }
    table.styled-table thead tr {
        background-color: #009879;
        color: #ffffff;
        text-align: left;
    }
    table.styled-table th,
    table.styled-table td {
        padding: 10px 12px;
        border: 1px solid #ddd;
    }
    table.styled-table tbody tr:nth-child(even) {
        background-color: #f3f3f3;
    }
    table.styled-table tbody tr:hover {
        background-color: #f1f1f1;
    }
</style>
</head>
<body>
<table class="styled-table">
<thead>
    <tr><th>Sample</th><th>AMR Gene Family</th><th>Genus</th><th>Fusion Read Count</th><th>Relative</th></tr>
</thead>
<tbody>
"""

# Group and track rowspans
grouped = df.groupby(["sample", "AMR Gene Family"])
for (sample, family), group in grouped:
    sample_rowspan = len(df[df["sample"] == sample])
    family_rowspan = len(group)
    amr = df[(df["sample"] == sample) & (df["AMR Gene Family"] == family)]
    reads_per_amr = amr["genus_count"].sum()
    amr_line = f"{family}<br><span style='font-size: 0.85em'> Total Fusion Reads: {reads_per_amr}</span>"
    first_family = True
    for i, row in group.iterrows():
        html += "<tr>"
        if i == df[df["sample"] == sample].index[0]:
            html += f'<td rowspan="{sample_rowspan}">{sample}</td>'
        if first_family:
            html += f'<td rowspan="{family_rowspan}">{amr_line}</td>'
            first_family = False
        html += f"<td>{row['genus']}</td><td>{row['genus_count']}</td><td>{row['relative_genus_count']}</td>"
        html += "</tr>"

html += """
</tbody>
</table>
</body>
</html>
"""
display(HTML(html))
# Write to file
#with open(overview_html, "w") as f:
#    f.write(html)


In [None]:
import os
import json
import shutil

# Root log directory (adjust if needed)
log_dir = "/local/work/adrian/ERMA/logs"   # replace with your actual path
output = "/local/work/adrian/ERMA/logs/logs.json"

logs = {}

for rule in sorted(os.listdir(log_dir)):
    rule_path = os.path.join(log_dir, rule)
    if not os.path.isdir(rule_path):
        continue

    rule_logs = {}
    for log_file in sorted(os.listdir(rule_path)):
        file_path = os.path.join(rule_path, log_file)

        # Skip empty files
        if os.path.getsize(file_path) == 0:
            continue

        # Use filename without extension as sample name
        sample = os.path.splitext(log_file)[0]

        # Read log text
        with open(file_path, "r", encoding="utf-8", errors="replace") as f:
            text = f.read().strip()

        rule_logs[sample] = text

    if rule_logs:  # only keep non-empty rules
        logs[rule] = rule_logs

# Write JSON
with open(output, "w", encoding="utf-8") as out:
    json.dump(logs, out, indent=2, ensure_ascii=False)

print("logs.json written with", sum(len(v) for v in logs.values()), "log entries.")

# Remove all subfolders in the log directory
for rule in os.listdir(log_dir):
    rule_path = os.path.join(log_dir, rule)
    if os.path.isdir(rule_path):
        shutil.rmtree(rule_path)

print("All subfolders removed, only logs.json remains.")
