In [1]:
from pathlib import Path
import numpy as np

EXP_DIR = Path("/local/scratch/carlyn.1/dna/training_output/pca_10/")
DNA_ROOT_DIR = Path("/local/scratch/carlyn.1/dna/processed/genome/")
OUTPUT_DIR = Path("/home/carlyn.1/dna-trait-analysis/tmp/edit_plots")
EDIT_OPTIONS = ["AA", "Aa/aA", "aa", "zero-out"]
GENE_INFO_DATA = Path("/home/carlyn.1/dna-trait-analysis/data/gene_alignment.csv")

In [2]:
from gtp.dataloading.tools import load_json, collect_chromosome_position_metadata


def load_data(exp_dir, species, wing, color, chromosome):
    position_metadata = collect_chromosome_position_metadata(
        DNA_ROOT_DIR, species, chromosome
    )
    data_dir = exp_dir / f"{species}_{wing}_{color}_chromosome_{chromosome}"
    all_data = []
    for phase in ["training", "validation", "test"]:
        data = np.load(
            data_dir / f"{phase}_windowed_editing_attributions.npy", allow_pickle=True
        ).item()
        metrics = load_json(data_dir / f"{phase}_metrics.json")
        pearsonr = metrics["pearsonr"]
        rmse = metrics["rmse"]
        all_data.append([data, rmse, pearsonr])

    return all_data, position_metadata


def collect_plot_data(data, nominal_positions):
    plot_data = {}
    for i, edit_opt in enumerate(EDIT_OPTIONS):
        all_pos = []
        for pos in nominal_positions:
            all_data = np.concatenate(data[pos][i])
            all_pos.append(all_data)
        plot_data[edit_opt] = np.stack(all_pos, axis=1)
    return plot_data

In [3]:
import pandas as pd

gene_info_df = pd.read_csv(GENE_INFO_DATA)

In [36]:
from tqdm import tqdm

from plotly.subplots import make_subplots
import plotly.graph_objs as go
import plotly.io as pio

import pandas as pd

PHASE_NAMES = ["training", "validation", "test"]
NUM_CHROMOSOMES = 21
PHASE_COLORS = ["#78006e", "#062475", "#016316"]
WINDOW_SIZE = 1000  # Example value, can be adjusted as needed

# Getting Data
species, wing, color, chromosome = (
    "erato",
    "forewings",
    "color_3",
    18,
)  # Example values, can be changed as needed

combined_df = None
for chromosome in tqdm(
    range(1, NUM_CHROMOSOMES + 1), desc="Collecting Chromosome Data"
):
    all_phase_data, position_metadata = load_data(
        EXP_DIR, species, wing, color, chromosome
    )

    # Gathering positions
    all_nominal_positions = [list(all_phase_data[i][0].keys()) for i in range(3)]
    assert (
        all_nominal_positions[0] == all_nominal_positions[1] == all_nominal_positions[2]
    ), "All position must match across phases."

    # HACK: to avoid issues with the last position being out of bounds
    chromosome_position_metadata = []
    for nom_pos in all_nominal_positions[0]:
        if nom_pos == len(position_metadata):
            nom_pos -= 1
        chromosome_position_metadata.append(position_metadata[nom_pos])

    phase = "test"
    phase_data = all_phase_data[PHASE_NAMES.index(phase)]
    attrs, rmse, pearsonr = phase_data
    nominal_window_positions = list(attrs.keys())
    plot_data = collect_plot_data(attrs, nominal_window_positions)

    colors = [(245, 138, 66), (66, 245, 102), (47, 74, 196), (189, 32, 131)]
    all_real_positions = [x[1] for x in chromosome_position_metadata]
    X = all_real_positions
    edit_means = {
        edit_opt: np.abs(plot_data[edit_opt]).mean(axis=0) for edit_opt in EDIT_OPTIONS
    }

    df = pd.DataFrame(
        {
            "position": X,
            "AA": edit_means["AA"],
            "Aa/aA": edit_means["Aa/aA"],
            "aa": edit_means["aa"],
            "zero-out": edit_means["zero-out"],
            "CHR": [chromosome] * len(X),
            "rmse": [rmse] * len(X),
            "pearsonr": [pearsonr] * len(X),
        }
    )

    gene_annotations = []
    filtered_gene_info_df = gene_info_df[
        gene_info_df[f"{species}_chromosome"].astype(int) == chromosome
    ]
    for gidx, gene_info in filtered_gene_info_df.iterrows():
        start_x = int(gene_info[f"{species}_start"])
        end_x = int(gene_info[f"{species}_end"])
        if start_x == -1 or pd.isna(gene_info.gene_symbol):
            continue
        opacity = 1.0
        if species == "erato":
            opacity = gene_info.erato_containment_score

        gene_annotations.append(
            [
                start_x,
                end_x,
                gene_info.gene_id,
                gene_info.gene_symbol,
                gene_info.erato_containment_score,
            ]
        )

    def is_overlap(p1s, p1e, p2s, p2e):
        if p2s >= p1s and p2s <= p2e:
            return True
        if p2e >= p1s and p2e <= p2e:
            return True
        if p2s < p1s and p2e > p1e:
            return True

        return False

    window_size = (all_nominal_positions[0][1] - all_nominal_positions[0][0]) // 2
    df["GENE"] = ""
    for ga in gene_annotations:
        for i, df_row in df.iterrows():
            if is_overlap(
                df_row["position"] - window_size,
                df_row["position"] + window_size,
                ga[0],
                ga[1],
            ):
                if df.loc[i, "GENE"] == "":
                    df.loc[i, "GENE"] = f"{ga[3]}"
                else:
                    # Append gene ID and symbol to existing GENE column
                    df.loc[i, "GENE"] = df.loc[i, "GENE"] + " | " + f"{ga[3]}"

    if combined_df is None:
        combined_df = df
    else:
        combined_df = pd.concat([combined_df, df], ignore_index=True)

combined_df.head()

Collecting Chromosome Data: 100%|██████████| 21/21 [00:20<00:00,  1.04it/s]


Unnamed: 0,position,AA,Aa/aA,aa,zero-out,CHR,rmse,pearsonr,GENE
0,178182,0.166561,0.218026,0.165243,0.161369,1,19.941073,0.891533,Def | TEP-B | Wnt10 | Wnt6 | Wnt9
1,441836,0.196061,0.198971,0.116384,0.123399,1,19.941073,0.891533,Def | TEP-B | Wnt10 | Wnt6 | Wnt9
2,692600,0.214451,0.190286,0.115888,0.131091,1,19.941073,0.891533,Def | TEP-B | Wnt10 | Wnt6 | Wnt9
3,908572,0.218764,0.181617,0.095782,0.120419,1,19.941073,0.891533,Def | TEP-B | Wnt10 | Wnt6 | Wnt9
4,1165229,0.209063,0.193608,0.084665,0.119833,1,19.941073,0.891533,Def | TEP-B | Wnt10 | Wnt6 | Wnt9


In [64]:
combined_df["BP"] = list(range(len(combined_df)))
combined_df["SNP"] = combined_df["position"].astype(str)
combined_df["P"] = 0

In [68]:
combined_df[80:100]

Unnamed: 0,position,AA,Aa/aA,aa,zero-out,CHR,rmse,pearsonr,GENE,BP,SNP,P
80,21606843,0.303735,0.224727,0.088365,0.139594,1,19.941073,0.891533,,80,21606843,0
81,21849669,0.228166,0.20163,0.122369,0.139825,1,19.941073,0.891533,,81,21849669,0
82,22060485,0.282924,0.174397,0.10788,0.14211,1,19.941073,0.891533,,82,22060485,0
83,22325749,0.131848,0.115147,0.092182,0.094408,1,19.941073,0.891533,,83,22325749,0
84,17357,0.452465,0.502701,0.159651,0.132977,2,20.725905,0.878181,CYP6AE34 | OR52 | UGT33S3 | UGT33S4 | PPO-A | ...,84,17357,0
85,385173,0.367723,0.3526,0.152033,0.161258,2,20.725905,0.878181,OR52 | UGT33S3 | UGT33S4 | PPO-A | PPO-B | CYP...,85,385173,0
86,316517,0.458164,0.371376,0.316714,0.256374,2,20.725905,0.878181,OR52 | UGT33S3 | UGT33S4 | PPO-A | PPO-B | CYP...,86,316517,0
87,649032,0.462391,0.337573,0.290324,0.226093,2,20.725905,0.878181,OR52 | UGT33S4 | CYP367B3,87,649032,0
88,868638,0.471203,0.444491,0.287745,0.224441,2,20.725905,0.878181,OR52 | UGT33S4 | CYP367B3,88,868638,0
89,1087177,0.596485,0.463261,0.260012,0.230145,2,20.725905,0.878181,OR52,89,1087177,0


In [71]:
import pandas as pd
from dash import Dash, html, dcc, Input, Output, callback
import dash_bio as dashbio

app = Dash()

app.layout = html.Div(
    [
        "Threshold value",
        dcc.Slider(
            id="default-manhattanplot-input",
            step=0.20,
            min=0,
            max=2,
            value=0.6,
        ),
        "Edit Type",
        dcc.RadioItems(
            id="edit-type-input",
            options=[
                {"label": "AA", "value": "AA"},
                {"label": "Aa/aA", "value": "Aa/aA"},
                {"label": "aa", "value": "aa"},
                {"label": "zero-out", "value": "zero-out"},
            ],
            value="AA",
        ),
        html.Br(),
        html.Div(
            dcc.Graph(
                id="default-dashbio-manhattanplot",
                figure=dashbio.ManhattanPlot(dataframe=combined_df),
            )
        ),
    ]
)


@callback(
    Output("default-dashbio-manhattanplot", "figure"),
    Input("default-manhattanplot-input", "value"),
    Input("edit-type-input", "value"),
)
def update_manhattanplot(threshold, opt):
    edit_opt = opt
    return dashbio.ManhattanPlot(
        dataframe=combined_df,
        genomewideline_value=threshold,
        logp=False,
        p=edit_opt,
        ylabel=f"Mean Edit Effect ({edit_opt})",
        suggestiveline_value=1,
        title=f"Manhattan Plot for {species} {wing} {color} Chromosome ({edit_opt})",
    )


app.run(debug=True, jupyter_mode="tab")

Dash app running on http://127.0.0.1:8050/



divide by zero encountered in log10


divide by zero encountered in log10



<IPython.core.display.Javascript object>