# Heliconius Gene Editing Dashboard

This dashboard acts as a tool to inspect perturbations on genotype-to-phenotype models.

Upon editing / changing genes, the resulting effect on the phenotype will be visualized.

## Initial Configurations
Edit configurations below

In [None]:
import os
from pathlib import Path
from dataclasses import dataclass

import torch

from gtp.configs.loaders import load_configs
from gtp.configs.project import GenotypeToPhenotypeConfigs
from gtp.dataloading.path_collectors import (
    get_experiment_directory,
    get_post_processed_genotype_directory,
)
from gtp.dataloading.tools import collect_chromosome
from gtp.models.net import SoyBeanNet
from gtp.options.process_attribution import ProcessAttributionOptions

# SPECIFY GPU
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"


class GeneEditingDashboardState:
    def __init__(
        self,
        wing: str,
        species: str,
        color: str,
        chromosome: int,
        exp_name: str,
        config_path: str,
    ):
        self.wing = wing
        self.species = species
        self.color = color
        self.chromosome = chromosome
        self.exp_name = exp_name
        self.configs: GenotypeToPhenotypeConfigs = load_configs(config_path)
        self.model = None
        self.camids = None
        self.data = None

        self.load_data_and_model()

    def load_data_and_model(self):
        opts: ProcessAttributionOptions = ProcessAttributionOptions(
            drop_out_prob=0.75,
            out_dims=1,
            out_dims_start_idx=0,
            insize=3,
            hidden_dim=10,
            species=self.species,
            chromosome=self.chromosome,
            color=self.color,
            wing=self.wing,
            exp_name=self.exp_name,
        )

        processed_genotype_dir = (
            get_post_processed_genotype_directory(self.configs.io)
            / self.configs.experiment.genotype_scope
        )

        # Collect genotype data
        self.camids, self.data = collect_chromosome(
            processed_genotype_dir, self.species, self.chromosome
        )

        self.model = SoyBeanNet(
            window_size=self.data.shape[1],
            num_out_dims=opts.out_dims,
            insize=opts.insize,
            hidden_dim=opts.hidden_dim,
            drop_out_prob=opts.drop_out_prob,
        )

        experiment_dir = get_experiment_directory(
            self.configs.io,
            species=self.species,
            wing=self.wing,
            color=self.color,
            chromosome=self.chromosome,
            exp_name=self.exp_name,
        )

        self.model.load_state_dict(
            torch.load(experiment_dir / "model.pt", weights_only=True)
        )
        self.model = self.model.cuda()
        self.model.eval()


"""
EDIT HERE
"""
gene_editing_dashboard_state = GeneEditingDashboardState(
    wing="forewings",
    species="erato",
    color="color_3",
    chromosome=18,
    exp_name="base",
    config_path=Path("../configs/default.yaml"),
)

## Dashboard
Explore the model below

In [None]:
import ipywidgets as widgets


In [None]:
from copy import copy

import ipywidgets as widgets
from ipywidgets import VBox

import plotly.graph_objs as go

import numpy as np
from captum.attr import LRP


def get_model_input(camid):
    input_idx = np.where(gene_editing_dashboard_state.camids == camid)[0][0]
    x = gene_editing_dashboard_state.data[input_idx]
    x = torch.tensor(x).unsqueeze(0).unsqueeze(0).float().cuda()
    return x


def get_model_output(input_x):
    with torch.no_grad():
        return gene_editing_dashboard_state.model(input_x)[0]


def get_attr(camid, target=0):
    att_model = LRP(gene_editing_dashboard_state.model)
    gene_editing_dashboard_state.model.zero_grad()
    x = get_model_input(camid)
    x.requires_grad = True
    attr = att_model.attribute(x, target=target)
    # For LRP, this (ONE-HOT state ex. [0,0,1] attributions) should be sum.
    # This is because the attribution scores should all add up to be the find value in the prediction, so averaging could break that.
    attr = attr.sum(-1)
    attr = attr[0, 0]  # Only has 1 channel, just extract it and is one batch item
    attr = attr.detach().cpu().numpy()

    return attr


@dataclass
class CurrentGeneState:
    text: str
    index: int


current_gene_state = CurrentGeneState(text="", index=0)

camid_dropdown_widget = widgets.Dropdown(
    options=gene_editing_dashboard_state.camids,
    value=gene_editing_dashboard_state.camids[0],
    description="CAMID",
    disabled=False,
)

current_gene_text_widget = widgets.Text(
    value=current_gene_state.text,
    placeholder="Gene",
    description="Current Gene",
    disabled=True,
    layout=widgets.Layout(width="100%"),
)


class ManhattanPlotInteractive:
    def __init__(self):
        self.display_box = VBox([])
        self.update_data()

    def update_data(self, *args):
        self.camid = camid_dropdown_widget.value
        self.Y = get_attr(self.camid)
        self.X = np.arange(len(self.Y))
        N = 1_000
        top_attr_idx = np.argsort(np.abs(self.Y))[::-1]
        self.top_n_idx = top_attr_idx[:N]
        self.top_X = self.X[self.top_n_idx]
        self.top_Y = self.Y[self.top_n_idx]

        self.update_plot()

    def update_plot(self, *args):
        self.figure_widget = go.FigureWidget(
            [go.Scatter(x=self.top_X, y=self.top_Y, mode="markers")]
        )

        # Set Initial Colors
        self.init_colors = ["#0000aa"] * len(self.top_X)
        scatter_plot = self.get_scatter_plot()
        scatter_plot.marker.color = self.init_colors

        # Register on-click events
        scatter_plot.on_click(self.handle_click)

        self.display_box.children = [self.figure_widget]

    def get_scatter_plot(self):
        return self.figure_widget.data[0]

    def handle_click(self, trace, points, state):
        print(points)
        idx = points.point_inds[0]
        x = points.xs[0]
        y = points.ys[0]

        current_gene_state.index = x
        input_x = get_model_input(camid_dropdown_widget.value)
        vec_state = input_x[0][0][x]
        gene_str = ""
        if vec_state[0].item() == 1:
            gene_str = "aa"
        elif vec_state[1].item() == 1:
            gene_str = "aA/Aa"
        if vec_state[2].item() == 1:
            gene_str = "AA"

        current_gene_state.text = f"({gene_str}) | Attribution: ({y}) | Position: ({x})"
        current_gene_text_widget.value = current_gene_state.text

        scatter_plot = self.get_scatter_plot()
        colors = copy(self.init_colors)
        colors[idx] = "#ff0000"
        with self.figure_widget.batch_update():
            scatter_plot.marker.color = colors


manhattan_plot = ManhattanPlotInteractive()

camid_dropdown_widget.observe(manhattan_plot.update_data, "value")


In [None]:
from ipywidgets import HBox

edit_btns = widgets.RadioButtons(
    options=["aa", "aA/Aa", "AA", "zero-out"],
)
edit_lbl = widgets.Label(value="Gene Edit")
gene_edit_btns = HBox([edit_lbl, edit_btns])

window_size_dd = widgets.Dropdown(
    options=list(range(0, 100001, 10000)),
    value=0,
    description="Window Size",
    disable=False,
)

In [None]:
from io import BytesIO
from ipywidgets import HBox
import pandas as pd

from PIL import Image
from matplotlib import cm

from gtp.dataloading.path_collectors import get_post_processed_phenotype_directory


def get_proj_matrix():
    cfgs = gene_editing_dashboard_state.configs
    species = gene_editing_dashboard_state.species
    wing = gene_editing_dashboard_state.wing
    color = gene_editing_dashboard_state.color
    proj_matrices_dir = Path(cfgs.io.default_root, "dna/projection_matrices")
    pca_df = pd.read_csv(proj_matrices_dir / f"{species}_{wing}_{color}.csv")
    pca_w = pca_df.to_numpy()
    return pca_w


def create_proj_img_bytes(pca_w, pca_vector):
    proj_img_m = pca_w @ pca_vector.T
    proj_img_m = proj_img_m.reshape(300, 300)  # Range between [-1, 1]
    # proj_img_m += 1
    # proj_img_m /= 2  # [0, 1]
    proj_img_m[proj_img_m <= 0] = 0
    proj_img_m[proj_img_m > 0] = 1
    im = Image.fromarray(np.uint8(cm.bwr(proj_img_m) * 255))

    im_bytes = BytesIO()
    im.save(im_bytes, format="PNG")
    return im_bytes


def get_org_pca_vector(camid):
    cfgs = gene_editing_dashboard_state.configs
    species = gene_editing_dashboard_state.species
    wing = gene_editing_dashboard_state.wing
    color = gene_editing_dashboard_state.color
    phenotype_folder = get_post_processed_phenotype_directory(cfgs.io)
    pca_df = pd.read_csv(phenotype_folder / f"{species}_{wing}_{color}" / "data.csv")
    results = pca_df.loc[pca_df.camid == camid]
    pca_vector = results.iloc[:1, 1:].to_numpy()
    return pca_vector


def get_proj_img(pca_vector):
    pca_w = get_proj_matrix()
    im_bytes = create_proj_img_bytes(pca_w, pca_vector)

    return im_bytes


class PCAProjectionView:
    def __init__(self):
        self.display_box = HBox([])
        self.update_data()

    def _update_data(self):
        camid = camid_dropdown_widget.value
        pca_output = get_org_pca_vector(camid)
        self.proj_img_bytes = get_proj_img(pca_output)

    def update_data(self, *args):
        self._update_data()
        self.update_plot()

    def update_plot(self):
        data_list = [round(x, 2) for x in self.pca_output.tolist()]
        self.figure_widget = go.FigureWidget(
            [
                go.Table(
                    header=dict(values=["PCA"]),
                    cells=dict(values=[data_list]),
                )
            ]
        )

        self.proj_img = widgets.Image(
            value=self.proj_img_bytes.getvalue(),
            format="png",
        )

        self.display_box.children = [self.figure_widget, self.proj_img]


class OrgPCAProjectionView(PCAProjectionView):
    def _update_data(self):
        camid = camid_dropdown_widget.value
        pca_output = get_org_pca_vector(camid)
        print(pca_output.shape)
        self.proj_img_bytes = get_proj_img(pca_output)
        self.pca_output = pca_output[0]


class ModelPCAProjectionView(PCAProjectionView):
    def _get_model_input(self, camid):
        input_x = get_model_input(camid)
        return input_x

    def _update_data(self):
        camid = camid_dropdown_widget.value
        input_x = self._get_model_input(camid)
        output = get_model_output(input_x)
        pca_output = output.detach().cpu().numpy()
        D = pca_output.shape[0]
        pca_vector = get_org_pca_vector(camid)
        pca_vector[0, :D] = pca_output
        # TODO replace pca_output into vector and pass
        self.proj_img_bytes = get_proj_img(pca_vector)
        self.pca_output = pca_vector[0]


class ModelEditPCAProjectionView(ModelPCAProjectionView):
    def _get_model_input(self, camid):
        input_x = super()._get_model_input(camid)
        loc = current_gene_state.index
        edit_gene_str = edit_btns.value
        if edit_gene_str == "AA":
            edit_vec = torch.tensor([0, 0, 1])
        elif edit_gene_str == "aa":
            edit_vec = torch.tensor([1, 0, 0])
        elif edit_gene_str == "aA/Aa":
            edit_vec = torch.tensor([0, 1, 0])
        else:
            edit_vec = torch.tensor([0, 0, 0])

        ws = window_size_dd.value
        if ws != 0:
            min_idx = max(0, loc - ws)
            max_idx = min(input_x.shape[2], loc + ws)
            edit_vec = edit_vec.unsqueeze(0).repeat(max_idx - min_idx, 1)

        loc = 0  # TODO: attr_order[gene_selection_dd.options.index(gene_selection_dd.value)]

        if ws == 0:
            input_x[0][0][loc] = edit_vec
        else:
            input_x[0][0][min_idx:max_idx] = edit_vec

        return input_x


class ModelEditDiffPCAProjectionView(PCAProjectionView):
    def __init__(self, model_pca_proj_view, model_edit_pca_proj_view):
        self.model_pca_proj_view: PCAProjectionView = model_pca_proj_view
        self.model_edit_pca_proj_view: PCAProjectionView = model_edit_pca_proj_view
        super().__init__()

    def _update_data(self):
        model_pca = self.model_pca_proj_view.pca_output
        model_edit_pca = self.model_edit_pca_proj_view.pca_output
        self.pca_output = model_edit_pca - model_pca

        model_img = Image.open(self.model_pca_proj_view.proj_img_bytes)
        edit_model_img = Image.open(self.model_edit_pca_proj_view.proj_img_bytes)

        diff_img = (np.array(model_img) - np.array(edit_model_img)).astype(np.float64)
        diff_img = diff_img.sum(-1)
        diff_img -= diff_img.min()
        if diff_img.max() != 0:
            diff_img /= diff_img.max()

        diff_img = Image.fromarray(
            np.uint8(
                cm.bwr(
                    diff_img,
                )
                * 255
            )
        )
        self.proj_img_bytes = BytesIO()
        diff_img.save(self.proj_img_bytes, format="png")


original_pca_projection = OrgPCAProjectionView()
camid_dropdown_widget.observe(original_pca_projection.update_data, "value")

model_pca_projection = ModelPCAProjectionView()
camid_dropdown_widget.observe(model_pca_projection.update_data, "value")

model_edit_pca_projection = ModelEditPCAProjectionView()
camid_dropdown_widget.observe(model_edit_pca_projection.update_data, "value")
edit_btns.observe(model_edit_pca_projection.update_data, "value")
window_size_dd.observe(model_edit_pca_projection.update_data, "value")
current_gene_text_widget.observe(model_edit_pca_projection.update_data, "value")

model_edit_diff_pca_projection = ModelEditDiffPCAProjectionView(
    model_pca_projection, model_edit_pca_projection
)
camid_dropdown_widget.observe(model_edit_diff_pca_projection.update_data, "value")
edit_btns.observe(model_edit_diff_pca_projection.update_data, "value")
window_size_dd.observe(model_edit_diff_pca_projection.update_data, "value")
current_gene_text_widget.observe(model_edit_diff_pca_projection.update_data, "value")

In [None]:
# edit_btns.observe()

In [None]:
from ipywidgets import TwoByTwoLayout

projection_grid = TwoByTwoLayout(
    top_left=original_pca_projection.display_box,
    top_right=model_pca_projection.display_box,
    bottom_left=model_edit_pca_projection.display_box,
    bottom_right=model_edit_diff_pca_projection.display_box,
)

In [None]:
display(camid_dropdown_widget)
display(manhattan_plot.display_box)
display(current_gene_text_widget)
display(gene_edit_btns)
display(window_size_dd)
display(projection_grid)