In [None]:
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload
%autoreload 2

In [None]:
# General imports
import os
import torch
import numpy as np
import pandas as pd
from copy import deepcopy 

# EUGENe imports and settings
from eugene import models
from eugene import train
from eugene import settings
settings.dataset_dir = "/cellar/users/aklie/data/eugene/revision/jores21"
settings.output_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/output/revision/jores21"
settings.logging_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/logs/revision/jores21"
settings.config_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/configs/jores21"

# EUGENe packages
import seqdata as sd
import motifdata as md

# New Jores21CNN model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F 
from eugene.models.base import _layers as layers
from eugene.models.base import _blocks as blocks
from eugene.models.base import _towers as towers


class BiConv1DTower(nn.Module):
    def __init__(
        self, 
        filters: int,
        kernel_size: int,
        input_size: int = 4, 
        n_layers: int = 1, 
        stride: int = 1, 
        dropout_rate: float = 0.15
    ):
        super().__init__()
        self.filters = filters
        self.kernel_size = kernel_size
        self.input_size = input_size
        if n_layers < 1:
            raise ValueError("At least one layer needed")
        self.n_layers = n_layers
        if (dropout_rate < 0) or (dropout_rate > 1):
            raise ValueError("Dropout rate must be a float between 0 and 1")
        self.dropout_rate = dropout_rate
        self.stride = stride
        self.layers = nn.ModuleList()
        for i in range(0, self.n_layers):
            if i == 0:
                in_channels = self.input_size
            else:
                in_channels = self.filters
            layer = nn.Conv1d(
                in_channels=in_channels,
                out_channels=self.filters,
                kernel_size=self.kernel_size,
                stride=self.stride,
                padding="same",
            )
            kernel = nn.Parameter(torch.empty((self.filters, in_channels, self.kernel_size)))
            nn.init.xavier_uniform_(kernel)
            bias = nn.Parameter(torch.empty((self.filters)))
            nn.init.zeros_(bias)
            layer.weight = kernel
            layer.bias = bias
            self.layers.append(layer)
            self.layers.append(nn.ReLU(inplace=False))
            self.layers.append(nn.Dropout(p=self.dropout_rate))

    def forward(self, x):
        x_fwd = F.conv1d(x, self.layers[0].weight, stride=self.stride, padding="same")
        x_fwd = torch.add(x_fwd.transpose(1, 2), self.layers[0].bias).transpose(1, 2)
        x_fwd = self.layers[1](x_fwd)
        x_fwd = self.layers[2](x_fwd)
        x_rev = F.conv1d(x, torch.flip(self.layers[0].weight, dims=[0, 1]), stride=self.stride, padding="same")
        x_rev = torch.add(x_rev.transpose(1, 2), self.layers[0].bias).transpose(1, 2)
        x_rev = self.layers[1](x_rev)
        x_rev = self.layers[2](x_rev)
        for i in range(1, self.n_layers):
            x_fwd = F.conv1d(x_fwd, self.layers[i*3].weight, stride=self.stride, padding="same")
            x_fwd = torch.add(x_fwd.transpose(1, 2), self.layers[i*3].bias).transpose(1, 2)
            x_fwd = self.layers[i*3+1](x_fwd)
            x_fwd = self.layers[i*3+2](x_fwd)
            x_rev = F.conv1d(x_rev, torch.flip(self.layers[i*3].weight, dims=[0, 1]), stride=self.stride, padding="same")
            x_rev = torch.add(x_rev.transpose(1, 2), self.layers[i*3].bias).transpose(1, 2)
            x_rev = self.layers[i*3+1](x_rev)
            x_rev = self.layers[i*3+2](x_rev)
        return torch.add(x_fwd, x_rev)

class Jores21CNN(nn.Module):
    def __init__(
        self,
        input_len: int,
        output_dim: int,
        filters: int = 128,
        kernel_size: int = 13,
        layers: int = 2,
        stride: int = 1,
        dropout: float = 0.15,
        hidden_dim: int = 64,
    ):
        super(Jores21CNN, self).__init__()

        # Set the attributes
        self.input_len = input_len
        self.output_dim = output_dim
        self.filters = filters
        self.kernel_size = kernel_size
        self.layers = layers
        self.stride = stride
        self.dropout = dropout

        # Create the blocks
        self.biconv = BiConv1DTower(
            filters=filters,
            kernel_size=kernel_size,
            n_layers=layers,
            stride=stride,
            dropout_rate=dropout,
        )
        self.conv = nn.Conv1d(
            in_channels=filters,
            out_channels=filters,
            kernel_size=kernel_size,
            stride=stride,
            padding="same",
        )
        self.relu = nn.ReLU(inplace=False)
        self.dropout = nn.Dropout(p=dropout)
        self.fc = nn.Linear(in_features=input_len * filters, out_features=hidden_dim)
        self.batchnorm = nn.BatchNorm1d(num_features=hidden_dim)
        self.relu2 = nn.ReLU(inplace=False)
        self.fc2 = nn.Linear(in_features=hidden_dim, out_features=output_dim)

    def forward(self, x):
        x = self.biconv(x)
        x = self.conv(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc(x.view(x.shape[0], -1))
        x = self.batchnorm(x)
        x = self.relu2(x)
        x = self.fc2(x)
        return x
    

In [None]:
biconv_tower = BiConv1DTower(
    filters=256,
    kernel_size=13,
    input_size=4,
    n_layers=3,
)

In [None]:
x = torch.randn(10, 4, 170)

In [None]:
biconv_tower(x).shape

In [None]:
model = Jores21CNN(
    input_len=170,
    output_dim=1,
    filters=256,
    kernel_size=13,
    layers=3
)

In [None]:
model

In [None]:
model(x).shape

In [None]:
import torch
import sys
sys.path.append("/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/scripts/jores21")
from jores21_helpers import BiConv1DTower, Jores21CNN

In [None]:
biconv_tower = BiConv1DTower(
    filters=256,
    kernel_size=13,
    input_size=4,
    n_layers=3,
)

In [None]:
biconv_tower

In [None]:
x = torch.randn(10, 4, 170)

In [None]:
biconv_tower(x).shape

In [None]:
model = Jores21CNN(
    input_len=170,
    output_dim=1,
    filters=256,
    kernel_size=13,
    layers=3
)

In [None]:
from eugene import plot as pl

In [None]:
pl.training_summary("/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/logs/revision/jores21/jores21_cnn_nn/leaf_trial_1")

In [None]:
pl.training_summary("/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/logs/revision/jores21/jores21_cnn/leaf_trial_1")

In [None]:
model

In [None]:
model(x).shape

In [None]:
import os
import yaml
import importlib
from eugene import settings, models
settings.config_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/configs/jores21"

In [None]:

def load_config_nn(
    config_path, 
    **kwargs
):
    # If config path is just a filename, assume it's in the default config directory
    if "/" not in config_path:
        config_path = os.path.join(settings.config_dir, config_path)
    with open(config_path, "r") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    module_name = config.pop("module")
    model_params = config.pop("model")
    arch_name = model_params["arch_name"]
    arch = model_params["arch"]
    model_type = getattr(importlib.import_module("jores21_helpers"), arch_name)
    model = model_type(**arch)
    module_type = getattr(importlib.import_module("eugene.models"), module_name)
    module = module_type(model, **config, **kwargs)
    return module

In [None]:
model = load_config_nn("jores21_cnn_nn.yaml", seed=13)

In [None]:
model(x).shape

In [None]:
model(x)

In [None]:
import torchinfo

In [None]:
torchinfo.summary(model, input_size=(10, 4, 170))

In [None]:
models.get_layer(model, "arch.biconv.dropouts")[0]

In [None]:
models.list_available_layers(model)

In [None]:
import torch.nn as nn
import torch.nn.functional as F

In [None]:
x[0].T[:10]

In [None]:
F.relu(x)[0].T[:10]

In [None]:
F.dropout(F.relu(x), p=0.3, training=True)[0].T[:10]

In [None]:
model.arch.biconv.dropouts[0](model.arch.biconv.relus[0](x))[0].T[:10]

In [None]:
import os
import yaml
import importlib
import torch
from eugene import settings, models
settings.config_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/configs/jores21"

In [None]:
model2 = models.load_config("jores21_cnn.yaml", seed=13)

In [None]:
x = torch.randn(10, 4, 170)

In [None]:
model2(x) == model2(x)

In [None]:
model2(x)

In [None]:
model2(x)

In [None]:
model2.eval()

In [None]:
model2.arch.biconv.training

In [None]:
model2.arch.biconv.training = False

In [None]:
model2.arch.biconv.training

In [None]:
torchinfo.summary(model2, input_size=(10, 4, 170))

# Jores21CNN model filter activations

In [None]:
import importlib
import logging
import os
from typing import Callable, Dict
import pandas as pd
import torch
import torch.nn as nn

class FeatureExtractor(nn.Module):
    def __init__(self, model: nn.Module, key_word: str, index: int = None):
        super().__init__()
        print("here")
        self.model = model
        layers = sorted([k for k in dict([*model.named_modules()]) if key_word in k])
        self.features = {layer: torch.empty(0) for layer in layers}
        self.handles = dict()
        self.index = index

        for layerID in layers:
            layer = dict([*self.model.named_modules()])[layerID]
            handle = layer.register_forward_hook(self.SaveOutputHook(layerID, self.index))
            self.handles[layerID] = handle
            
    def SaveOutputHook(self, layerID: str, index: int = None) -> Callable:
        def fn(layer, input, output):
            if self.index is not None:
                self.features[layerID] = output[self.index]
            else:
                self.features[layerID] = output
        return fn

    def forward(self, x, **kwargs) -> Dict[str, torch.Tensor]:
        preds = self.model(x, **kwargs)
        return self.features, self.handles, preds

In [None]:
layer_name = "arch.conv1d_tower.layers.1"

In [None]:
test = FeatureExtractor(model, layer_name)

In [None]:
sequences = sdata["ohe_seq"].transpose("_sequence", "_ohe", "length").values[:128]
torch_seqs = torch.tensor(sequences, dtype=torch.float32).to("cuda")
torch_seqs.shape

In [None]:
test(torch_seqs)[0][layer_name][0].T

In [None]:
import torch.nn.functional as F
def get_layer(
    model, 
    layer_name,
    index=None
):
    if index is not None:
        return dict([*model.named_modules()])[layer_name][index]
    else:
        return dict([*model.named_modules()])[layer_name]

In [None]:
layer = get_layer(model, layer_name, index=0)
layer_outs = F.relu(F.conv1d(torch_seqs, layer)).detach().cpu().numpy()

# Filter viz padding fun!

In [None]:
# Select the layer you want to interpret
layer_name = "arch.conv1d_tower.layers.0"

In [None]:
from copy import deepcopy

# Grab motifs
core_promoter_elements = md.read_meme(os.path.join(settings.dataset_dir, "CPEs.meme"))
tf_clusters = md.read_meme(os.path.join(settings.dataset_dir, "TF-clusters.meme"))

# Smush them together, make function in the future
all_motifs = deepcopy(core_promoter_elements)
for motif in tf_clusters:
    all_motifs.add_motif(motif)
all_motifs

In [None]:
# Function for instantiating a new randomly initialized model
def prep_new_model(
    config,
    seed
):
    # Instantiate the model
    model = models.load_config(config_path=config, seed=seed)
    
    # Initialize the model prior to conv filter initialization
    models.init_weights(model, initializer="kaiming_normal")

    # Initialize the conv filters
    if model.arch_name == "Jores21CNN":
        layer_name = "arch.biconv.kernels"
        list_index = 0
    elif model.arch_name in ["CNN", "Hybrid", "DeepSTARR"]:
        layer_name = "arch.conv1d_tower.layers.0"
        list_index = None
    models.init_motif_weights(
        model=model,
        layer_name=layer_name,
        list_index=list_index,
        initializer="xavier_uniform",
        motifs=all_motifs,
        convert_to_pwm=False,
        divide_by_bg=True,
        motif_align="left",
        kernel_align="left"
    )

    # Return the model
    return model 

# Test the instantiation of each model to make sure this is working properly
model = prep_new_model("jores21_cnn.yaml", seed=0)

In [None]:
# Biconv kernel
kernel = models.get_layer(model, "arch.biconv.kernels")[0]
bias = models.get_layer(model, "arch.biconv.biases")[0]
layer = torch.nn.Conv1d(
    in_channels=kernel.shape[1],
    out_channels=kernel.shape[0],
    kernel_size=kernel.shape[2],
    padding=3,
)
layer.weight = torch.nn.Parameter(kernel)
layer.bias = torch.nn.Parameter(bias)
layer.eval().to("cuda")

In [None]:
# CNN kernel
layer = models.get_layer(model, layer_name)
layer.eval().to("cuda")

In [None]:
layer.weight[0].T

In [None]:
decode_seq(X_np[0])

In [None]:
decode_seq(X_np[0, :, 0:13])

In [None]:
layer(X[0, :, 0:13])[0]

In [None]:
layer(X[0, :, 1:14])[0]

In [None]:
layer(X[0])[0]

In [None]:
X_np = sdata["ohe_seq"].transpose("_sequence", "_ohe", "length").to_numpy()
X = torch.tensor(X_np, dtype=torch.float32).to(device="cuda")
activations = F.relu(layer(X)).detach().cpu().numpy()

In [None]:
from seqexplainer._utils import _k_largest_index_argsort
from seqexplainer.preprocess._preprocess import decode_seq, ohe_seq

In [None]:
single_filter = activations[:, 0, :]
large_inds = _k_largest_index_argsort(single_filter, k=10)
single_filter.shape

In [None]:
for ind in large_inds:
    print(ind)
    print(single_filter[ind[0], ind[1]])

In [None]:
for i, seq in enumerate(X_np[large_inds[:, 0]]):
    print(i)
    start = large_inds[i, 1] - 6
    end = large_inds[i, 1] + 7
    print(decode_seq(seq)[start:end])

In [None]:
# Generate pfms from filters
interpret.generate_pfms_sdata(
    model,
    sdata,
    seq_key="ohe_seq",
    layer_name=layer_name,
    kernel_size=13,
    activations=activations,
    num_filters=1,
    padding=3,
    seqs=sdata["ohe_seq"].transpose("_sequence", "_ohe", "length").to_numpy(),
    num_seqlets=100
)

In [None]:
# Visualize a filter of choice
pl.filter_viz(
    sdata,
    filter_num=0,
    pfms_key=f"{layer_name}_pfms",
)

# Data stuff (sort this later)

In [None]:
sdata = sd.open_zarr(os.path.join(settings.dataset_dir, "jores21_leaf_train.zarr"))

In [None]:
sdata["ohe_seq"].shape, sdata["train_val"].to_dataframe().value_counts(normalize=True)

In [None]:
seq_key = "ohe_seq"
target_keys = "enrichment"
train_key = "train_val"
seq_transforms = {seq_key: lambda x: torch.tensor(x, dtype=torch.float32).permute(0, 2, 1)}
batch_size = 128
num_workers = 4
drop_last = True

In [None]:
if isinstance(target_keys, str):
    target_keys = [target_keys]
if len(target_keys) == 1:
    sdata["target"] = sdata[target_keys[0]]
else:
    sdata["target"] = xr.concat([sdata[target_key] for target_key in target_keys], dim="_targets").transpose("_sequence", "_targets")
targs = sdata["target"].values
if len(targs.shape) == 1:
    nan_mask = np.isnan(targs)
else:
    nan_mask = np.any(np.isnan(targs), axis=1)
print(f"Dropping {nan_mask.sum()} sequences with NaN targets.")
sdata = sdata.isel(_sequence=~nan_mask)

In [None]:
# Load training data into memory
sdata["ohe_seq"].load()
sdata["enrichment"].load()
sdata["train_val"].load()

In [None]:
targs = sdata["enrichment"].values

In [None]:
import xarray as xr

In [None]:
nan_mask = xr.DataArray(np.isnan(targs), dims=["_sequence"])

In [None]:
sdata = sdata.where(~nan_mask, drop=True)

In [None]:
print(f"Dropping {int(nan_mask.sum().values)} sequences with NaN targets.")

In [None]:
sdata

In [None]:
sdata.where(sdata["train_val"], drop=True)

In [None]:
sdata.where(~sdata["train_val"], drop=True)

In [None]:
sdata.where(~sdata.train_val)

In [None]:
train_mask = np.where(sdata[train_key])[0]
train_sdata = sdata.isel(_sequence=train_mask)
val_sdata = sdata.isel(_sequence=~train_mask)
train_dataloader = sd.get_torch_dataloader(
    train_sdata,
    sample_dims=["_sequence"],
    variables=[seq_key, "target"],
    transforms=seq_transforms,
    prefetch_factor=2,
    shuffle=True,
    drop_last=drop_last,
    batch_size=batch_size,
    num_workers=num_workers
)
val_dataloader = sd.get_torch_dataloader(
    val_sdata,
    sample_dims=["_sequence"],
    variables=[seq_key, "target"],
    transforms=seq_transforms,
    prefetch_factor=2,
    shuffle=False,
    drop_last=drop_last,
    batch_size=batch_size,
    num_workers=num_workers
)

In [None]:
train_sdata

In [None]:
val_sdata

In [None]:
batch = next(iter(train_dataloader))
batch_ohe_seq = batch[seq_key]
batch_target = batch["target"]
batch_ohe_seq.shape, batch_target.shape

In [None]:
from tqdm.auto import tqdm

In [None]:
for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc="Looping over train dataloader"):
    batch_ohe_seq = batch[seq_key]
    batch_target = batch["target"]

In [None]:
for i, batch in tqdm(enumerate(val_dataloader), total=len(val_dataloader), desc="Looping over val dataloader"):
    batch_ohe_seq = batch[seq_key]
    batch_target = batch["target"]

In [None]:
for i, batch in enumerate(val_dataloader):
    batch_ohe_seq = batch[seq_key]
    batch_target = batch["target"]
    print(batch_ohe_seq.shape, batch_target.shape)
    if i > 10:
        break

In [None]:
to_decode = batch_ohe_seq[0].numpy()

In [None]:
to_decode.shape

In [None]:
DNA = ["A", "C", "G", "T"]
RNA = ["A", "C", "G", "U"]

def _get_vocab(vocab):
    if vocab == "DNA":
        return DNA
    elif vocab == "RNA":
        return RNA
    else:
        raise ValueError("Invalid vocab, only DNA or RNA are currently supported")

# exact concise
def _get_index_dict(vocab):
    """
    Returns a dictionary mapping each token to its index in the vocabulary.
    """
    return {i: l for i, l in enumerate(vocab)}

# modified dinuc_shuffle
def _one_hot2token(one_hot, neutral_value=-1, consensus=False):
    """
    Converts a one-hot encoding into a vector of integers in the range [0, D]
    where D is the number of classes in the one-hot encoding.

    Parameters
    ----------
    one_hot : np.array
        L x D one-hot encoding
    neutral_value : int, optional
        Value to use for neutral values.
    
    Returns
    -------
    np.array
        L-vector of integers in the range [0, D]
    """
    if consensus:
        return np.argmax(one_hot, axis=0)
    tokens = np.tile(neutral_value, one_hot.shape[1])  # Vector of all D
    seq_inds, dim_inds = np.where(one_hot.transpose()==1)
    tokens[seq_inds] = dim_inds
    return tokens

def _sequencize(tvec, vocab="DNA", neutral_value=-1, neutral_char="N"):
    """
    Converts a token vector into a sequence of symbols of a vocab.
    """
    vocab = _get_vocab(vocab) 
    index_dict = _get_index_dict(vocab)
    index_dict[neutral_value] = neutral_char
    return "".join([index_dict[i] for i in tvec])

def decode_seq(arr, vocab="DNA", neutral_value=-1, neutral_char="N"):
    """Convert a single one-hot encoded array back to string"""
    if isinstance(arr, torch.Tensor):
        arr = arr.numpy()
    return _sequencize(
        tvec=_one_hot2token(arr, neutral_value),
        vocab=vocab,
        neutral_value=neutral_value,
        neutral_char=neutral_char,
    )

In [None]:
len(val_dataloader)

In [None]:
len(train_dataloader)

In [None]:
decode_seq(to_decode)

In [None]:
batch_target[0]

# Random

In [None]:
pfm_dfs = pfms_to_df_dict(pfms)
ppms = pfms_to_ppms(pfms, pseudocount=1)
pwms = ppms_to_pwms(ppms)
infos = ppms_to_igms(ppms)
ppics = per_position_ic(ppms)
tot_ics = ppics.sum(axis=1)

In [None]:
# Sort by total information content
sort_idx = np.argsort(tot_ics)[::-1]
sort_idx[:5]

In [None]:
from tqdm.auto import tqdm
from seqexplainer import evolution
def evolve_seqs_sdata(
    model: torch.nn.Module, 
    sdata, 
    rounds: int, 
    seq_key: str = "ohe_seq",
    axis_order = ("_sequence", "_ohe", "length"),
    add_seqs=True,
    return_seqs: bool = False, 
    device: str = "cpu", 
    batch_size: int = 128,
    copy: bool = False, 
    **kwargs
):
    """
    In silico evolve a set of sequences that are stored in a SeqData object.

    Parameters
    ----------
    model: torch.nn.Module  
        The model to score the sequences with
    sdata: SeqData  
        The SeqData object containing the sequences to evolve
    rounds: int
        The number of rounds of evolution to perform
    return_seqs: bool, optional
        Whether to return the evolved sequences
    device: str, optional
        Whether to use a 'cpu' or 'cuda'.
    copy: bool, optional
        Whether to copy the SeqData object before mutating it
    kwargs: dict, optional
        Additional arguments to pass to the evolution function
    
    Returns
    -------
    sdata: SeqData
        The SeqData object containing the evolved sequences
    """

    sdata = sdata.copy() if copy else sdata

    # Set device
    device = "cuda" if settings.gpus > 0 else "cpu" if device is None else device

    # Grab seqs
    ohe_seqs = sdata[seq_key].transpose(*axis_order).to_numpy()
    evolved_seqs = np.zeros(ohe_seqs.shape)
    deltas = np.zeros((sdata_evolve.dims["_sequence"], rounds))
    
    # Evolve seqs
    for i, ohe_seq in tqdm(enumerate(ohe_seqs), total=len(ohe_seqs), desc="Evolving seqs"):
        evolved_seq, delta, _ = evolution(model, ohe_seq, rounds=rounds, device=device)
        evolved_seqs[i] = evolved_seq
        deltas[i, :] = deltas[i, :] + delta

    # Get original scores
    orig_seqs = torch.tensor(ohe_seqs, dtype=torch.float32).to(device)
    original_scores = model.predict(orig_seqs, batch_size=batch_size, verbose=False).detach().cpu().numpy().squeeze()

    # Put evolved scores into sdata
    sdata["original_score"] = xr.DataArray(original_scores, dims="_sequence")
    sdata["evolved_1_score"] = xr.DataArray(original_scores + deltas[:, 0], dims="_sequence")
    for i in range(2, rounds + 1):
        sdata[f"evolved_{i}_score"] = xr.DataArray(sdata[f"evolved_{i-1}_score"] + deltas[:, i - 1], dims="_sequence")
    if return_seqs:
        evolved_seqs = torch.tensor(evolved_seqs, dtype=torch.float32)
        return evolved_seqs
    if add_seqs:
        sdata["evolved_seqs"] = xr.DataArray(evolved_seqs, dims=("_sequence", "_ohe", "length"))
    return sdata if copy else None

In [None]:
from seqexplainer.preprocess._preprocess import dinuc_shuffle_seq
consensus, dinuc_shuffle_seq(consensus), k_shuffle(consensus, k=2).tobytes().decode()

In [None]:
# if using naiveISM
sdata["ohe_seq"] = sdata["ohe_seq"].transpose("_sequence", "_ohe", "length")
X1 = sdata["ohe_seq"].values
X2 = (sdata[f"{method}_attrs"]*-1).sum(dim="_ohe").values
X1.shape, X2.shape
# Multiply the one-hot encoded sequence with the saliency scores. X1 has shape 128,4,170 and X2 has shape 128,170.
# We need to expand X2 to 128,4,170 to be able to multiply it with X1.
X2 = np.expand_dims(X2, axis=1)
X2 = np.repeat(X2, 4, axis=1)
X2.shape
X = X1 * X2
sdata[f"{method}_attrs_sum"] = xr.DataArray(X, dims=["_sequence", "_ohe", "length"], coords=sdata["ohe_seq"].coords)