# Heliconius Average Edit Viewer

This notebook will aggregate edits for a window within the dna

## Initial Configurations
Edit configurations below

In [None]:
import os
from pathlib import Path
from dataclasses import dataclass

import torch

from gtp.configs.loaders import load_configs
from gtp.configs.project import GenotypeToPhenotypeConfigs
from gtp.dataloading.path_collectors import (
    get_experiment_directory,
    get_post_processed_genotype_directory,
)
from gtp.dataloading.tools import collect_chromosome
from gtp.models.net import SoyBeanNet
from gtp.options.process_attribution import ProcessAttributionOptions

# SPECIFY GPU
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = "cuda"


class ExperimentState:
    def __init__(
        self,
        wing: str,
        species: str,
        color: str,
        chromosome: int,
        exp_name: str,
        config_path: str,
    ):
        self.wing = wing
        self.species = species
        self.color = color
        self.chromosome = chromosome
        self.exp_name = exp_name
        self.configs: GenotypeToPhenotypeConfigs = load_configs(config_path)
        self.model = None
        self.camids = None
        self.data = None

        self.load_data_and_model()

    def load_data_and_model(self):
        match self.exp_name:
            case "base":
                drop_out_prob = 0.75
                out_dims = 1
                out_dims_start_idx = 0
                insize = 3
                hidden_dim = 10
            case "pca_10":
                drop_out_prob = 0.75
                out_dims = 10
                out_dims_start_idx = 0
                insize = 3
                hidden_dim = 10
            case _:
                raise NotImplementedError(
                    f"Experiment {self.exp_name} isn't implemented for this visualization. Please create a case and update configs."
                )

        opts: ProcessAttributionOptions = ProcessAttributionOptions(
            drop_out_prob=drop_out_prob,
            out_dims=out_dims,
            out_dims_start_idx=out_dims_start_idx,
            insize=insize,
            hidden_dim=hidden_dim,
            species=self.species,
            chromosome=self.chromosome,
            color=self.color,
            wing=self.wing,
            exp_name=self.exp_name,
        )

        processed_genotype_dir = (
            get_post_processed_genotype_directory(self.configs.io)
            / self.configs.experiment.genotype_scope
        )

        # Collect genotype data
        self.camids, self.data = collect_chromosome(
            processed_genotype_dir, self.species, self.chromosome
        )

        self.model = SoyBeanNet(
            window_size=self.data.shape[1],
            num_out_dims=opts.out_dims,
            insize=opts.insize,
            hidden_dim=opts.hidden_dim,
            drop_out_prob=opts.drop_out_prob,
        )

        experiment_dir = get_experiment_directory(
            self.configs.io,
            species=self.species,
            wing=self.wing,
            color=self.color,
            chromosome=self.chromosome,
            exp_name=self.exp_name,
        )

        self.model.load_state_dict(
            torch.load(experiment_dir / "model.pt", weights_only=True)
        )
        self.model = self.model.to(device)
        self.model.eval()


"""
EDIT HERE
"""
experiment_state = ExperimentState(
    wing="forewings",
    species="erato",
    color="total",
    chromosome=10,
    exp_name="pca_10",
    config_path=Path("../configs/default.yaml"),
)

## Dashboard
Explore the model below

In [None]:
import ipywidgets as widgets
from copy import copy
from dataclasses import dataclass

from ipywidgets import VBox, HBox
import plotly.graph_objs as go
import numpy as np
from captum.attr import LRP

def get_model_input(camid):
    input_idx = np.where(experiment_state.camids == camid)[0][0]
    x = experiment_state.data[input_idx]
    x = torch.tensor(x).unsqueeze(0).unsqueeze(0).float().to(device)
    return x


def get_model_output(input_x):
    with torch.no_grad():
        return experiment_state.model(input_x)

test_camid = experiment_state.camids[0]
with torch.no_grad():
    test_model_input = get_model_input(test_camid)
    test_model_output = get_model_output(test_model_input)
test_model_output.shape

In [None]:
from io import BytesIO
from ipywidgets import HBox
import pandas as pd
from tqdm import tqdm

from PIL import Image
from matplotlib import cm

from gtp.dataloading.path_collectors import get_post_processed_phenotype_directory

cfgs = experiment_state.configs
species = experiment_state.species
wing = experiment_state.wing
color = experiment_state.color
phenotype_folder = get_post_processed_phenotype_directory(cfgs.io)
pca_pheno_df = pd.read_csv(phenotype_folder / f"{species}_{wing}_{color}" / "data.csv")

camid_pca_map = {}
for camid in tqdm(experiment_state.camids, desc="caching camids"):
    results = pca_pheno_df.loc[pca_pheno_df.camid == camid]
    pca_vector = results.iloc[:1, 1:].to_numpy()
    
    camid_pca_map[camid] = pca_vector

if color == "total":
    color = "color_total"
proj_matrices_dir = Path(cfgs.io.default_root, "dna/projection_matrices")
pca_df = pd.read_csv(proj_matrices_dir / f"{species}_{wing}_{color}.csv")
pca_w = pca_df.to_numpy()

    
def get_org_pca_vector(camid):
    return camid_pca_map[camid]


def get_proj_matrix():
    return pca_w


def create_proj_img(pca_w, pca_vector, return_raw=False):
    proj_img_m = pca_w @ pca_vector.T
    if experiment_state.color == "total":
        proj_img_m = proj_img_m.reshape(3, 300, 300)  # Range between [-1, 1]
        proj_img_m = np.transpose(proj_img_m, axes=(1, 2, 0))
    else:
        proj_img_m = proj_img_m.reshape(300, 300)  # Range between [-1, 1]
        
    if return_raw:
        return proj_img_m
        
    # proj_img_m += 1
    # proj_img_m /= 2  # [0, 1]
    proj_img_m[proj_img_m <= 0] = 0
    proj_img_m[proj_img_m > 0] = 1
    im = Image.fromarray(np.uint8(proj_img_m * 255))

    return im

def get_proj_img(pca_vector, return_raw=False):
    pca_w = get_proj_matrix()

    im = create_proj_img(pca_w, pca_vector, return_raw=return_raw)

    return im

def get_model_pca_view(input_x, camid):
    output = get_model_output(input_x)
    pca_output = output.detach().cpu().numpy()
    D = pca_output.shape[-1]
    pca_vector = get_org_pca_vector(camid)
    pca_vector[:, :D] = pca_output
    return pca_vector


org_pca_vec = get_org_pca_vector(test_camid)
org_img = get_proj_img(org_pca_vec)

org_img

In [None]:
model_pca_vec = get_model_pca_view(test_model_input, test_camid)
model_img = get_proj_img(model_pca_vec)

model_img

In [None]:
window_start = 100
window_end = 2000
edit_value = "aA/Aa"

match edit_value:
    case "AA":
        edit_vec = torch.tensor([0, 0, 1])
    case "aa":
        edit_vec = torch.tensor([1, 0, 0])
    case "aA/Aa":
        edit_vec = torch.tensor([0, 1, 0])
    case _:
        edit_vec = torch.tensor([0, 0, 0])


edit_vec = edit_vec.unsqueeze(0).repeat(window_end - window_start, 1)
edit_input = test_model_input.detach().clone()
edit_input[0][0][window_start:window_end] = edit_vec

model_edit_pca_vec = get_model_pca_view(edit_input, test_camid)
model_edit_img = get_proj_img(model_edit_pca_vec)

model_edit_img

In [None]:
def get_diff_img(org, edit):
    diff_img = (np.array(org).astype(np.float64) - np.array(edit)).astype(np.float64)
    if len(diff_img.shape) == 3:
        diff_img = np.abs(diff_img).sum(-1)
    


    diff_img = Image.fromarray(
        np.uint8(
            cm.jet(
                diff_img,
            )
            * 255
        )
    )

    return diff_img

diff_img_tmp = get_diff_img(model_img, model_edit_img)
diff_img_tmp

In [None]:
from gtp.dataloading.tools import collect_chromosome_position_metadata
from gtp.dataloading.path_collectors import get_post_processed_genotype_directory

genotype_dir = get_post_processed_genotype_directory(experiment_state.configs.io)
dna_dir = genotype_dir / experiment_state.configs.experiment.genotype_scope
print(dna_dir)
position_metadata = collect_chromosome_position_metadata(
    dna_dir, experiment_state.species, experiment_state.chromosome
)

position_metadata[10]

In [None]:
normalizing_constant = { # Obtained from the greated difference image at the peak attribution for each color
    "color_1" : 0.0015756611967540771,
    "color_2" : 0.0029504226500927213,
    "color_3" : 0.0008137173370284477,
    "color_total" : 0.004299188027193457
}

saved_positions = [
    (4_651_958, -139, 338), # base position, window left, window right
    (4_569_576, -20, 22), # Total Strongest
    (4_619_982, -101, 82),
    (4_583_043, -559, 432), # Control, should be noisy / bad
    (4_674_406, -134, 79),
    (4_674_406, -134, 79), # Color 3 Strongest
    (4_647_619, -223, 626), # Color 2 Strongest
    (4_651_658, -314, 161), # Color 1 Strongest
    (4_637_657, 0, 4_637_727 - 4_637_657), # 8 - Phenotype 1
    (4_639_853, 0, 4_641_535 - 4_639_853), # 9 - Phenotype 2
    (4_657_452, 0, 4_658_207 - 4_657_452), # 10 - Phenotype 3
    (4_666_909, 0, 4_670_474 - 4_666_909), # 11 - Phenotype 4
    (4_700_932, 0, 4_708_441 - 4_700_932), # 12 - Phenotype 5
    
    (4634681, 0, 4634700 - 4634681), # 13 - 7-In-Sd	SD-B
    (4634860, 0, 4634877 - 4634860), # 14 - 7-In-Sd	SD-A
    (4635132, 0, 4635151 - 4635132), # 15 - 7-In-Sd	SD-C
    (4635525, 0, 4635544 - 4635525), # 16 - 7-In-Sd	SD-D
    
    (4645632, 0, 4645650 - 4645632), # 17 - 8-In-St	WntA-ST2-gRNA1
    (4645503, 0, 4645522 - 4645503), # 18 - 8-In-St	WntA-ST2-gRNA-Y
    (4645649, 0, 4645668 - 4645649), # 19 - 8-In-St	WntA-ST2-gRNA-Z
    
    (4647368, 0, 4647387 - 4647368), # 20 - 9-In-St	ST1-A1
    (4647363, 0, 4647382 - 4647363), # 21 - 9-In-St	ST1-A2
    (4647197, 0, 4647216 - 4647197), # 22 - 9-In-St	ST1-A7
    (4647463, 0, 4647482 - 4647463), # 23 - 9-In-St	ST1-P2
    
    (4634681, 0, 4635544 - 4634681), # 24 - 7-In-Sd
    (4645503, 0, 4645668 - 4645503), # 25 - 8-In-St
    (4647197, 0, 4647482 - 4647197), # 26 - 9-In-St
]

saved_name = "new-location"
sp = saved_positions[1]

In [None]:
def float_to_uint_np(x):
    return (x * 255).astype(np.uint8)

In [None]:
from tqdm import tqdm
from pathlib import Path

image_save_path = Path("/local/scratch/carlyn.1/dna_proj_images", saved_name)


predicition_save_dir = image_save_path / "predictions"
edit_save_dir = image_save_path / "edits"
predicition_save_dir.mkdir(exist_ok=True, parents=True)
edit_save_dir.mkdir(exist_ok=True, parents=True)


base_position = sp[0]
window_start = base_position + sp[1]
window_end = base_position + sp[2]

with open(image_save_path / "dna_location.txt", 'w') as f:
    f.write(f"{window_start} -> {window_end}")

nominal_start = None
nominal_end = None
max_diff = 0
for i, (_, real_pos) in enumerate(position_metadata):
    if nominal_start is None:
        if real_pos < window_start:
            continue
        elif real_pos == window_start:
            nominal_start = i
            continue
        else:
            nominal_start = max(i-1, 0)
            continue
    elif nominal_end is None:
        if real_pos < window_end:
            continue
        elif real_pos == window_end:
            nominal_end = i
            break
        else:
            nominal_end = i
            break
    else:
        break
    
print(f"New window - nominal: ({nominal_start}, {nominal_end})")
print(f"New window - real: ({position_metadata[nominal_start][1]}, {position_metadata[nominal_end][1]})")

print(color)
if color == "color_total":
    zero_arr = np.zeros((300, 300, 3)).astype(np.float64)
else:
    zero_arr = np.zeros((300, 300)).astype(np.float64)
aggregated_diff = {
    "AA": zero_arr.copy(), "aA/Aa": zero_arr.copy(), "aa": zero_arr.copy(), "zero": zero_arr.copy(),
}
aggregated_pos_diff = {
    "AA": zero_arr.copy(), "aA/Aa": zero_arr.copy(), "aa": zero_arr.copy(), "zero": zero_arr.copy(),
}
aggregated_neg_diff = {
    "AA": zero_arr.copy(), "aA/Aa": zero_arr.copy(), "aa": zero_arr.copy(), "zero": zero_arr.copy(),
}
bad_camids = 0
tbar = tqdm(experiment_state.camids, desc="Aggregating species edits", colour="#864E04")
for camid in tbar:
    org_pca_vec = get_org_pca_vector(camid)
    if org_pca_vec.shape[0] == 0:
        print(f"Skipping bad CAMID: {camid}")
        bad_camids += 1
        continue
    org_img = get_proj_img(org_pca_vec, return_raw=True)
    
    with torch.no_grad():
        model_input = get_model_input(camid)
        model_output = get_model_output(model_input)
        
        model_pca_vec = get_model_pca_view(model_input, camid)
        
        model_img_to_save = get_proj_img(model_pca_vec) 
        model_img_to_save.save(predicition_save_dir / f"{camid}.png")
        
        model_img = get_proj_img(model_pca_vec, return_raw=True)
        

        for edit_value in ["AA", "aA/Aa", "aa", "zero"]:
            match edit_value:
                case "AA":
                    edit_vec = torch.tensor([0, 0, 1])
                case "aa":
                    edit_vec = torch.tensor([1, 0, 0])
                case "aA/Aa":
                    edit_vec = torch.tensor([0, 1, 0])
                case _:
                    edit_vec = torch.tensor([0, 0, 0])

            edit_vec = edit_vec.unsqueeze(0).repeat(nominal_end - nominal_start, 1)
            edit_input = model_input.detach().clone()
            edit_input[0][0][nominal_start:nominal_end] = edit_vec

            model_edit_pca_vec = get_model_pca_view(edit_input, camid)
            model_edit_img = get_proj_img(model_edit_pca_vec, return_raw=True)
            
            diff_raw = model_img - model_edit_img
            
            pos_idx = diff_raw > 0
            
            pos_img = np.zeros_like(diff_raw)
            pos_img[pos_idx] = diff_raw[pos_idx].copy()
            
            neg_img = np.zeros_like(diff_raw)
            neg_img[~pos_idx] = diff_raw[~pos_idx].copy()
            neg_img = -neg_img
            
            diff_img = np.abs(diff_raw)
            
            #if len(diff_img.shape) == 3:
            #    diff_img = diff_img.sum(-1)
            #    pos_img = pos_img.sum(-1)
            #    neg_img = neg_img.sum(-1)
            
            max_diff = max(max_diff, diff_img.max())
            aggregated_diff[edit_value] += np.abs(diff_img)
            aggregated_pos_diff[edit_value] += np.abs(pos_img)
            aggregated_neg_diff[edit_value] += np.abs(neg_img)
            tbar.set_postfix({
                "max_diff" : max_diff,
                "pos_img_chng": pos_img.sum(),
                "neg_img_chng": neg_img.sum(),
                "img_chng_diff": pos_img.sum() - neg_img.sum()
            })
            
            model_edit_img_to_save = get_proj_img(model_edit_pca_vec)
            model_edit_img_to_save.save(edit_save_dir / f"{camid}_{edit_value.replace("/", '-')}.png")
        
            

print(max_diff)

In [None]:
import matplotlib.pyplot as plt

#threshold = 0.0
fig, axs = plt.subplots(nrows=1, ncols=4, figsize=(10, 10))
for r, c, ev in [(0, 0, "AA"), (0, 1, "aA/Aa"), (0, 2, "aa"), (0, 3, "zero")]:
    num_imgs = (len(experiment_state.camids) - bad_camids)
    
    img_to_show = aggregated_diff[ev].copy()
    img_to_show = img_to_show.sum(-1)
    
    img_to_show /= num_imgs
    print(img_to_show.max())
    img_to_show /= img_to_show.max()
    #img_to_show /= normalizing_constant[color]
    im = axs[c].imshow(img_to_show, vmin=0, vmax=1)
    #im = axs[c].imshow(img_to_show)
    axs[c].set_title(f"Edit: {ev}")
    axs[c].set_xticks([])
    axs[c].set_yticks([])
    #cbar = plt.colorbar(im)

plt.tight_layout()
title_text = f"{species.capitalize()} | {wing.capitalize()} | {color.capitalize()} | Chr: {experiment_state.chromosome} | ({position_metadata[nominal_start][1]} - {position_metadata[nominal_end][1]})"
fig.suptitle(title_text, fontsize=18, y=0.68)
plt.show()

In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import CenteredNorm

max_pos = 0
max_neg = 0
num_imgs = (len(experiment_state.camids) - bad_camids)
for r, c, ev in [(0, 0, "AA"), (0, 1, "aA/Aa"), (0, 2, "aa"), (0, 3, "zero")]:
    max_pos = max(aggregated_pos_diff[ev].max() / num_imgs, max_pos)
    max_neg = max(aggregated_neg_diff[ev].max() / num_imgs, max_neg)
    
#threshold = 0.0
# Create a CenteredNorm object
norm = CenteredNorm(vcenter=0)
fig, axs = plt.subplots(nrows=3, ncols=4, figsize=(15, 12))
for color_row in range(3):
    for r, c, ev in [(0, 0, "AA"), (0, 1, "aA/Aa"), (0, 2, "aa"), (0, 3, "zero")]:
        pos_img_to_show = aggregated_pos_diff[ev].copy()
        pos_img_to_show = pos_img_to_show[:, :, color_row]
        pos_img_to_show /= num_imgs
        pos_img_to_show /= max_pos
        #pos_img_to_show /= pos_img_to_show.max()
        neg_img_to_show = aggregated_neg_diff[ev].copy()
        neg_img_to_show = neg_img_to_show[:, :, color_row]
        neg_img_to_show /= num_imgs
        neg_img_to_show /= max_neg
        #neg_img_to_show /= neg_img_to_show.max()
        
        #img_to_show = np.zeros(list(pos_img_to_show.shape) + [3])
        #img_to_show[:, :, 0] = pos_img_to_show
        #img_to_show[:, :, 2] = neg_img_to_show
        
        img_to_show = pos_img_to_show - neg_img_to_show
        #img_to_show = neg_img_to_show - pos_img_to_show
        
        img = axs[color_row, c].imshow(img_to_show, cmap='RdBu_r', norm=norm)
        img = axs[color_row, c].set_aspect("auto")
        #colorbar = fig.colorbar(img, ax=axs[color_row, c])
        
        #colorbar.set_label('Average Difference')
        if color_row == 0:
            axs[color_row, c].set_title(f"{ev}", fontsize=24, weight="bold")
        axs[color_row, c].set_xticks([])
        axs[color_row, c].set_yticks([])
        for side in ['top', 'bottom', 'right', 'left']:
            axs[color_row, c].spines[side].set_visible(False)
        #cbar = plt.colorbar(im)
        
        if c == 0:
            axs[color_row, c].set_ylabel(f"Color {color_row+1}", fontsize=18, weight="bold")
            
color_title = color if "total" not in color else "total"
            
title_text = f"{species.capitalize()} - {wing.capitalize()} | {color_title.capitalize()} | [{position_metadata[nominal_start][1]} - {position_metadata[nominal_end][1]}]"
fig.suptitle(title_text, fontsize=28, y=1.01, weight="bold")

plt.subplots_adjust(hspace=0.0, wspace=0.00)
plt.tight_layout(w_pad=0, h_pad=0.0)

plt.savefig(f"/home/carlyn.1/dna-trait-analysis/output/{saved_name}.svg", bbox_inches='tight')

plt.show()