In [1]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import glob
import hashlib
import gzip
import numpy as np
import re

In [2]:
def get_md5sum(x):
    return hashlib.md5(x.encode("utf-8")).hexdigest()[:10]

In [3]:
dataset = "Horse"
bone_path=f"/data/ancient{dataset}/results/assembly-evaluation-quast/*report.tsv.gz"
dataset = "Gut"
gut_path=f"/data/ancient{dataset}/results/assembly-evaluation-quast/*report.tsv.gz"
dataset = "Calc"
calc_path=f"/data/ancient{dataset}/results/assembly-evaluation-quast/*report.tsv.gz"

In [None]:
labels = ["gut_sum_high_c3", "gut_sum_high_c5", "gut_sum_high_c10", \
 "calc_2095_high_c3", "calc_2095_high_c5", "calc_2095_high_c10", \
 "horse_sum_high_c3", "horse_sum_high_c5", "horse_sum_high_c10", ]

labels_clean = [
    "Gut:\nHigh Damage; Cov. 3X",
    "Gut:\nHigh Damage; Cov. 5X",
    "Gut:\nHigh Damage; Cov. 10X",
    "Calculus:\nHigh Damage; Cov. 3X",
    "Calculus:\nHigh Damage; Cov. 5X",
    "Calculus:\nHigh Damage; Cov. 10X",
    "Bone:\nHigh Damage; Cov. 3X",
    "Bone:\nHigh Damage; Cov. 5X",
    "Bone:\nHigh Damage; Cov. 10X",
]

labels_dict = {key: get_md5sum(key) for key in labels}
labels_dict_inv = {value: key for key, value in labels_dict.items()}
print(labels_dict_inv)

labels_dict_clean = {labels[i] : labels_clean[i] for i in range(len(labels))}

In [5]:
def map_assembler(cell):
    if "carpedeam" in cell:
        return "CarpeDeam"
    elif "penguin" in cell:
        return "PenguiN"
    elif "megahit" in cell:
        return "MEGAHIT"
    elif "spades" in cell:
        return "metaSPAdes"
    else:
        return cell  # Return the cell as is if none of the conditions are met

In [6]:
def adjust_assemblerconfig(row):
    if row["assembler_clean"] == "CarpeDeam":
        if "carpedeamSafe" in row["assemblerconfig"]:
            return "CarpeDeam\n(safe mode)"
        elif "carpedeamUnsafe" in row["assemblerconfig"]:
            return "CarpeDeam\n(unsafe mode)"
        else:
            return "CarpeDeam"
    else:
        return row["assembler_clean"]

In [7]:
def curate_df(path_tsv):
    files = glob.glob(path_tsv)
    dfs = []
    for file in files:
        df = pd.read_csv(file, compression='gzip', sep='\t')
        dfs.append(df)
    big_df = pd.concat(dfs, ignore_index=True)
    big_df["assemblerconfig"] = big_df["assembler"] + " " + big_df["config"]
    big_df["assembler_clean"] = big_df["assembler"].apply(map_assembler)
    big_df["assembler_final"] = big_df.apply(adjust_assemblerconfig, axis=1)
    big_df["label"] = big_df["label"].astype(str)
    big_df["label_human"] = big_df["label"].map(labels_dict_inv)
    big_df["label_clean"] = big_df["label_human"].replace(labels_dict_clean)
    big_df["dataset_clean"] = big_df["label_clean"].str.split(":").str[0]
    big_df["mis_per_contig"] = big_df["num_misassemblies"]/big_df["num_contigs_ge_0_bp"]
    big_df["mis_per_aln_base"] = big_df["num_misassemblies"]/big_df["total_aligned_length"]
    return big_df

In [None]:
bone = curate_df(bone_path)
calc = curate_df(calc_path)
gut = curate_df(gut_path)
print(gut)
dfs = [bone, calc, gut]
big_df = pd.concat(dfs, ignore_index=True)

In [9]:
main=['carpedeam2 configSafe', 'carpedeam2 configUnsafe', 'megahit config0', 'penguin config0', 'spades config0'] 
main_df = big_df[big_df["assemblerconfig"].isin(main)]

In [None]:
pal = sns.color_palette("pastel")
print(pal.as_hex())
pal

In [13]:
def plot_metrics_half(df, metrics, title, rows, start_row, num):
    fig_width = 11  # Adjust width if necessary
    fig_height = 6   # Adjust height for half the number of rows
    fig, axs = plt.subplots(rows, len(metrics), figsize=(fig_width, fig_height))
    
    suffix_order = [
        'High Damage; Cov. 3X',
        'High Damage; Cov. 5X',
        'High Damage; Cov. 10X',]
    
    damage_rename = {
        'High Damage; Cov. 3X' : 'Moderate Damage\nCov. 3X',
        'High Damage; Cov. 5X' : 'Moderate Damage\nCov. 5X',
        'High Damage; Cov. 10X' : 'Moderate Damage\nCov. 10X',
    }
    
    custom_palette = ['#a1c9f4', '#b9f2f0', '#8de5a1', '#ffb482', '#fab0e4']

    custom_order = ['C', 'P', 'M', 'm']
    for i, metric in enumerate(metrics):
        for j, suffix in enumerate(suffix_order[start_row:start_row+rows]):
            metric_dict = {"largest_alignment": "Largest\nAlignment", "genome_fraction_perc": "Genome\nFraction (%)", \
                       "mis_per_contig": "# Misassemblies\nper Contig", "mis_per_aln_base" : "Misassemblies\nper aligned bp", \
                          "na50" : "NA50"}
            metric_clean = metric_dict[metric]
            ax = axs[j, i] if rows > 1 else axs[i]  # Adjust based on the number of rows
            subset = df[df['label_clean'].str.contains(suffix)]
            if subset.empty:
                continue
            hue_order = sorted(subset['assembler_final'].unique(), key=lambda x: custom_order.index(x[0]))
            
            sns.barplot(data=subset, x='dataset_clean', y=metric, hue='assembler_final', hue_order=hue_order, palette=custom_palette, ax=ax)
            
            if metric == "mis_per_aln_base":
                max_value = subset[metric].max()
                min_value = subset[metric].min()
                tick_values = np.arange(0, max_value, 2e-5)
                tick_labels = [f'{int(x / 1e-5)}e-5' for x in tick_values]
                ax.set_yticks(tick_values)
                ax.set_yticklabels(tick_labels)
                
            title_name = damage_rename[suffix]
            
            ax.set_title(title_name, fontsize=10)
            ax.set_xlabel('')
            ax.set_ylabel(metric_clean, fontsize=10)
            ax.tick_params(axis='x', rotation=0)
            ax.get_legend().set_visible(False)
    
    if True:
        # Place a single legend outside the right side of the last subplot
        handles, labels = ax.get_legend_handles_labels()
        fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 0.02), fancybox=True, shadow=True, ncol=5)

    plt.tight_layout()
    plt.savefig(f'plots/figure3/benchmark_multipanel_{title}_{num}.svg', format="svg", bbox_inches="tight")
    plt.show()

In [None]:
# Example usage
metrics = ['na50', 'largest_alignment', 'genome_fraction_perc', 'mis_per_contig']
plot_metrics_half(main_df, metrics, "main", 3, 0, '1')  # For the first three rows