In [1]:
import dataclasses
import os

import click
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

from gtp.configs.loaders import load_configs
from gtp.configs.project import GenotypeToPhenotypeConfigs
from gtp.dataloading.path_collectors import (
    get_post_processed_genotype_directory,
    get_results_training_output_directory,
)
from gtp.dataloading.tools import collect_chromosome_position_metadata
from gtp.options.plot_attribution import PlotAttributionOptions
from gtp.tools.simple import create_exp_info_text


def _process_chromosome(
    configs: GenotypeToPhenotypeConfigs,
    options: PlotAttributionOptions,
    chromosome: int,
):
    if chromosome is not None:
        options = dataclasses.replace(options)
        options.chromosome = chromosome

    training_output_dir = get_results_training_output_directory(configs.io)
    exp_info = create_exp_info_text(
        species=options.species,
        wing=options.wing,
        color=options.color,
        chromosome=options.chromosome,
    )

    experiment_dir = training_output_dir / options.exp_name / exp_info
    if not os.path.exists(experiment_dir):
        print(
            f"{experiment_dir} does not exist. Unable to process chromosome {options.chromosome}"
        )
        return

    # Load attributions
    val_attributions = np.load(experiment_dir / "validation_attributions.npy")
    test_attributions = np.load(experiment_dir / "test_attributions.npy")

    # Get Chromosome metadata
    genotype_folder = get_post_processed_genotype_directory(configs.io)

    position_metadata = collect_chromosome_position_metadata(
        genotype_folder / configs.experiment.genotype_scope,
        options.species,
        options.chromosome,
    )

    assert (
        len(position_metadata)
        == val_attributions.shape[0]
        == test_attributions.shape[0]
    ), "Must have matching length to be accurate."

    # Creating metadata values [[{CHROMOSOME}, '{SCAFFOLD}:{SCAFFOLD_POSITION}"], ...]
    metadata = [
        [i, int(options.chromosome), f"{x[0]}:{x[1]}"]
        for i, x in enumerate(position_metadata)
    ]

    df_val_data = np.concatenate((metadata, val_attributions[:, np.newaxis]), axis=1)
    df_test_data = np.concatenate((metadata, test_attributions[:, np.newaxis]), axis=1)

    if options.top_n > 0:
        # Order and filter by attributions
        df_val_data = df_val_data[df_val_data[:, -1].astype(np.float32).argsort()][
            -options.top_n :
        ]
        df_test_data = df_test_data[df_test_data[:, -1].astype(np.float32).argsort()][
            -options.top_n :
        ]

        # Reorder by base pair
        df_val_data = df_val_data[df_val_data[:, 0].astype(np.int64).argsort()]
        df_test_data = df_test_data[df_test_data[:, 0].astype(np.int64).argsort()]

    column_names = ["BP", "CHR", "SNP", "Attribution"]
    val_df = pd.DataFrame(data=df_val_data, columns=column_names)
    test_df = pd.DataFrame(data=df_test_data, columns=column_names)
    casting_kwargs = {"CHR": "int64", "BP": "int32", "Attribution": "float32"}
    val_df = val_df.astype(casting_kwargs)
    test_df = test_df.astype(casting_kwargs)

    return val_df, test_df


def _process_genome(
    configs: GenotypeToPhenotypeConfigs, options: PlotAttributionOptions
):
    val_dfs = []
    test_dfs = []
    for chromosome in tqdm(
        range(1, configs.global_butterfly_metadata.number_of_chromosomes + 1),
        desc="Processing Chromosomes",
        colour="blue",
    ):
        vdf, tdf = _process_chromosome(configs, options, chromosome=chromosome)
        val_dfs.append(vdf)
        test_dfs.append(tdf)

    return pd.concat(val_dfs), pd.concat(test_dfs)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from dash import Dash, html, dcc, Input, Output, callback
import dash_bio as dashbio


def plot_attributions(
    configs: GenotypeToPhenotypeConfigs, options: PlotAttributionOptions
):
    if configs.experiment.genotype_scope == "genome":
        val_df, test_df = _process_genome(configs, options)
    elif configs.experiment.genotype_scope == "chromosome":
        # There is a bug if we just process a single chromosome in the dashbio package
        val_df, test_df = _process_chromosome(configs, options)

    df = val_df
    plot_kwargs = {
        "p": "Attribution",
        "chrm": "CHR",
        "bp": "BP",
        "snp": "SNP",
        "gene": None,
        "logp": False,
        "ylabel": "Model Attribution Score",
        "highlight": False,
    }

    app = Dash("Manhattan Plot")

    app.layout = html.Div(
        [
            "Threshold value",
            dcc.Slider(
                id="default-manhattanplot-input",
                min=1,
                max=10,
                marks={i: {"label": str(i)} for i in range(10)},
                value=6,
            ),
            html.Br(),
            html.Div(
                dcc.Graph(
                    id="default-dashbio-manhattanplot",
                    figure=dashbio.ManhattanPlot(dataframe=df, **plot_kwargs),
                )
            ),
        ]
    )

    @callback(
        Output("default-dashbio-manhattanplot", "figure"),
        Input("default-manhattanplot-input", "value"),
    )
    def update_manhattanplot(threshold):
        return dashbio.ManhattanPlot(
            dataframe=df, genomewideline_value=threshold, **plot_kwargs
        )

    app.run(jupyter_mode="tab")
    # You will get plotting label errors on the tooltip unless you make the changes discussed here: https://github.com/plotly/dash-bio/issues/720

In [5]:
def create_manhattan_plot_static(
    configs: GenotypeToPhenotypeConfigs, options: PlotAttributionOptions
):
    if configs.experiment.genotype_scope == "genome":
        val_df, test_df = _process_genome(configs, options)
    elif configs.experiment.genotype_scope == "chromosome":
        # There is a bug if we just process a single chromosome in the dashbio package
        val_df, test_df = _process_chromosome(configs, options)

    plot_kwargs = {
        "p": "Attribution",
        "chrm": "CHR",
        "bp": "BP",
        "snp": "SNP",
        "gene": None,
        "logp": False,
        "ylabel": "Model Attribution Score",
        "highlight": False,
    }

    title_str = (
        f"{options.species.capitalize()} ({options.wing}) | Phenotype: {options.color}"
    )

    manhattanplot = dashbio.ManhattanPlot(
        dataframe=val_df,
        highlight_color="#00FFAA",
        title=title_str,
        **plot_kwargs,
    )

    return manhattanplot

In [7]:
from pathlib import Path

config_path = Path("../configs/default.yaml")
cfgs: GenotypeToPhenotypeConfigs = load_configs(config_path)
opts: PlotAttributionOptions = PlotAttributionOptions(
    species="erato",
    color="color_3",
    wing="forewings",
    exp_name="base",
    verbose=True,
    top_n=50,  # Set to -1 to view all
)

# plot_attributions(cfgs, opts)
fig = create_manhattan_plot_static(cfgs, opts)
fig.write_image("test.png", width=800, height=450)
fig

Processing Chromosomes: 100%|[34m██████████[0m| 21/21 [01:52<00:00,  5.37s/it]
