# Aggregation Edit Analysis
The purpose of this notebook is to analyze the aggregated effect of editing certain locations of the gene

In [1]:
import os
from pathlib import Path
from dataclasses import dataclass
import tempfile

import torch

from gtp.configs.loaders import load_configs
from gtp.configs.project import GenotypeToPhenotypeConfigs
from gtp.dataloading.path_collectors import (
    get_experiment_directory,
    get_post_processed_genotype_directory,
)
from gtp.dataloading.tools import collect_chromosome
from gtp.models.net import SoyBeanNet
from gtp.options.process_attribution import ProcessAttributionOptions

# SPECIFY GPU
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
device = "cuda"

# Set temp directory
tempfile.tempdir = "/home/carlyn.1/tmp"
print(tempfile.gettempdir())


class GeneEditingDashboardState:
    def __init__(
        self,
        wing: str,
        species: str,
        color: str,
        chromosome: int,
        exp_name: str,
        config_path: str,
    ):
        self.wing = wing
        self.species = species
        self.color = color
        self.chromosome = chromosome
        self.exp_name = exp_name
        self.configs: GenotypeToPhenotypeConfigs = load_configs(config_path)
        self.model = None
        self.camids = None
        self.data = None

        self.load_data_and_model()

    def load_data_and_model(self):
        match self.exp_name:
            case "base":
                drop_out_prob = 0.75
                out_dims = 1
                out_dims_start_idx = 0
                insize = 3
                hidden_dim = 10
            case "pca_10":
                drop_out_prob = 0.75
                out_dims = 10
                out_dims_start_idx = 0
                insize = 3
                hidden_dim = 10
            case _:
                raise NotImplementedError(
                    f"Experiment {self.exp_name} isn't implemented for this visualization. Please create a case and update configs."
                )

        opts: ProcessAttributionOptions = ProcessAttributionOptions(
            drop_out_prob=drop_out_prob,
            out_dims=out_dims,
            out_dims_start_idx=out_dims_start_idx,
            insize=insize,
            hidden_dim=hidden_dim,
            species=self.species,
            chromosome=self.chromosome,
            color=self.color,
            wing=self.wing,
            exp_name=self.exp_name,
        )

        processed_genotype_dir = (
            get_post_processed_genotype_directory(self.configs.io)
            / self.configs.experiment.genotype_scope
        )

        # Collect genotype data
        self.camids, self.data = collect_chromosome(
            processed_genotype_dir, self.species, self.chromosome
        )

        self.model = SoyBeanNet(
            window_size=self.data.shape[1],
            num_out_dims=opts.out_dims,
            insize=opts.insize,
            hidden_dim=opts.hidden_dim,
            drop_out_prob=opts.drop_out_prob,
        )

        experiment_dir = get_experiment_directory(
            self.configs.io,
            species=self.species,
            wing=self.wing,
            color=self.color,
            chromosome=self.chromosome,
            exp_name=self.exp_name,
        )

        self.model.load_state_dict(
            torch.load(experiment_dir / "model.pt", weights_only=True)
        )
        self.model = self.model.to(device)
        self.model.eval()


"""
EDIT HERE
"""
gene_editing_dashboard_state = GeneEditingDashboardState(
    wing="forewings",
    species="erato",
    color="color_3",
    chromosome=18,
    exp_name="pca_10",
    config_path=Path("../configs/default.yaml"),
)

/home/carlyn.1/tmp


In [None]:
import math
from copy import copy
from collections import defaultdict

from tqdm.auto import tqdm
from ipywidgets import VBox
import plotly.graph_objs as go
import numpy as np
from torch.utils.data import DataLoader


def get_model_input(camid):
    input_idx = np.where(gene_editing_dashboard_state.camids == camid)[0][0]
    x = gene_editing_dashboard_state.data[input_idx]
    x = torch.tensor(x).unsqueeze(0).unsqueeze(0).float().to(device)
    return x


def get_model_output(input_x):
    with torch.no_grad():
        return gene_editing_dashboard_state.model(input_x)


def get_edited_input(batch, edit_loc, edit, ws):
    input_x = batch.detach().clone()
    match edit:
        case "AA":
            edit_vec = torch.tensor([0, 0, 1])
        case "Aa/aA":
            edit_vec = torch.tensor([0, 1, 0])
        case "aa":
            edit_vec = torch.tensor([1, 0, 0])
        case "zero-out":
            edit_vec = torch.tensor([0, 0, 0])
        case _:
            pass

    if ws != 0:
        min_idx = max(0, edit_loc - ws)
        max_idx = min(input_x.shape[2], edit_loc + ws)
        edit_vec = edit_vec.unsqueeze(0).repeat(max_idx - min_idx, 1)

    if ws == 0:
        input_x[:, 0, edit_loc] = edit_vec
    else:
        input_x[:, 0, min_idx:max_idx] = edit_vec

    return input_x


def get_gene_edit_attrs(ws=10_000, batch_size=16):
    dataloader = DataLoader(
        gene_editing_dashboard_state.data,
        batch_size=batch_size,
        num_workers=4,
        shuffle=False,
    )
    edit_options = ["AA", "Aa/aA", "aa", "zero-out"]
    change_tracker = defaultdict(lambda: defaultdict(list))
    for batch in tqdm(
        dataloader, desc="Collecting Batched Edits", leave=True, position=1
    ):
        input_x = torch.tensor(batch).unsqueeze(1).float().to(device)
        org_model_pca_output = get_model_output(input_x)

        chromosome_length = input_x.shape[2]
        num_of_edits = math.ceil(chromosome_length / (ws * 2))
        for edit_num in tqdm(
            range(num_of_edits),
            desc="recording changes",
            colour="#8822CC",
            leave=False,
            position=0,
        ):
            edit_loc = min((edit_num * ws * 2) + ws, chromosome_length)
            for edit_value in edit_options:
                edited_input_x = get_edited_input(input_x, edit_loc, edit_value, ws)
                edited_model_pca_output = get_model_output(edited_input_x)
                diff = (edited_model_pca_output - org_model_pca_output).detach().cpu()
                l2_dist = np.sqrt((diff**2).sum(1).numpy())
                change_tracker[edit_loc][edit_value].append(l2_dist)

    return change_tracker


def get_gene_edit_attrs_by_pca(ws=10_000, batch_size=16):
    dataloader = DataLoader(
        gene_editing_dashboard_state.data,
        batch_size=batch_size,
        num_workers=4,
        shuffle=False,
    )
    edit_options = ["AA", "Aa/aA", "aa", "zero-out"]
    change_tracker = defaultdict(lambda: defaultdict(list))
    for batch in tqdm(
        dataloader, desc="Collecting Batched Edits", leave=True, position=1
    ):
        input_x = torch.tensor(batch).unsqueeze(1).float().to(device)
        org_model_pca_output = get_model_output(input_x)

        chromosome_length = input_x.shape[2]
        num_of_edits = math.ceil(chromosome_length / (ws * 2))
        for edit_num in tqdm(
            range(num_of_edits),
            desc="recording changes",
            colour="#8822CC",
            leave=False,
            position=0,
        ):
            edit_loc = min((edit_num * ws * 2) + ws, chromosome_length)
            for edit_value in edit_options:
                edited_input_x = get_edited_input(input_x, edit_loc, edit_value, ws)
                edited_model_pca_output = get_model_output(edited_input_x)
                diff = (edited_model_pca_output - org_model_pca_output).detach().cpu()
                l2_dist = diff.abs().numpy()
                change_tracker[edit_loc][edit_value].append(l2_dist)

    return change_tracker


In [None]:
# change_tracker = get_gene_edit_attrs(ws=10_000, batch_size=16)
change_tracker = get_gene_edit_attrs_by_pca(ws=10_000, batch_size=16)

Collecting Batched Edits:   0%|          | 0/31 [00:00<?, ?it/s]


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

recording changes:   0%|          | 0/83 [00:00<?, ?it/s]

Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/carlyn.1/miniconda3/envs/gtp/lib/python3.10/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/home/carlyn.1/miniconda3/envs/gtp/lib/python3.10/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/home/carlyn.1/miniconda3/envs/gtp/lib/python3.10/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/home/carlyn.1/miniconda3/envs/gtp/lib/python3.10/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/home/carlyn.1/miniconda3/envs/gtp/lib/python3.10/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/home/carlyn.1/miniconda3/envs/gtp/lib/python3.10/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/home/carlyn.1/m

In [43]:
import plotly.graph_objs as go

edit_options = ["AA", "Aa/aA", "aa", "zero-out"]
gene_positions = list(change_tracker.keys())
plot_data = {}
for edit_opt in edit_options:
    all_pos = []
    for pos in gene_positions:
        all_data = np.concatenate(change_tracker[pos][edit_opt])
        all_pos.append(all_data)
    plot_data[edit_opt] = np.stack(all_pos, axis=1)

fig = go.Figure()
colors = [(245, 138, 66), (66, 245, 102), (47, 74, 196), (189, 32, 131)]
for i, edit_opt in enumerate(edit_options):
    means = plot_data[edit_opt].mean(axis=0)
    stds = plot_data[edit_opt].std(axis=0)
    y_upper = means + stds
    y_lower = means - stds

    std_color = f"rgba({','.join(str(c) for c in colors[i])},0.2)"
    print(std_color)
    fig.add_traces(
        go.Scatter(
            x=gene_positions + gene_positions[::-1],
            y=y_upper.tolist() + y_lower.tolist()[::-1],
            fill="toself",
            fillcolor=std_color,
            line=dict(color="rgba(255, 255, 255, 0)"),
            hoverinfo="skip",
            showlegend=False,
            name=edit_opt,
        )
    )
    line_color = f"rgb({','.join(str(c) for c in colors[i])})"
    fig.add_traces(
        go.Scatter(
            x=gene_positions,
            y=means,
            mode="lines",
            line=dict(color=line_color, width=2.5),
            name=edit_opt,
        )
    )

fig


rgba(245,138,66,0.2)
rgba(66,245,102,0.2)
rgba(47,74,196,0.2)
rgba(189,32,131,0.2)
