In [None]:
from pathlib import Path
import numpy as np

EXP_DIR = Path("/local/scratch/carlyn.1/dna/training_output/pca_10/")
DNA_ROOT_DIR = Path("/local/scratch/carlyn.1/dna/processed/genome/")
OUTPUT_DIR = Path("/home/carlyn.1/dna-trait-analysis/tmp/edit_plots")
OUTPUT_DIR.mkdir(exist_ok=True)
EDIT_OPTIONS = ["AA", "Aa/aA", "aa", "zero-out"]
GENE_INFO_DATA = Path("/home/carlyn.1/dna-trait-analysis/data/gene_alignment.csv")

In [17]:
from gtp.dataloading.tools import load_json, collect_chromosome_position_metadata


def load_data(exp_dir, species, wing, color, chromosome):
    position_metadata = collect_chromosome_position_metadata(
        DNA_ROOT_DIR, species, chromosome
    )
    data_dir = exp_dir / f"{species}_{wing}_{color}_chromosome_{chromosome}"
    all_data = []
    for phase in ["training", "validation", "test"]:
        data = np.load(
            data_dir / f"{phase}_windowed_editing_attributions.npy", allow_pickle=True
        ).item()
        metrics = load_json(data_dir / f"{phase}_metrics.json")
        pearsonr = metrics["pearsonr"]
        rmse = metrics["rmse"]
        all_data.append([data, rmse, pearsonr])

    return all_data, position_metadata


def collect_plot_data(data, nominal_positions):
    plot_data = {}
    for i, edit_opt in enumerate(EDIT_OPTIONS):
        all_pos = []
        for pos in nominal_positions:
            all_data = np.concatenate(data[pos][i])
            all_pos.append(all_data)
        plot_data[edit_opt] = np.stack(all_pos, axis=1)
    return plot_data

In [18]:
import pandas as pd

gene_info_df = pd.read_csv(GENE_INFO_DATA)

In [19]:
from plotly.subplots import make_subplots
import plotly.graph_objs as go
import plotly.io as pio

import pandas as pd

PHASE_NAMES = ["training", "validation", "test"]
NUM_CHROMOSOMES = 21
PHASE_COLORS = ["#78006e", "#062475", "#016316"]


def create_plot_for_one_chromosome(
    species="erato", wing="forewings", color="color_1", chromosome=1, phase="test"
):
    fig = make_subplots(1, 1)
    data_to_save = {}
    phase_rmses = [[], [], []]
    chromosome_position_metadata = []
    all_phase_data, position_metadata = load_data(
        EXP_DIR, species, wing, color, chromosome
    )

    all_nominal_positions = [list(all_phase_data[i][0].keys()) for i in range(3)]
    assert (
        all_nominal_positions[0] == all_nominal_positions[1] == all_nominal_positions[2]
    ), "All position must match across phases."

    chromosome_position_metadata = []
    for nom_pos in all_nominal_positions[0]:
        if nom_pos == len(position_metadata):
            nom_pos -= 1
        chromosome_position_metadata.append(position_metadata[nom_pos])

    data_to_save[chromosome] = {
        "window_size": all_nominal_positions[0][0],
        "nominal_positions": all_nominal_positions[0],
        "real_position_metadata": chromosome_position_metadata,
        "attributions": {},
        "rmses": {},
        "pearsonr": {},
    }
    phase_data = all_phase_data[PHASE_NAMES.index(phase)]
    attrs, rmse, pearsonr = phase_data
    nominal_window_positions = list(attrs.keys())
    plot_data = collect_plot_data(attrs, nominal_window_positions)

    data_to_save[chromosome]["rmses"][phase] = rmse
    data_to_save[chromosome]["pearsonr"][phase] = pearsonr
    data_to_save[chromosome]["attributions"][phase] = {}

    colors = [(245, 138, 66), (66, 245, 102), (47, 74, 196), (189, 32, 131)]
    all_real_positions = [x[1] for x in chromosome_position_metadata]
    for i, edit_opt in enumerate(EDIT_OPTIONS):
        means = np.abs(plot_data[edit_opt]).mean(axis=0)
        stds = np.abs(plot_data[edit_opt]).std(axis=0)
        y_upper = means + stds
        y_lower = means - stds
        data_to_save[chromosome]["attributions"][phase][edit_opt] = {
            "means": means.tolist(),
            "stds": stds.tolist(),
        }

        std_color = f"rgba({','.join(str(c) for c in colors[i])},0.2)"
        fig.add_trace(
            go.Scatter(
                x=all_real_positions + all_real_positions[::-1],
                y=y_upper.tolist() + y_lower.tolist()[::-1],
                fill="toself",
                fillcolor=std_color,
                line=dict(color="rgba(255, 255, 255, 0)"),
                hoverinfo="skip",
                legendgroup=f"edit_{edit_opt}",
                showlegend=False,
                name=edit_opt,
            ),
        )
        line_color = f"rgb({','.join(str(c) for c in colors[i])})"
        fig.add_trace(
            go.Scatter(
                x=all_real_positions,
                y=means,
                hovertemplate="Attribution: %{y}"
                + "<br>Nominal Position: %{x}<br>"
                + "%{text}",
                text=[
                    f"Scaffold: {scaffold}<br>Real Position: {real_position}"
                    for scaffold, real_position in chromosome_position_metadata
                ],
                marker=dict(size=4),
                mode="lines+markers",
                line=dict(color=line_color, width=2.5),
                name=edit_opt,
                legendgroup=f"edit_{edit_opt}",
                showlegend=True,
            ),
        )

    fig.update_layout(
        title=f'Edit changes in <b style="color: #b06a15">{species}</b> <b style="color: #016304">{wing}</b> for phenotype <b style="color: #1c0036">{color}</b> on chromosome <b>{chromosome}</b>',
        title_font_size=16,
        height=400,
        width=800,
    )

    max_y = max([max(d.y) for d in fig.data])
    filtered_gene_info_df = gene_info_df[
        gene_info_df[f"{species}_chromosome"].astype(int) == chromosome
    ]
    for gidx, gene_info in filtered_gene_info_df.iterrows():
        start_x = int(gene_info[f"{species}_start"])
        end_x = int(gene_info[f"{species}_end"])
        if start_x == -1 or pd.isna(gene_info.gene_symbol):
            continue
        opacity = 1.0
        if species == "erato":
            opacity = gene_info.erato_containment_score
        fig.add_trace(
            go.Scatter(
                x=[start_x, end_x, end_x, start_x, start_x],
                y=[0, 0, max_y, max_y, 0],
                fill="toself",
                line_color="#A30016",
                fillcolor="#A30016",
                mode="lines",
                opacity=0.2,
                name=f"{gene_info.gene_id} ({gene_info.gene_symbol})",
            )
        )
        fig.data = tuple([fig.data[-1]] + list(fig.data[:-1]))  # Move this to the back

    return fig, data_to_save


def create_plot_and_data(
    species="erato",
    wing="forewings",
    color="color_1",
):
    specs = [
        [{} for i in range(NUM_CHROMOSOMES)],
        [{} for i in range(NUM_CHROMOSOMES)],
        [{} for i in range(NUM_CHROMOSOMES)],
        [{} for i in range(NUM_CHROMOSOMES)],
        [{} for i in range(NUM_CHROMOSOMES)],
        [{} for i in range(NUM_CHROMOSOMES)],
    ]
    row_titles = [
        f'<b style="color: {PHASE_COLORS[0]}">Training<br>Attributions</b>',
        f'<b style="color: {PHASE_COLORS[0]}">Training<br>RMSE</b>',
        f'<b style="color: {PHASE_COLORS[1]}">Validation<br>Attributions</b>',
        f'<b style="color: {PHASE_COLORS[1]}">Validation<br>RMSE</b>',
        f'<b style="color: {PHASE_COLORS[2]}">Testing<br>Attributions</b>',
        f'<b style="color: {PHASE_COLORS[2]}">Testing<br>RMSE</b>',
    ]
    phase_names = [
        "Training RMSE",
        "Validation RMSE",
        "Testing RMSE",
    ]
    fig = make_subplots(
        rows=6,
        cols=NUM_CHROMOSOMES,
        shared_yaxes="rows",
        subplot_titles=[f"Chromosome {i}" for i in range(1, NUM_CHROMOSOMES + 1)],
        row_titles=row_titles,
        specs=specs,
        horizontal_spacing=0.0025,
        vertical_spacing=0.05,
    )
    data_to_save = {}
    phase_rmses = [[], [], []]
    chromosome_position_metadata = []
    for chromosome in range(1, NUM_CHROMOSOMES + 1):
        all_phase_data, position_metadata = load_data(
            EXP_DIR, species, wing, color, chromosome
        )

        all_nominal_positions = [list(all_phase_data[i][0].keys()) for i in range(3)]
        assert (
            all_nominal_positions[0]
            == all_nominal_positions[1]
            == all_nominal_positions[2]
        ), "All position must match across phases."

        chromosome_position_metadata.append([])
        for nom_pos in all_nominal_positions[0]:
            if nom_pos == len(position_metadata):
                nom_pos -= 1
            chromosome_position_metadata[chromosome - 1].append(
                position_metadata[nom_pos]
            )

        data_to_save[chromosome] = {
            "window_size": all_nominal_positions[0][0],
            "nominal_positions": all_nominal_positions[0],
            "real_position_metadata": chromosome_position_metadata[chromosome - 1],
            "attributions": {},
            "rmses": {},
            "pearsonr": {},
        }
        all_real_positions = [
            x[1] for x in chromosome_position_metadata[chromosome - 1]
        ]
        for pi, phase_data in enumerate(all_phase_data):
            phase_row = pi * 2 + 1
            attrs, rmse, pearsonr = phase_data

            nominal_window_positions = list(attrs.keys())

            plot_data = collect_plot_data(attrs, nominal_window_positions)

            data_to_save[chromosome]["rmses"][PHASE_NAMES[pi]] = rmse
            data_to_save[chromosome]["pearsonr"][PHASE_NAMES[pi]] = pearsonr
            data_to_save[chromosome]["attributions"][PHASE_NAMES[pi]] = {}

            colors = [(245, 138, 66), (66, 245, 102), (47, 74, 196), (189, 32, 131)]
            for i, edit_opt in enumerate(EDIT_OPTIONS):
                means = np.abs(plot_data[edit_opt]).mean(axis=0)
                stds = np.abs(plot_data[edit_opt]).std(axis=0)
                y_upper = means + stds
                y_lower = means - stds
                data_to_save[chromosome]["attributions"][PHASE_NAMES[pi]][edit_opt] = {
                    "means": means.tolist(),
                    "stds": stds.tolist(),
                }

                std_color = f"rgba({','.join(str(c) for c in colors[i])},0.2)"
                # fig.add_trace(
                #    go.Scatter(
                #        x=nominal_window_positions + nominal_window_positions[::-1],
                #        y=y_upper.tolist() + y_lower.tolist()[::-1],
                #        fill="toself",
                #        fillcolor=std_color,
                #        line=dict(color="rgba(255, 255, 255, 0)"),
                #        hoverinfo="skip",
                #        legendgroup=f"edit_{edit_opt}",
                #        showlegend=False,
                #        name=edit_opt,
                #    ),
                #    col=chromosome,
                #    row=phase_row,
                # )
                line_color = f"rgb({','.join(str(c) for c in colors[i])})"
                fig.add_trace(
                    go.Scatter(
                        x=nominal_window_positions,
                        y=means,
                        hovertemplate="Attribution: %{y}"
                        + "<br>Nominal Position: %{x}<br>"
                        + "%{text}",
                        text=[
                            f"Scaffold: {scaffold}<br>Real Position: {real_position}"
                            for scaffold, real_position in chromosome_position_metadata[
                                chromosome - 1
                            ]
                        ],
                        marker=dict(size=4),
                        mode="lines+markers",
                        line=dict(color=line_color, width=2.5),
                        name=edit_opt,
                        legendgroup=f"edit_{edit_opt}",
                        showlegend=chromosome == 1 and pi == 0,
                    ),
                    col=chromosome,
                    row=phase_row,
                )
                if i == 0:
                    fig.add_trace(
                        go.Bar(
                            x=[0],
                            # y=[-rmse],
                            y=[rmse],
                            legendgroup=phase_names[pi],
                            showlegend=False,
                            name=phase_names[pi],
                            marker=dict(color="red"),
                            hovertemplate=f"RMSE: {rmse}<br>Pearson R: {pearsonr}<br>Chromosome{chromosome}",
                            text=[f"RMSE: {rmse}"],
                        ),
                        row=phase_row + 1,
                        col=chromosome,
                    )
                    fig.update_xaxes(
                        showticklabels=False, row=phase_row + 1, col=chromosome
                    )
                if i == 0:
                    phase_rmses[pi].append(rmse)
    fig.update_layout(
        title=f'Edit changes in <b style="color: #b06a15">{species}</b> <b style="color: #016304">{wing}</b> for phenotype <b style="color: #1c0036">{color}</b>',
        title_font_size=40,
        height=1000,
        width=5000,
    )
    return fig, data_to_save


# fig, data_to_save = create_plot_and_data(
#    species="erato", wing="forewings", color="color_1"
# )
# fig

species, wing, color, chromosome = "erato", "forewings", "color_3", 18
fig, data_to_save = create_plot_for_one_chromosome(
    species=species, wing=wing, color=color, chromosome=chromosome, phase="test"
)
pio.write_image(fig, "tmp.svg")
fig.show()

In [20]:
from tqdm import tqdm

from gtp.dataloading.tools import save_json

configs = []
for species in ["erato", "melpomene"]:
    for wing in ["forewings", "hindwings"]:
        for color in ["color_1", "color_2", "color_3", "total"]:
            configs.append((species, wing, color))

for species, wing, color in tqdm(configs, desc="Creating Plots", colour="#33cc22"):
    fig, data_to_save = create_plot_and_data(species=species, wing=wing, color=color)
    basename = f"{species}_{wing}_{color}"

    save_json(data_to_save, OUTPUT_DIR / f"{basename}.json")
    fig.write_html(OUTPUT_DIR / f"{basename}.html")

Creating Plots: 100%|[38;2;51;204;34m██████████[0m| 16/16 [04:24<00:00, 16.50s/it]
