# Testing `Jores21CNN` model class

**Authorship:**
Adam Klie, *07/31/2022*
***
**Description:**
Notebook for testing out the custom `Jores21CNN` model class.

In [1]:
import numpy as np
import pandas as pd
import torch

# Autoreload extension
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload
%autoreload 2

In [2]:
import eugene as eu

Global seed set to 13


In [36]:
from eugene.models.base import BaseModel
import torch.nn as nn
import torch.nn.functional as F
from eugene.models.base._utils import GetFlattenDim

class Kopp21CNN(BaseModel):
    def __init__(
        self,
        input_len,
        output_dim,
        strand="ds",
        task="binary_classification",
        aggr="max",
        filters: list = [10, 8],
        conv_kernel_size: list = [11, 3],
        maxpool_kernel_size: int = 30,
        stride: int = 1,
        **kwargs
    ):
        super().__init__(
            input_len, output_dim, strand=strand, task=task, aggr=aggr, **kwargs
        )
        self.conv = nn.Conv1d(4, filters[0], conv_kernel_size[0], stride=stride)
        self.maxpool = nn.MaxPool1d(kernel_size=maxpool_kernel_size, stride=stride)
        self.batchnorm = nn.BatchNorm1d(filters[0])
        self.conv2 = nn.Conv1d(filters[0], filters[1], conv_kernel_size[1], stride=stride)
        self.batchnorm2 = nn.BatchNorm1d(filters[1])
        self.linear = nn.Linear(filters[1], self.output_dim)

    def forward(self, x, x_rev):
        x_fwd = F.relu(self.conv(x))
        x_rev = F.relu(self.conv(x_rev))
        if self.aggr == 'concat':
            x = torch.cat((x_fwd, x_rev), dim=1)
        elif self.aggr == 'max':
            x = torch.max(x_fwd, x_rev)
        elif self.aggr == 'ave':
            x = (x_fwd + x_rev) / 2
        elif self.aggr is None:
            x = torch.cat((x_fwd, x_rev), dim=1) 
        x = self.maxpool(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool1d(x, x.shape[2])
        x = self.batchnorm2(x)
        x = x.view(x.shape[0], -1)
        x = self.linear(x)
        return x

In [37]:
x = torch.randn(10, 4, 500)
x_rev = torch.randn(10, 4, 500)

In [38]:
model = Kopp21CNN(
    input_len=500,
    output_dim=1,
)
model



Kopp21CNN(
  (hp_metric): AUROC()
  (conv): Conv1d(4, 10, kernel_size=(11,), stride=(1,))
  (maxpool): MaxPool1d(kernel_size=30, stride=1, padding=0, dilation=1, ceil_mode=False)
  (batchnorm): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv1d(10, 8, kernel_size=(3,), stride=(1,))
  (batchnorm2): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (linear): Linear(in_features=8, out_features=1, bias=True)
)

In [39]:
model(x, x_rev)

torch.Size([10, 8, 459])
torch.Size([10, 8, 1])
torch.Size([10, 8])


tensor([[ 0.3092],
        [ 0.1936],
        [ 0.0502],
        [ 0.6300],
        [-0.4470],
        [ 0.1794],
        [ 0.1944],
        [ 0.2262],
        [ 0.3051],
        [ 0.5353]], grad_fn=<AddmmBackward0>)

In [28]:
from eugene.utils._motif import _create_kernel_matrix

In [46]:
model.__getattr__("biconv").kernels[0] = kern

In [40]:
model."biconv"

SyntaxError: invalid syntax (2592916639.py, line 1)

In [36]:
kern = model.biconv.kernels[0]

In [37]:
kern = _create_kernel_matrix((128, 4, 13), all_motifs)

In [56]:
model.conv.weight = torch.nn.Parameter(torch.randn(128, 128, 13))

In [57]:
model.conv.weight

Parameter containing:
tensor([[[-1.1160, -0.3325, -0.7744,  ...,  0.0685, -0.1617, -0.7123],
         [ 0.0985, -1.0684,  0.9452,  ...,  0.0340,  0.8200, -0.8049],
         [-0.3100, -1.2196, -0.4721,  ...,  0.7886,  0.2196,  2.7004],
         ...,
         [-1.1247, -0.0458, -0.4971,  ...,  0.4404, -1.2233, -0.1531],
         [ 0.6456, -0.5590, -1.2134,  ..., -0.5988, -0.2380, -0.2558],
         [-0.8898, -1.3874, -0.8575,  ..., -0.4385,  1.1100, -0.2351]],

        [[-0.7794,  0.8648, -0.3815,  ..., -1.6818,  0.1235,  1.6162],
         [ 0.7390, -1.7336,  0.4566,  ...,  0.1805, -0.2668,  0.9915],
         [ 0.9938, -0.0948,  1.4486,  ...,  1.2259, -0.0529,  0.5860],
         ...,
         [-0.9741, -2.5641, -0.8386,  ...,  0.6279, -2.0081, -1.0378],
         [-0.1105,  0.5527, -0.0749,  ..., -0.5626, -0.4269,  0.9372],
         [-0.6103,  2.6745,  0.2042,  ..., -1.4715, -0.0636,  0.8312]],

        [[-1.8713,  0.9585,  1.2868,  ...,  0.8011,  0.1002,  0.1813],
         [-0.8813, -0.4

In [104]:
core_promoter_elements = MinimalMEME('/cellar/users/aklie/projects/EUGENe/tests/_data/datasets/jores21/CPEs.meme')
tf_groups = MinimalMEME('/cellar/users/aklie/projects/EUGENe/tests/_data/datasets/jores21/TF-clusters.meme')
all_motifs = {**core_promoter_elements.motifs, **tf_groups.motifs}

In [122]:
kernel = torch.zeros(128, 4, 13)
nn.init.xavier_uniform_(kernel)

# overwrite part of kernel with pfms from motifs
for i, motif_id in enumerate(all_motifs):
    motif = all_motifs[motif_id]
    # convert PFM to PWM, assume equal background frequency of 0.25
    # truncates motifs longer than 13bp to 13bp
    kernel[i, :, :min(len(motif), kernel.shape[2])] = torch.tensor(motif.pfm[:min(len(motif), kernel.shape[2]), :] / 0.25).transpose(0, 1)

In [124]:
model.biconv.kernels[0] = kernel

In [125]:
model(x)

tensor([[-0.0013],
        [ 0.0682],
        [ 0.0221],
        [-0.0021],
        [ 0.0844],
        [ 0.1157],
        [ 0.0566],
        [ 0.2570],
        [-0.1326],
        [-0.1769]], grad_fn=<AddmmBackward0>)

In [18]:
core_promoter_elements = eu.utils.MinimalMEME('/cellar/users/aklie/projects/EUGENe/tests/_data/datasets/jores21/CPEs.meme')
tf_groups = eu.utils.MinimalMEME('/cellar/users/aklie/projects/EUGENe/tests/_data/datasets/jores21/TF-clusters.meme')
all_motifs = {**core_promoter_elements.motifs, **tf_groups.motifs}

AttributeError: module 'eugene.utils' has no attribute 'MinimalMEME'