# Sequence Width Comparison

In [None]:
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure

In [None]:
def width_barplot(data_type: str, dataset_size: str, input_folder: str, output_folder: str, metric: str) -> None:
    
    '''
    Barplot showing the average perfomances across the experiments of meme, streme and ls-gkm when
    using different sequences width.
    
    
    data_type: one of the following ["dnase-shuffle", "shuffle-dnase", "dnase", "shuffle"]
    dataset_size: either optimal or full
    input_folder: folder containing all the data regarding dnase, shuffle, dnase-shuffle and shuffle-dnase
    output_folder: where the barplot will be created, structured based on the data_type 
    metric: one of the following ["auroc","auprc", "f1score"]
    
    It creates the barplot regarding the selected data_type, metric and dataset_size, retrieving the data from the 
    "{input_folder}/{data_type}/dataset-size-comparison" folder and saving it in the 
    "{output_folder}/{data_type}/dataset_size_comparison_barplot" barplot.

    '''
    
    possible_data_types = ["dnase-shuffle", "shuffle-dnase", "dnase", "shuffle"]
    if data_type not in possible_data_types:
        raise ValueError(f"data_type must be one of the following: {possible_data_types}")
    
    possible_metrics = ["auroc","auprc", "f1score"]
    if metric not in possible_metrics:
        raise ValueError(f"metric must be one of the following: {possible_metrics}")
        
    possible_dataset_sizes = ["full", "optimal"]
    if dataset_size not in possible_dataset_sizes:
        raise ValueError(f"dataset_size must be one of the following: {possible_dataset_sizes}")
        
    # Set the input and output folder
    data_folder = f"{input_folder}/{data_type}/sequence-width-comparison"
    output_folder = f"{output_folder}/{data_type}/width_comparison_barplot"
    
    # Create the output folder(s)
    os.makedirs(output_folder, exist_ok=True)
    
    # Widths and tools that will be compared
    width_list = [50, 100, 150, 200, 'full']
    tools = ["meme","streme","svm"]

    # Load the data
    data_widths = {} 
    for width in width_list:
        file_path = f"{data_folder}/summary_table_{dataset_size}_width_{width}.tsv"
        data_widths[width] = pd.read_csv(file_path, sep="\t")
        
    # Retrieve the data refering to "metric" only
    metric_col_name = "F1" if metric == "f1score" else metric.upper()
    metric_data = {}
    for width, df in data_widths.items():
        metric_cols = [col for col in df.columns if col.startswith(metric_col_name) or col.startswith("EXP")]
        metric_data[width] = df[metric_cols]
        
    # Compute for each tool and width the mean and standard deviation
    metric_mean_dicts = {}
    metric_std_dicts = {}
    for width, df in metric_data.items():
        means = [df[col].mean().round(3) for col in df.columns[1:]]
        stds = [df[col].std().round(3) for col in df.columns[1:]]
        tool_means = dict(zip([col.split("_")[1] for col in df.columns[1:]], means))
        tool_stds = dict(zip([col.split("_")[1] for col in df.columns[1:]], stds))

        metric_mean_dicts[width] = {tool: tool_means[tool] for tool in tools}
        metric_std_dicts[width] = {tool: tool_stds[tool] for tool in tools}  
    

    # Plot
    barWidth = 0.3
    fig, ax = plt.subplots(figsize=(20, 15), dpi=300)

    # Bar positions 
    br_positions = {width: [] for width in width_list}
    br1 = np.arange(len(width_list))

    for index, width in enumerate(width_list):
        br_positions[width] = [br1[index] + (barWidth + 0.00) * x for x in range(len(tools))]

    colors = ["#f7b32bff", "#d52a39ff", "#366fdaff"]  # meme, streme, ls-gkm
    for index, width in enumerate(width_list):
        plt.bar(
            br_positions[width],
            metric_mean_dicts[width].values(),
            color=colors,
            width=barWidth,
            edgecolor="grey",
            label=["meme", "streme", "ls-gkm"] if index == 0 else None,
            zorder=3
        )
        for i, (tool, mean) in enumerate(metric_mean_dicts[width].items()):
            std = metric_std_dicts[width][tool]
            plt.text(
                br_positions[width][i],
                mean + 0.055,
                f"std: \n{std if not pd.isna(std) else f'{(0.0):.2f}'}",
                horizontalalignment="center",
                fontsize=15
            )
            plt.text(
                br_positions[width][i],
                mean + 0.01,
                f"mean: \n{mean:.3f}",
                horizontalalignment="center",
                fontsize=15
            )

    # Axis configuration
    plt.xticks([r + barWidth for r in range(len(width_list))], width_list, fontsize=15)
    plt.yticks(fontsize=15)
    plt.legend(loc=2, prop={'size': 18})
    plt.grid(axis="y", zorder=0, alpha=0.3)
    ylabel = "F1-Score" if metric == "f1score" else metric.upper()
    plt.ylabel(ylabel, fontsize=30)
    plt.xlabel("Width", fontsize=30)
    plt.title(f"{dataset_size} {data_type} {ylabel} mean of tools", fontsize=30)
    ax.set_ylim([0, 1.1])


    plt.savefig(f"{output_folder}/{ylabel}_width_{dataset_size}.svg")
    plt.savefig(f"{output_folder}/{ylabel}_width_{dataset_size}.png")
    plt.show()

In [None]:
width_barplot(data_type = "", dataset_size = "", input_folder = "", output_folder = "", metric =  "")