In [1]:
from pathlib import Path
import pandas as pd

gwas_points_file_path = Path("/local/scratch/carlyn.1/dna/metadata/gwas_Herato1001_4570000_4850000.txt")
gwas_df = pd.read_csv(gwas_points_file_path, delimiter="\t")

annotation_file_path = Path("/local/scratch/carlyn.1/dna/metadata/wnta_annotation_mel2x5_eratodem1.xlsx")
anno_df = pd.read_excel(annotation_file_path,  sheet_name="Complete_wnta_annt_demopho")
anno_df.head()

Unnamed: 0,Scaffold,Name,Type,Start,End,Length,Intervals,Direction,info
0,Herato1001,wnta,gene,4626226,4707387,81161,1,forward,vanbellemg concaha
1,Herato1001,Sd1,phenotype,4637657,4637727,70,1,forward,vanbellemg concaha
2,Herato1001,Sd2,phenotype,4639853,4641535,1682,1,forward,vanbellemg concaha
3,Herato1001,St,phenotype,4657452,4658207,755,1,forward,vanbellemg concaha
4,Herato1001,Ly1,phenotype,4666909,4670474,3565,1,forward,vanbellemg concaha


In [2]:
gwas_df.head()

Unnamed: 0,chr,rs,ps,n_miss,allele1,allele0,af,beta,se,logl_H1,l_remle,l_mle,p_wald,p_lrt,p_score,neg_log10_p
0,Herato1001,.,4639773,0,A,C,0.303,-56.16131,1.913198,-2210.711,6808.26,100000.0,1.912446e-108,5.500456e-109,3.923419e-53,107.718411
1,Herato1001,.,4639853,0,C,G,0.298,-56.35173,1.929564,-2212.244,100000.0,100000.0,9.120289e-108,2.555516e-108,5.462232e-53,107.039991
2,Herato1001,.,4639936,0,C,T,0.299,-56.80083,1.949337,-2212.931,100000.0,100000.0,1.8095339999999998e-107,5.08816e-108,6.341158e-53,106.742433
3,Herato1001,.,4640008,0,A,T,0.305,-56.54915,1.880458,-2204.094,294.4706,100000.0,1.141958e-111,7.255491000000001e-112,9.685145e-54,110.94235
4,Herato1001,.,4640096,0,C,A,0.305,-57.67782,1.856515,-2194.273,199.9867,100000.0,4.1961619999999997e-116,3.865583e-116,1.322778e-54,115.377148


In [14]:
EXP_DIR = Path("/local/scratch/carlyn.1/dna/training_output/pca_10/")
DNA_ROOT_DIR = Path("/local/scratch/carlyn.1/dna/processed/genotypes/genome/")
EDIT_OPTIONS = ["AA", "Aa/aA", "aa", "zero-out"]
PHASE_NAMES = ["training", "validation", "test"]

In [15]:
import numpy as np
from gtp.dataloading.tools import load_json, collect_chromosome_position_metadata


def load_data(exp_dir, species, wing, color, chromosome):
    position_metadata = collect_chromosome_position_metadata(
        DNA_ROOT_DIR, species, chromosome
    )
    data_dir = exp_dir / f"{species}_{wing}_{color}_chromosome_{chromosome}"
    all_data = []
    for phase in ["training", "validation", "test"]:
        data = np.load(
            data_dir / f"{phase}_windowed_editing_attributions.npy", allow_pickle=True
        ).item()
        metrics = load_json(data_dir / f"{phase}_metrics.json")
        pearsonr = metrics["pearsonr"]
        rmse = metrics["rmse"]
        all_data.append([data, rmse, pearsonr])

    return all_data, position_metadata


def collect_plot_data(data, nominal_positions):
    plot_data = {}
    for i, edit_opt in enumerate(EDIT_OPTIONS):
        all_pos = []
        for pos in nominal_positions:
            all_data = np.concatenate(data[pos][i])
            all_pos.append(all_data)
        plot_data[edit_opt] = np.stack(all_pos, axis=1)
    return plot_data

In [16]:
bp_pos_min = anno_df.Start.min()
bp_pos_max = anno_df.End.max()

In [17]:
anno_df.Type.unique()

array(['gene', 'phenotype', 'atacSeq peak', 'exon'], dtype=object)

In [18]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

species = "erato"
wing = "forewings"
color = "color_3"
chromosome = 10
all_phase_data, position_metadata = load_data(EXP_DIR, species=species, wing=wing, color=color, chromosome=10)

all_nominal_positions = [list(all_phase_data[i][0].keys()) for i in range(3)]
#assert (
#    all_nominal_positions[0] == all_nominal_positions[1] == all_nominal_positions[2]
#), "All position must match across phases."

chromosome_position_metadata = []
chromosome_position_metadata_start_end = []
for nom_pos in all_nominal_positions[2]:
    if nom_pos == len(position_metadata):
        nom_pos -= 1
    chromosome_position_metadata.append(position_metadata[nom_pos])
    chromosome_position_metadata_start_end.append([position_metadata[max(0, nom_pos-10)], position_metadata[min(len(position_metadata)-1, nom_pos+10)]])

phase_data = all_phase_data[PHASE_NAMES.index("test")]
attrs, rmse, pearsonr = phase_data
nominal_window_positions = list(attrs.keys())
plot_data = collect_plot_data(attrs, nominal_window_positions)
print(int(nominal_window_positions[0])-int(nominal_window_positions[1]))

all_real_positions = [x[1] for x in chromosome_position_metadata]
all_real_positions_start_and_end = [[x1[1], x2[1]] for x1, x2 in chromosome_position_metadata_start_end]

start_pos = None
end_pos = all_real_positions[-1]
for i, rp in enumerate(all_real_positions):
    if start_pos is None and rp > bp_pos_min:
        start_pos = max(0, i-1)
    if rp > bp_pos_max:
        end_pos = i
        break



-20


In [None]:
plot_data["AA"].max()

In [None]:
w = 0.0005
max_height = 0
fig, ax = plt.subplots(1, 1, figsize=(24, 9))
for i, edit_opt in enumerate(EDIT_OPTIONS):
    means = np.abs(plot_data[edit_opt]).mean(axis=0)
    y = means[start_pos:end_pos+1]
    max_height = max(max(y), max_height)
    for j, (s, e) in enumerate(all_real_positions_start_and_end[start_pos:end_pos+1]):
        v = y[j]
        rect = patches.Rectangle((s, v-(w/2)), e-s, w, linewidth=1, edgecolor='black', facecolor='purple', alpha=0.1)
        ax.add_patch(rect)
    ax.scatter(all_real_positions[start_pos:end_pos+1], means[start_pos:end_pos+1], label=edit_opt, s=10)
    
for i, row in anno_df.iterrows():
    y = max_height + w * 10
    if row["Type"] == "gene":
        y -= w*2
        c = "blue"
    elif row["Type"] == "phenotype":
        y -= w*4
        c = "yellow"
    elif row["Type"] == "atacSeq peak":
        y -= w*6
        c = "red"
    elif row["Type"] == "exon":
        c = "green"
        y -= w*8
    rect = patches.Rectangle((row.Start, y), row.End-row.Start, w, linewidth=1, edgecolor='black', facecolor=c)
    ax.add_patch(rect)

ax.legend()
    
plt.grid()
plt.tight_layout()
plt.show()

In [None]:
import plotly.graph_objects as go

fig = go.Figure()

x = gwas_df["ps"]
y = gwas_df["neg_log10_p"]
y -= y.min()
y /= y.max()
y *= max_height

gobj = go.Scatter(
    x=x, 
    y=y, 
    name="GWAS", 
    mode='markers', 
    marker=dict(size=5, color='black'),
)
fig.add_trace(gobj)


for i, edit_opt in enumerate(EDIT_OPTIONS):
    means = np.abs(plot_data[edit_opt]).mean(axis=0)
    y = means[start_pos:end_pos+1]
    x = all_real_positions[start_pos:end_pos+1]
    s, e = list(zip(*all_real_positions_start_and_end[start_pos:end_pos+1]))
    error_x = np.array(e) - np.array(x)
    error_x_minus = np.array(x) - np.array(s)
    gobj = go.Scatter(
        x=x, 
        y=y, 
        name=edit_opt, 
        mode='markers', 
        marker=dict(size=5), 
        error_x=dict(type="data", array=error_x, arrayminus=error_x_minus),
    )
    fig.add_trace(gobj)

for i, row in anno_df.iterrows():
    y = max_height + w * 10
    if row["Type"] == "gene":
        y -= w*2
        c = "blue"
    elif row["Type"] == "phenotype":
        y -= w*4
        c = "yellow"
    elif row["Type"] == "atacSeq peak":
        y -= w*6
        c = "red"
    elif row["Type"] == "exon":
        c = "green"
        y -= w*8
    fig.add_shape(
        type="rect",
        xref="x",  # Reference x-axis data coordinates
        yref="y",  # Reference y-axis data coordinates
        x0=row.Start,    # Starting x-coordinate of the rectangle
        y0=y-(w/2),    # Starting y-coordinate of the rectangle
        x1=row.End,    # Ending x-coordinate of the rectangle
        y1=y+(w/2),    # Ending y-coordinate of the rectangle
        line=dict(
            color=c,
            width=0.002,
        ),
        fillcolor=c,
        opacity=1,
        layer="below" # Place the rectangle below the data points
    )
    
    fig.add_trace(
        go.Scatter(
            x=[row.Start + (row.End-row.Start)/2], 
            y=[y],
            name=row.Type,
            hovertemplate=f"{row.Type} {row.name} : {row.Start}-{row.End}",
            showlegend=False,
            mode='markers', 
            marker=dict(color="black", size=1),
        )
    )

fig.update_layout(dict(hovermode='closest', xaxis=dict(showspikes=True), xaxis_spikemode='across'))

fig.write_html(f"../output/results/{species}_{wing}_{color}_{chromosome}.html")
fig.show()

In [None]:
all_dfs = []
for i, edit_opt in enumerate(EDIT_OPTIONS):
    means = np.abs(plot_data[edit_opt]).mean(axis=0)
    y = means[start_pos:end_pos+1]
    x = all_real_positions[start_pos:end_pos+1]
    
    df = pd.DataFrame(np.stack((x, y), axis=1), columns=["Position", "Attribution"])
    df["Edit_Type"] = edit_opt
    df.Position = df.Position.astype(int)
    all_dfs.append(df)

all_df = pd.concat(all_dfs)

all_df.to_csv(f"../output/results/{species}_{wing}_{color}_{chromosome}.csv")