In [1]:
from pathlib import Path
import numpy as np
from rich import print
from tqdm import tqdm

EXP_DIR = Path("/local/scratch/carlyn.1/dna/training_output/deepnet_8_4_all/")
DNA_ROOT_DIR = Path("/local/scratch/carlyn.1/dna/processed/genome/")
OUTPUT_DIR = Path("/home/carlyn.1/dna-trait-analysis/tmp/edit_plots")
OUTPUT_DIR.mkdir(exist_ok=True)
EDIT_OPTIONS = ["AA", "Aa/aA", "aa", "zero-out"]
GENE_INFO_DATA = Path("/home/carlyn.1/dna-trait-analysis/data/gene_alignment.csv")

In [2]:
from gtp.dataloading.tools import load_json, collect_chromosome_position_metadata, extract_metadata_from_scaffold_str


def load_data(exp_dir, species, wing, color):
    position_metadata = collect_chromosome_position_metadata(
        DNA_ROOT_DIR, species, chromosome="all"
    )
    data_dir = exp_dir / f"{species}_{wing}_{color}_chromosome_all"
    all_data = []
    for phase in ["training", "validation", "test"]:
        data = np.load(
            data_dir / f"{phase}_lrp_attributions.npy", allow_pickle=True
        )
        metrics = load_json(data_dir / f"{phase}_metrics.json")
        pearsonr = metrics["pearsonr"]
        rmse = metrics["rmse"]
        all_data.append([data, rmse, pearsonr])

    return all_data, position_metadata

In [3]:
import pandas as pd

gene_info_df = pd.read_csv(GENE_INFO_DATA)

In [None]:
from collections import defaultdict
from plotly.subplots import make_subplots
import plotly.graph_objs as go
import plotly.io as pio

import pandas as pd

PHASE_NAMES = ["training", "validation", "test"]
NUM_CHROMOSOMES = 21
PHASE_COLORS = ["#78006e", "#062475", "#016316"]


def create_plot(
    species="erato", wing="forewings", color="color_1", phase="test"
):
    fig = make_subplots(1, 1)
    data_to_save = {}
    phase_rmses = [[], [], []]
    chromosome_position_metadata = []
    all_phase_data, position_metadata = load_data(
        EXP_DIR, species, wing, color
    )

    print(all_phase_data[0][0].shape)
    all_nominal_positions = [list(range(len(all_phase_data[i][0][0]))) for i in range(3)]
    print(len(all_nominal_positions[0]), len(all_nominal_positions[1]), len(all_nominal_positions[2]))
    assert (
        all_nominal_positions[0] == all_nominal_positions[1] == all_nominal_positions[2]
    ), "All position must match across phases."

    chromosome_position_metadata = []
    for nom_pos in all_nominal_positions[0]:
        if nom_pos == len(position_metadata):
            nom_pos -= 1
        chromosome_position_metadata.append(position_metadata[nom_pos])

    print(all_nominal_positions[0][0], all_nominal_positions[0][1])
    data_to_save = {
        "window_size": all_nominal_positions[0][0],
        "nominal_positions": all_nominal_positions[0].tolist(),
        "real_position_metadata": chromosome_position_metadata.tolist(),
        "attributions": {},
        "rmses": {},
        "pearsonr": {},
    }
    phase_idx = PHASE_NAMES.index(phase)
    phase_data = all_phase_data[phase_idx]
    attrs, rmse, pearsonr = phase_data
    plot_data = attrs[0]
    print("Plot data collected")

    data_to_save["rmses"][phase] = rmse
    data_to_save["pearsonr"][phase] = pearsonr

    colors = [(245, 138, 66), (66, 245, 102), (47, 74, 196), (189, 32, 131)]
    all_real_positions = [x[1] for x in chromosome_position_metadata]
    all_chromosomes = [extract_metadata_from_scaffold_str(x[0], species)[0] for x in chromosome_position_metadata]
    data_to_save["attributions"][phase] = plot_data.tolist()
    
    chromosome_data = defaultdict(lambda: defaultdict(list))
    for i, (scaffold_str, real_position) in enumerate(chromosome_position_metadata):
        chromosome = extract_metadata_from_scaffold_str(scaffold_str, species)[0]
        chromosome_data[chromosome]['real_position'].append(real_position)
        chromosome_data[chromosome]['nominal_position'].append(all_nominal_positions[phase_idx][i])
        chromosome_data[chromosome]['attribution'].append(plot_data[i])
        chromosome_data[chromosome]['scaffold'].append(scaffold_str)
    
    
    x_chromo = {
        "chromosome": [],
        "positions" : []
    }
    for chromosome in tqdm(range(1, 22), desc="Plotting per chromosome"):
        if chromosome > 2: break
        data = chromosome_data[chromosome]


        text = [
            f"Chromosome: {chromosome} <br>Scaffold: {scaffold}<br>Real Position: {real_position}"
            for scaffold, real_position in zip(data["scaffold"], data["real_position"])
        ]
        
        x_chromo["chromosome"].append(str(chromosome))
        x_chromo["positions"].append(int(np.array(data["nominal_position"]).mean()))

        fig.add_trace(
            go.Scattergl(
                x=data["nominal_position"],
                y=data["attribution"],
                hovertemplate="Attribution: %{y}"
                + "<br>Nominal Position: %{x}<br>"
                + "%{text}",
                text=text,
                mode="markers",
                marker=dict(
                    color="black" if chromosome % 2 == 0 else "grey",
                ),
            ),
        )

    fig.update_layout(
        title=f'Edit changes in <b style="color: #b06a15">{species}</b> <b style="color: #016304">{wing}</b> for phenotype <b style="color: #1c0036">{color}',
        title_font_size=16,
        height=400,
        width=800,
        xaxis=dict(
            tickmode='array', # Important: Set tickmode to 'array' for custom ticks
            tickvals=x_chromo['positions'],
            ticktext=x_chromo['chromosome'],
            title='Chromosomes' # Optional: Add an x-axis title
        ),
    )

    #max_y = max([max(d.y) for d in fig.data])
    #filtered_gene_info_df = gene_info_df
    #for gidx, gene_info in filtered_gene_info_df.iterrows():
    #    start_x = int(gene_info[f"{species}_start"])
    #    end_x = int(gene_info[f"{species}_end"])
    #    if start_x == -1 or pd.isna(gene_info.gene_symbol):
    #        continue
    #    opacity = 1.0
    #    if species == "erato":
    #        opacity = gene_info.erato_containment_score
    #    fig.add_trace(
    #        go.Scatter(
    #            x=[start_x, end_x, end_x, start_x, start_x],
    #            y=[0, 0, max_y, max_y, 0],
    #            fill="toself",
    #            line_color="#A30016",
    #            fillcolor="#A30016",
    #            mode="lines",
    #            opacity=0.2,
    #            name=f"{gene_info.gene_id} ({gene_info.gene_symbol})",
    #        )
    #    )
    #    fig.data = tuple([fig.data[-1]] + list(fig.data[:-1]))  # Move this to the back

    return fig, data_to_save

species, wing, color = "erato", "forewings", "color_1"
fig, data_to_save = create_plot(
    species=species, wing=wing, color=color, phase="test"
)
pio.write_image(fig, "tmp.svg")
fig.show()

AttributeError: 'int' object has no attribute 'tolist'

In [None]:
from tqdm import tqdm

from gtp.dataloading.tools import save_json

configs = []
for species in ["erato", "melpomene"]:
    for wing in ["forewings", "hindwings"]:
        for color in ["color_1", "color_2", "color_3", "total"]:
            configs.append((species, wing, color))

for species, wing, color in tqdm(configs, desc="Creating Plots", colour="#33cc22"):
    fig, data_to_save = create_plot(species=species, wing=wing, color=color)
    basename = f"{species}_{wing}_{color}"

    save_json(data_to_save, OUTPUT_DIR / f"{basename}.json")
    fig.write_html(OUTPUT_DIR / f"{basename}.html")

Creating Plots:   0%|[38;2;51;204;34m          [0m| 0/16 [00:00<?, ?it/s]

Plotting per chromosome:  29%|██▊       | 6/21 [00:02<00:05,  2.55it/s]
Creating Plots:   0%|[38;2;51;204;34m          [0m| 0/16 [01:46<?, ?it/s]


TypeError: Object of type ndarray is not JSON serializable