# Report generation example codes

### 1. Show the metadata as a table

In [None]:
import pandas as pd
import plotly.graph_objects as go
import os

## ---------------------------------------------------
## Load the meta data, in csv, tsv, or excel format  -
## ---------------------------------------------------
meta_file = "meta.csv" # Replace with your file path
data = pd.read_csv(meta_file)

## ------------------------------------
## The following section is optional  -
## ------------------------------------
# sort data first by Serotype and then by Accession
data = data.sort_values(by=['Serotype', 'Date of Collection']) # Example columns; You may sort by other columns
#replace Nan with empty string
data = data.fillna('')
# rename column "Origin" to "Country of Origin" as example
data = data.rename(columns={"Origin": "Country of Origin"})


## ----------------------------------------------------
## Plot the data in a table, details are configurable -
## ----------------------------------------------------
fig = go.Figure(data=[go.Table(
    header=dict(values= list(data.keys()), 
                align='center',
                font=dict(color='black', size=14),
                ),
    # show two places after decimal
    cells=dict( values= [
                                data[col].round(2) if data[col].dtype in ['float64', 'int64'] else data[col]
                                for col in data],
                fill_color='lavender',
                align='center',
                font=dict(color='black', size=14),
                height=25,
            ),
            # columnwidth=[80, 80, 80, 80, 80, 80], # you may adjust the column width depending on the data
)])

fig.update_layout(margin=dict(l=30, r=30, t=30, b=30))
fig.update_layout(width=1500, height=1000) # You may change the width and height to better display the data
fig.show()

### 2. Plot phylogenetic tree as a whole tree given a newick format file

In [None]:
from Bio import Phylo
import pandas as pd
from plotly.offline import init_notebook_mode, iplot
import plotly.graph_objects as go

init_notebook_mode(connected=False)
import colorsys
import numpy as np
import os
import random



def get_max_depth(tree):
    terminals = tree.get_terminals()
    depths_from_root = {terminal.name: tree.distance(tree.root, terminal) for terminal in terminals}
    max_depth = max(depths_from_root.values())
    return max_depth


def get_x_coordinates(tree):
    """Associates to  each clade an x-coord.
    returns dict {clade: x-coord}
    """
    xcoords = tree.depths()
    if not max(xcoords.values()):
        xcoords = tree.depths(unit_branch_lengths=True)
    return xcoords


def get_y_coordinates(tree, dist=1.3):

    maxheight = tree.count_terminals()
    # Rows are defined by the tips/leafs
    ycoords = dict((leaf, maxheight - i * dist) for i, leaf in enumerate(reversed(tree.get_terminals())))

    def calc_row(clade):
        for subclade in clade:
            if subclade not in ycoords:
                calc_row(subclade)
        ycoords[clade] = (ycoords[clade.clades[0]] + ycoords[clade.clades[-1]]) / 2

    if tree.root.clades:
        calc_row(tree.root)
    return ycoords


def get_clade_lines(
    orientation="horizontal",
    y_curr=0,
    x_start=0,
    x_curr=0,
    y_bot=0,
    y_top=0,
    line_color="rgb(25,25,25)",
    line_width=0.5,
):
    """define a shape of type 'line', for branch"""
    branch_line = dict(type="line", layer="below", line=dict(color=line_color, width=line_width))
    if orientation == "horizontal":
        branch_line.update(x0=x_start, y0=y_curr, x1=x_curr, y1=y_curr)
    elif orientation == "vertical":
        branch_line.update(x0=x_curr, y0=y_bot, x1=x_curr, y1=y_top)
    else:
        raise ValueError("Line type can be 'horizontal' or 'vertical'")

    return branch_line


def draw_clade(
    clade,
    x_start,
    line_shapes,
    x_coords,
    y_coords,
    line_color="rgb(25,25,25)",
    line_width=1,
):
    """Recursively draw the tree branches, down from the given clade"""

    x_curr = x_coords[clade]
    y_curr = y_coords[clade]

    # Draw a horizontal line from start to current
    branch_line = get_clade_lines(
        orientation="horizontal",
        y_curr=y_curr,
        x_start=x_start,
        x_curr=x_curr,
        line_color=line_color,
        line_width=line_width,
    )

    line_shapes.append(branch_line)

    if clade.clades:
        # Draw a vertical line connecting all children
        y_top = y_coords[clade.clades[0]]
        y_bot = y_coords[clade.clades[-1]]

        line_shapes.append(
            get_clade_lines(
                orientation="vertical",
                x_curr=x_curr,
                y_bot=y_bot,
                y_top=y_top,
                line_color=line_color,
                line_width=line_width,
            )
        )

        # Draw descendants
        for child in clade:
            draw_clade(child, x_curr, line_shapes, x_coords=x_coords, y_coords=y_coords)

# These are predefined colors. Feel free to make edits to the list.
def generate_distinct_colors(elements):
    predefined_rgb_color = [
        "red",
        "blue",
        "green",
        "darkorange",
        "purple",
        "brown",
        "darkgoldenrod",
        "deepskyblue",
        "magenta",
        "darkviolet",
        "darkturquoise",
        "darkslategray",
        "darkkhaki",
        "peru",
        "darkred",
        "darkcyan",
        "goldenrod",
        "lightblue",
        "#bcbd22",
        "#e377c2",
        "#8c564b",
        "#d62728",
        "#2ca02c",
        "#ff7f0e",
        "#1f77b4",
    ]
    colors = {}
    # iterate over all elements and assign them a color
    for i, element in enumerate(elements):
        colors[element] = predefined_rgb_color[i]
    return colors


def plot_tree(
    tree,
    meta,
    colors_for_counties,
    title,
    xaxis_title="Distance",
    intermediate_node_color="rgb(25,25,25)",
    width=1500,
    height=2000,
    plot_column_name="speciesID_mash", # the column name from the meta data to be used for coloring, such as "speciesID_mash"
):
    max_depth = get_max_depth(tree)

    x_coords = get_x_coordinates(tree)
    y_coords = get_y_coordinates(tree)

    line_shapes = []
    draw_clade(
        tree.root,
        0,
        line_shapes,
        x_coords=x_coords,
        y_coords=y_coords,
        line_color=intermediate_node_color,
        line_width=1,
    )

    my_tree_clades = x_coords.keys()
    X = []
    Y = []
    text = []
    hover_text = []
    color = []
    size = []

    for cl in my_tree_clades:
        X.append(x_coords[cl])
        Y.append(y_coords[cl])
        if cl.is_terminal():
            name = cl.name
            size.append(10)
            # add color for accorsponding county according to the name
            # the meta data file has to have a column named "Sample ID", or you have the change sample id column name here
            county_name = meta.loc[meta["Sample ID"] == name, plot_column_name].iloc[0]
            color.append(colors_for_counties[county_name])

        else:
            name = ""
            size.append(0)
            color.append(intermediate_node_color)
        text.append(name)
        hover_text.append(name)

    meta_county = meta.copy()
    # reset the index
    meta_county = meta_county.reset_index(drop=True)
    for k, strain in meta_county.iterrows():
        i = text.index(strain.iloc[0])
        
        ## ------------------------------------------------------
        ## This section is to update what to show in hover text -
        ## ------------------------------------------------------
        hover_text[i] = (
            hover_text[i]
            + "<br>SpeciesID_mash: "
            + str(meta_county.loc[k, "speciesID_mash"]) # speciesID_mash is a column in the metadata
            + "<br>SpeciesID_kraken: "
            + str(meta_county.loc[k, "speciesID_kraken"]) # speciesID_kraken is a column in the metadata
            + "<br>Date: "
            + str(meta_county.loc[k, "Collection Date"]) # Collection Date is a column in the metadata
            + "<br>County: "
            + str(meta_county.loc[k, "County"]) # County is a column in the metadata    
            + "<br>NDM genes: "
            + str(meta_county.loc[k, "NDM"]) # NDM is a column in the metadata
        )
        
    axis = dict(showline=False, zeroline=False, showgrid=False, showticklabels=False, title="")

    nodes = dict(
        type="scatter",
        x=X,
        y=Y,
        mode="markers+text",
        marker=dict(
            color=color,
            size=size,
        ),
        text=text,
        hovertext=hover_text,
        hoverinfo="text",
        textposition="middle right",
        textfont=dict(family="Arial", size=12, color=color),
    )

    layout = dict(
        title=title,  #'Phylogeny of Nmen Virus<br>86 genomes',
        font=dict(family="Balto", size=20),
        width=width,
        height=height,
        autosize=False,
        showlegend=False,
        xaxis=dict(
            showline=True,
            zeroline=False,
            showgrid=False,
            ticklen=4,
            showticklabels=True,
            title=xaxis_title,  # "Phylogeny of Nmen Virus"
        ),
        yaxis=axis,
        hovermode="closest",
        shapes=line_shapes,
        plot_bgcolor="rgb(250,250,250)",
        margin=dict(l=10),
    )

    ## ----------------------------------------------------------
    ## The following section is to update annotation positions  -
    ## The annotation color is based on plot_column_name.       -
    ## In this example, even though the variable is county, but -
    ## the actual variable could be species or other variables  -
    ## ----------------------------------------------------------
    # Create annotations for each county at the top right corner
    annotations = []
    start_y = 0.90  # Starting y position, near the top of the plot
    y_offset = 0.025 * (2500 / height)  # Vertical space between annotations

    for i, county in enumerate(meta[plot_column_name].unique()):
        annotations.append(
            {
                "x": 0.91,  # Near the left edge of the plot
                "y": start_y - (i * y_offset),  # Move down for each new annotation
                "xref": "paper",
                "yref": "paper",
                "text": county,
                "showarrow": False,
                "font": {"family": "Arial", "size": 20, "color": "black"},
                "xanchor": "left",  # Align text to the left
                "align": "left",
                "bgcolor": colors_for_counties[county],  # Background color of annotation
            }
        )

    # Update layout with annotations
    layout["annotations"] = annotations

    fig = dict(data=[nodes], layout=layout)
    fig["layout"]["xaxis"].update({"range": [0, max_depth * 1.2]})  # x-axis range
    return go.Figure(fig)


def draw_annotation_line(
    fig,
    node_0: str,
    node_1: str,
    line_color: str,
    dist_df,
    #  max_depth:float,
):
    text = fig._data_objs[0].text
    x = fig._data_objs[0].x
    y = fig._data_objs[0].y
    index_0 = text.index(node_0)
    index_1 = text.index(node_1)

    fig.add_shape(
        type="line",
        x0=x[index_0],
        y0=y[index_0],  # start position
        x1=x[index_1],
        y1=y[index_1],  # end position
        line=dict(
            color=line_color,
            width=4,
            dash="dot",  # Make the line dashed
        ),
    )
    snp_distance = dists_df[node_0][node_1]
    if snp_distance > 1:
        snp_distance = int(snp_distance)
    fig.add_annotation(
        x=(x[index_0] + x[index_1]) / 2,
        y=(y[index_0] + y[index_1]) / 2,
        # ax=-max_depth*0.0015,
        # ay=0,                             # No vertical shift
        text=f"<b>Distance: {snp_distance}</b>",
        font=dict(size=18, color=line_color),
        # showarrow=False,
        arrowhead=2,
        arrowsize=1,
        arrowwidth=2,
        arrowcolor="red",
    )


meta_file = "meta.csv" # read the meta data file; replace with your meta data file
                        # the meta data file has to have a column named ""Sample ID", or you have the change "Sample ID" in the plot_tree function
meta = pd.read_csv(meta_file)

species = meta["speciesID_mash"].unique() # in this example, the data points are colored by speciesID_mash

colors_for_species = generate_distinct_colors(species)
intermediate_node_color = "rgb(25,25,25)"

dists_df = pd.read_csv("distances.tsv", sep="\t", index_col=0) # read the pairwise distance file; replace with your pairwise distance file

tree = Phylo.read("tree.newick", "newick") # read the tree file; replace with your tree file

fig = plot_tree(
    tree,
    meta,
    colors_for_species,
    title="Phylogeny of <i>NDM Positive CROs</i> samples<br>" + str(len(meta)) + " samples", # the title of the plot
    height=3000,
    plot_column_name="speciesID_mash",
)

## ---------------------------------------------------------------------------
## This section is optional, it is to draw the annotation line between nodes -
## ---------------------------------------------------------------------------
draw_annotation_line(fig, "ID#", "ID#", "black", dists_df) # example line; you may change the sample ids here
draw_annotation_line(fig, "ID#", "ID#", "black", dists_df) # example line; you may change the sample ids here
draw_annotation_line(fig, "ID#", "ID#", "black", dists_df) # example line; you may change the sample ids here

iplot(fig)

### 3. Plot multiple sub-trees according to a column. In this example, the column is "speciesID_mash". Thus, one tree will be plotted for each species.

In [None]:
from Bio import Phylo
import pandas as pd
from plotly.offline import init_notebook_mode, iplot

init_notebook_mode(connected=False)
import colorsys
import numpy as np
from Bio.Phylo.BaseTree import Tree
from ete3 import Tree

meta_file = "meta.csv" # read the meta data file; replace with your meta data file
meta = pd.read_csv(meta_file)
all_species = meta["speciesID_mash"].unique() # in this example, the column to separate the sub-trees is "speciesID_mash"

colors_for_species = generate_distinct_colors(species)
intermediate_node_color = "rgb(25,25,25)"


genes = meta["NDM"].unique()
genes.sort()
colors_for_genes = generate_distinct_colors(genes) # in this example, the data points are colored by NDM in each sub-tree

tree_file = "tree.newick" # read the tree file; replace with your tree file

index = 0
for species in all_species:
    # obtain the list of samples in meta for the county
    species_samples = meta[meta["speciesID_mash"] == species]
    sample_list_by_species = species_samples["Sample ID"].tolist()
    if len(species_samples) <= 1:
        continue

    tree = Tree(tree_file, format=1)  # read the whole tree
    tree.prune(sample_list_by_species)  # prune the tree to the samples in the species
    out_file = "subtree." + species + ".nwk"
    tree.write(outfile=out_file, format=1)  # write the subtree to a newick file

    subtree = Phylo.read(out_file, "newick")

    title = species + "<br>" + str(len(species_samples)) + " samples" # the title of the plot

    subfig = plot_tree(
        subtree,
        species_samples,
        colors_for_genes,
        title=title,
        xaxis_title="Mash distance",
        width=1500, # the width of the plot
        height=500 + 10*len(species_samples), # the height of the plot
        plot_column_name="NDM", # the column to be used for coloring
    )

    iplot(subfig)
    os.remove(out_file)     # remove out_file, which is a temporary file

### 4. 3D plot of the samples based on the pairwise distances. Either Multidimensional Scaling (MDS) or Uniform Manifold Approximation and Projection (UMAP) may be used for generating the coordinates.

In [None]:
import pandas as pd
import numpy as np
from sklearn.manifold import MDS
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import umap.umap_ as umap
from sklearn.decomposition import PCA
import warnings

# Suppress all FutureWarnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", message="using precomputed metric; inverse_transform will be unavailable")
warnings.filterwarnings("ignore", message="n_jobs value -1 overridden to 1 by setting random_state. Use no seed for parallelism.")


metadata_path = "meta.csv" # read the meta data file; replace with your meta data file
metadata = pd.read_csv(metadata_path)
species = metadata['speciesID_mash'].unique() # in this example, the column to separate the sub-trees is "speciesID_mash"
colors_for_species = generate_distinct_colors(species) # in this example, the data points are colored by speciesID_mash

# zip the species and colors_for_species
species_colors = dict(zip(species, colors_for_species.values()))

distance_matrix_path = "distances.tsv" # read the pairwise distance file; replace with your pairwise distance file
distance_matrix = pd.read_csv(distance_matrix_path, sep="\t", index_col=0)
distance_matrix = distance_matrix.fillna(0)

distance_array = distance_matrix.values

mds = MDS(n_components=3, dissimilarity='precomputed', random_state=42) # in this example, the data points are colored by speciesID_mash
# mds = umap.UMAP(n_components=3, metric='precomputed', random_state=42) # you may also use UMAP for dimensionality reduction
coordinates = mds.fit_transform(distance_array)

fig = px.scatter_3d(
    x=coordinates[:, 0],
    y=coordinates[:, 1],
    z=coordinates[:, 2],
    color=metadata['speciesID_mash'].astype(str),
    labels={'x': 'Dimension 1', 'y': 'Dimension 2', 'z': 'Dimension 3', 'color': 'speciesID_mash'},

    # example hover data; you may change the column names here
    hover_data={
        'Sample ID': metadata['Sample ID'],
        'Collection Date': metadata['Collection Date'],
        'County': metadata['County'],
        'speciesID_mash': metadata['speciesID_mash'].astype(str),
        'NDM': metadata['NDM'],
    },
    color_discrete_map={str(k): v for k, v in species_colors.items()},
    width=1500, # the width of the plot
    height=1500, # the height of the plot
)

# Update hover template to exclude dimension data
# example hover data; you may change the column names here
fig.update_traces(
    hovertemplate="<br>".join([
        "Species: %{customdata[3]}",
        "Sample ID: %{customdata[0]}",
        "Collection Date: %{customdata[1]}",
        "County: %{customdata[2]}",
        "NDM: %{customdata[4]}",
    ]),
    customdata=metadata[['Sample ID', 'Collection Date', 'County', 'speciesID_mash', 'NDM']].values
)

fig.update_traces(marker=dict(size=3))
# Configure the legend
fig.update_layout(
    legend=dict(
        title='Species',
        tracegroupgap=10,
        font=dict(size=14),
        itemsizing='constant',
        x=1,
        y=0.5,
    ),
    font=dict(size=12)
)
fig.show()

### 5. Antimicrobial Resistance Genes. The following tables outline the antimicrobial resistance (AMR) genes present in each sample. The data was generated using AMRFinderPlus (v3.10.1) included in Sanibel pipeline.

In [None]:
import pandas as pd
import plotly.graph_objects as go
import os
from IPython.display import display, Markdown

# Load the CSV file
cwd = os.getcwd()
cwd = os.path.dirname(cwd)
gene_names = []


def scan_gene_names(path):
    dirs = [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d)) and d != "filtered"]
    dirs.sort()
    for dir in dirs:
        amr_file = os.path.join(path, dir, dir + "_assembly", dir + "_amrfinderplus_report.tsv")
        data = pd.read_csv(amr_file, sep="\t", dtype=str)
        # get the values of the column "Gene symbol" in data, which is a dataframe
        genes = data["Gene symbol"].tolist()
        # for each item in genes
        for gene in genes:
            if gene.lower() not in gene_names:
                gene_names.append(gene.lower())


def italic_gene_names_column(value: pd.Series) -> pd.Series:
    # in each value of the Series, use italic_gene_names() to update it
    return value.apply(italic_gene_names)


def italic_gene_names(value: str) -> str:
    # check every words in value, if it is in gene_names, make it italic
    words = str(value).split()
    for i, word in enumerate(words):
        # if word in gene_names, ignore case
        if word.lower() in gene_names:
            words[i] = "<i>" + word + "</i>"
    return " ".join(words)


def plot_table(data: pd.DataFrame, title: str):
    fig = go.Figure(
        data=[
            go.Table(
                header=dict(values=list(data.keys()), align="center", font=dict(color="black", size=14)),
                cells=dict(
                    values=[italic_gene_names_column(data[col]) for col in data],
                    fill_color="lavender",
                    align="center",
                    font=dict(color="black", size=14),
                ),
                columnwidth=[70, 70, 290, 70, 70, 70, 70],
            )
        ]
    )

    title_text = f"<div align='left' style='font-size:24px;'>{title}</div>"
    display(Markdown(title_text))

    fig.update_layout(margin=dict(l=30, r=30, t=30, b=30))
    fig.update_layout(width=1500, height=100 + 40 * len(data))
    fig.show()


path = "Sanibel/output-date"  # example input data path; you may change the path here
scan_gene_names(path)

# get all directories under path, and the directory name cannot be "filtered"
dirs = [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d)) and d != "filtered"]
dirs.sort()
all_data = pd.DataFrame()
for dir in dirs:
    amr_file = os.path.join(path, dir, dir + "_assembly", dir + "_amrfinderplus_report.tsv")
    data = pd.read_csv(amr_file, sep="\t", dtype=str)
    data = data[
        [
            "Gene symbol",
            "Sequence name",
            "Element type",
            "Class",
            "% Coverage of reference sequence",
            "% Identity to reference sequence",
        ]
    ]
    # add a column "Sample ID" be the first column
    data.insert(0, "Sample ID", dir)

    # change column "Gene symbol" to "Gene"
    data = data.rename(columns={"Gene symbol": "Gene"})
    all_data = pd.concat([all_data, data], ignore_index=True)

plot_table(all_data, "Antimicrobial Resistance Genes") # option 1: show the table in the notebook

all_data.to_excel("AMR_genes.xlsx", index=False) # option 2: save the table to an excel file

### 6. Geo map by county

In [None]:
import plotly.graph_objects as go
import pandas as pd
import plotly.express as px
import os
import json

# Load the geojson file
cwd = os.getcwd()
with open("../Data/meta/fl-counties-fips.json", "r") as f:  # update the path to the geojson file
    geojson = json.load(f)

# Load the CSV file
meta = pd.read_csv("../Data/meta/meta_data.csv") # example meta data file; you may change the meta data file here
renamed_columns = ["Sample ID", "Collection Date", "County"]  # only show data for these columns; you may change the column names here
meta.columns = renamed_columns


def get_county_data(meta, geojson):
    county_names = [
        entry["properties"]["NAME"]
        for entry in geojson["features"]
        if entry["properties"]["STATE"] == "12" # florida fips is 12
    ]
    county_counts = pd.DataFrame(county_names, columns=["County"])
    county_counts["FIPS"] = [
        entry["properties"]["GEO_ID"][-5:] 
        for entry in geojson["features"]
        if entry["properties"]["STATE"] == "12" # florida fips is 12
    ]

    counts = meta["County"].value_counts().reset_index()

    county_counts = county_counts.merge(counts, on="County", how="left")
    county_counts["count"].fillna(0, inplace=True)

    return county_counts


county_counts = get_county_data(meta, geojson)

fig = px.choropleth(
    county_counts,
    geojson=geojson,
    locations="FIPS",
    color="count",
    hover_name="County",
    scope="usa",
    title="Florida by County",
    color_continuous_scale="blues",
)

fig.update_traces(marker_line=dict(color="rgba(0,0,0,0.3)", width=0.2))
fig.update_layout(autosize=True, margin=dict(t=0, b=0, l=0, r=0))
fig.update_geos(fitbounds="locations", visible=False)

fig.update_layout(width=1500, height=1000)
fig.show()

### 7. Summary Varpipe-WGS results

In [None]:
import pandas as pd
import plotly.graph_objects as go
import os
from IPython.display import display, Markdown

# Load the CSV file
cwd = os.getcwd()
cwd = os.path.dirname(cwd)
gene_names = [
    "rpob",
    "katg",
    "inha",
    "fabg1",
    "embb",
    "pnca",
    "gyra",
    "gyrb",
    "mmpr",  # "rv0678 (mmpr)",
    "atpe",
    "rplc",
    "rrl",
]


def extract_interpretations(file_path):
    with open(file_path, "r") as file:
        lines = file.readlines()

    start_line = None
    for i, line in enumerate(lines):
        if line.strip() == "Interpretations Summary:":
            start_line = i
            break

    if start_line is None:
        raise ValueError("Interpretations Summary not found in the file")

    interpretations_data = []
    headers = ["Drug", "Variant", "Interpretation"]

    for line in lines[start_line + 2 :]:
        if line.strip() == "":
            break
        interpretations_data.append(line.strip().split("\t"))

    df = pd.DataFrame(interpretations_data, columns=headers)
    return df


def scan_target_files(path):
    # scan all files under the path and the subdirectories to look for all files ended with "_target_region_coverage.txt"
    target_files = {}
    summary_files = {}
    lineage_files = {}
    for dirpath, dirnames, filenames in os.walk(path):
        dirnames[:] = [d for d in dirnames if d != "QC"]
        for filename in filenames:
            sample_id = filename[:11]
            if filename.endswith("_DR_loci_Final_annotation.txt"):
                target_files[sample_id] = os.path.join(dirpath, filename)
            elif filename.endswith("_summary.txt"):
                summary_files[sample_id] = os.path.join(dirpath, filename)
            elif filename.endswith("_Lineage.txt"):
                lineage_files[sample_id] = os.path.join(dirpath, filename)
    return target_files, summary_files, lineage_files


def italic_gene_names_column(value: pd.Series) -> pd.Series:
    # in each value of the Series, use italic_gene_names() to update it
    return value.apply(italic_gene_names)


def italic_gene_names(value: str) -> str:
    # check every words in value, if it is in gene_names, make it italic
    words = str(value).split()
    for i, word in enumerate(words):
        # if word in gene_names, ignore case
        if word.lower() in gene_names:
            words[i] = "<i>" + word + "</i>"
    return " ".join(words)


def plot_table_varpipe_target(target_file: str, summary_file, sample_id: str, lineage_file: str):
    data = pd.read_csv(target_file, sep="\t", dtype=str)
    summary_data = extract_interpretations(summary_file)
    # for the column "SAMPLE_ID", only keep the first 11 characters
    data["Sample ID"] = data["Sample ID"].str[:11]

    # remove columns: CHROM, REF, ALT, TYPE, Position within CDS, REF Amino acid, ALT Amino acid, Codon Position, Gene ID
    data = data.drop(
        [
            "CHROM",
            "REF",
            "ALT",
            "Variant Type",
            "Position within CDS ",
            "REF Amino acid",
            "ALT Amino acid",
            "Codon Position",
            "Gene ID",
        ],
        axis=1,
    )
    # remove the rows that gene name is not in gene_names, ignore case
    data = data[data["Gene Name"].str.lower().isin(gene_names)]
    # only keep the rows: 1. if 'Gene name' is gyrA or gyrB, 'Percent Alt Allele' should be at least 5, 2. for all other gene names, 'Percent Alt Allele' should be at least 10
    data["Percent Alt Allele"] = pd.to_numeric(data["Percent Alt Allele"], errors="coerce")
    data = data[
        (data["Gene Name"].isin(["gyrA", "gyrB"]) & (data["Percent Alt Allele"] >= 5))
        | (~data["Gene Name"].isin(["gyrA", "gyrB"]) & (data["Percent Alt Allele"] >= 10))
    ]

    # remove the rows from data, if "Amino Acid Change" is "p.Leu449Gln" and "Gene Name" is "rpoB"
    data = data[~((data["Amino acid Change"] == "p.Leu449Gln") & (data["Gene Name"] == "rpoB"))]

    # add "Interpretation" from summary_data to data, only if the first few characters in "Variant" before "_" in summary_data is in the "Gene Name" in data
    summary_data["Gene Name"] = summary_data["Variant"].str.split("_").str[0]
    data = data.merge(summary_data, on="Gene Name", how="left")
    # remove the column "Drug" and "Variant" from data
    data.drop(["Drug", "Variant"], axis=1, inplace=True)
    # replace NaN with empty string in "Interpretation"
    data["Interpretation"] = data["Interpretation"].fillna("")

    # we do not need interpretation any more requested by Calin
    data = data.drop(["Interpretation"], axis=1)

    # merge linage table to data
    lineage = ""
    long_sublineage = ""
    with open(lineage_file, "r") as file:
        sublineages = []
        for line in file:
            line = line.rstrip()
            # if line starts with "Lineage"
            if line.startswith("Lineage"):
                lineage = line[len("Lineage: ") : -1]
            else:
                if "lineage" in line:
                    index = line.index("lineage") + len("lineage") + 1
                    sublineages.append(line[index : len(line)].strip())
                elif "suggests" in line:
                    index = line.index("suggests") + len("suggests") + 1
                    sublineages.append(line[index : len(line)].strip())
        long_sublineage = max(sublineages, key=len)

    data["Lineage"] = lineage
    data["Sub-lineage"] = long_sublineage

    return data


def plot_table_lineage(lineage_file, sample_id):
    with open(lineage_file, "r") as file:
        sublineages = []
        for line in file:
            line = line.rstrip()
            # if line starts with "Lineage"
            if line.startswith("Lineage"):
                lineage = line[len("Lineage: ") : -1]
            else:
                if "lineage" in line:
                    index = line.index("lineage") + len("lineage") + 1
                    sublineages.append(line[index : len(line)].strip())
                elif "suggests" in line:
                    index = line.index("suggests") + len("suggests") + 1
                    sublineages.append(line[index : len(line)].strip())
        long_sublineage = max(sublineages, key=len)

    # new a dataframe with three columns: "Sample ID", "Lineage", "Sub-lineage"
    data = pd.DataFrame(columns=["Sample ID", "Lineage", "Sub-lineage"])
    # add a new row with value of sample_id, lineage, sublineage, respectivaly
    data.loc[0] = [sample_id, lineage, long_sublineage]

    fig = go.Figure(
        data=[
            go.Table(
                header=dict(
                    values=list(data.keys()),
                    align="center",
                    font=dict(color="black", size=14),
                ),
                cells=dict(
                    values=[italic_gene_names_column(data[col]) for col in data],
                    fill_color="lavender",
                    align="center",
                    font=dict(color="black", size=14),
                    height=25,
                ),
            )
        ]
    )

    fig.update_layout(margin=dict(l=30, r=30, t=30, b=30))
    fig.update_layout(width=600, height=100 + len(data) * 32)
    fig.show()



meta_data = pd.read_csv("meta.csv")  # add the path here
# meta_data.drop(columns=["SOURCE", "Date of receipt", "Submitter"], inplace=True)
meta_data.columns = ["Sample ID", "Collection Date", "County"]


path = "path/to/fastqs_varpipe"  # add the path here
target_files, summary_files, lineage_files = scan_target_files(path)

gene_names = {item.lower() for item in gene_names}

total_data = pd.DataFrame()
for sample_id in meta_data["Sample ID"].tolist():
    # if sample_id is a key in the target_files
    if sample_id in target_files:
        target_file = target_files[sample_id]
        summary_file = summary_files[sample_id]
        data = plot_table_varpipe_target(target_file, summary_file, sample_id, lineage_files[sample_id])
        total_data = pd.concat([total_data, data])

def plot_table_total(total_data):
    fig = go.Figure(
        data=[
            go.Table(
                header=dict(
                    values=list(total_data.keys()),
                    align="center",
                    font=dict(color="black", size=14),
                ),
                cells=dict(
                    values=[italic_gene_names_column(total_data[col]) for col in total_data],
                    fill_color="lavender",
                    align="center",
                    font=dict(color="black", size=14),
                    height=25,
                ),
            )
        ]
    )

    fig.update_layout(margin=dict(l=30, r=30, t=30, b=30))
    fig.update_layout(width=1500, height=100 + len(total_data) * 32)
    fig.show()

if len(total_data) < 100: # if the number of rows in total_data is less than 100, plot the table
    plot_table_total(total_data)
else: # if the number of rows in total_data is greater than 100, save the table to a csv file
    total_data.to_csv("Table - Genetic modifications.csv", index=False)

### 8. To generate a report in html format, run the following command. 
#### - Make sure you change the .ipynb file name to the one you are preparing.
#### - Make sure you have nbconvert installed in your conda environment
#### - Add the tag "exclude" to the cells which you do not want to show in the generated html file


       `jupyter nbconvert --execute summary.ipynb --to html --no-prompt --TagRemovePreprocessor.remove_cell_tags "exclude" --TemplateExporter.exclude_input=True`