# Implementing the `Jores21CNN` class

**Authorship:**
Adam Klie, *07/31/2022*
***
**Description:**
Notebook for implementing the `Jores21CNN` class.

In [7]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from eugene.models.base import BaseModel, BasicFullyConnectedModule, BasicConv1D

# Autoreload extension
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload
%autoreload 2

# `Jores21CNN`
See *TODO* for more details

## `BiConv1D` module
This is a `torch.nn.Module` that is used in the `Jores21CNN` class.

In [91]:
class BiConv1D(nn.Module):
    def __init__(self, filters, kernel_size, input_size = 4, layers = 2, stride = 1, dropout_rate = 0.15):
        super().__init__()
        self.input_size = input_size
        self.filters = filters
        self.kernel_size = kernel_size
        if layers < 1:
            raise ValueError("At least one layer needed")
        self.layers = layers
        if (dropout_rate < 0) or (dropout_rate > 1):
            raise ValueError("Dropout rate must be a float between 0 and 1")
        self.dropout_rate = dropout_rate
        self.stride = stride

        self.kernels = []
        self.biases = []
        kernel = torch.zeros(filters, input_size, kernel_size)
        nn.init.xavier_uniform_(kernel)
        self.kernels.append(kernel)
        bias = torch.zeros(filters)
        nn.init.zeros_(bias)
        self.biases.append(bias)
        for layer in range(1, self.layers):
            kernel = nn.Parameter(torch.empty((self.filters, self.filters, self.kernel_size)))
            nn.init.xavier_uniform_(kernel)
            self.kernels.append(kernel)
            bias = nn.Parameter(torch.empty((self.filters)))
            nn.init.zeros_(bias)
            self.biases.append(bias)

    def forward(self, x):
        x_fwd = F.conv1d(x, self.kernels[0], stride = self.stride, padding = "same")
        x_fwd = torch.add(x_fwd.transpose(1,2), self.biases[0]).transpose(1,2)
        x_fwd = F.dropout(F.relu(x_fwd), p = self.dropout_rate)
        x_rev = F.conv1d(x, torch.flip(self.kernels[0], dims=[1, 2]), stride = self.stride, padding = "same")
        x_rev = torch.add(x_rev.transpose(1,2), self.biases[0]).transpose(1,2)
        x_rev = F.dropout(F.relu(x_rev), p = self.dropout_rate)
        for layer in range(1, self.layers):
            x_fwd = F.conv1d(x_fwd, self.kernels[layer], stride = self.stride, padding = "same")
            x_fwd = torch.add(x_fwd.transpose(1,2), self.biases[layer]).transpose(1,2)
            x_fwd = F.dropout(F.relu(x_fwd), p = self.dropout_rate)
            x_rev = F.conv1d(x_rev, torch.flip(self.kernels[layer], dims=[1, 2]), stride = self.stride, padding = "same")
            x_rev = torch.add(x_rev.transpose(1,2), self.biases[layer]).transpose(1,2)
            x_rev = F.dropout(F.relu(x_rev), p = self.dropout_rate)
        return torch.add(x_fwd, x_rev)

In [92]:
test_biconv = BiConv1D(filters=128, kernel_size=13)
x = torch.randn(10, 4, 170)
test_biconv(x).shape

## `Jores21CNN` class definition

In [94]:
class jores21(BaseModel):
    def __init__(self, input_len, output_dim, strand="ss", task="regression", aggr=None):
        super().__init__(input_len, output_dim, strand, task, aggr)
        self.biconv = BiConv1D(filters=128, kernel_size=13)
        self.conv = nn.Conv1d(in_channels=128, out_channels=128, kernel_size=13, stride=1, padding="same")
        self.dropout = nn.Dropout(p=0.15)
        self.fc = nn.Linear(in_features=128*170, out_features=64)
        self.batchnorm = nn.BatchNorm1d(num_features=64)
        self.fc2 = nn.Linear(in_features=64, out_features=output_dim)

    def forward(self, x):
        x = self.biconv(x)
        x = self.conv(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc(x.view(x.shape[0], -1))
        x = self.batchnorm(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

In [96]:
model = jores21(input_len=4, output_dim=1, strand="ss", task="regression", aggr=None)
model.biconv.kernels[0] = torch.randn(128, 4, 13)

In [100]:
model.biconv.kernels[0]

tensor([[[ 1.8953, -0.4749,  0.6269,  ..., -1.2313,  0.6483, -1.2316],
         [-2.4813, -0.7270,  0.8044,  ...,  1.4520, -0.4763,  0.0513],
         [-1.7531,  0.3280, -0.1239,  ..., -0.5653,  0.4183,  0.4763],
         [-0.3167, -2.6283,  1.0020,  ...,  0.3277,  0.7104, -1.1900]],

        [[ 0.5870, -1.8096,  1.1668,  ...,  0.1083, -0.7349, -0.0243],
         [ 0.8162,  0.1150, -0.2222,  ..., -0.1172, -0.0473,  1.6646],
         [ 1.0849,  0.7717,  0.7226,  ...,  1.4296, -0.2696, -2.5712],
         [ 1.4951,  0.9446, -0.7659,  ...,  0.7915, -0.8546,  0.8496]],

        [[-0.6582, -0.8332, -0.4257,  ...,  0.7393, -1.0129,  1.5116],
         [ 0.8167,  2.0377, -0.2003,  ...,  1.0718,  0.6252,  0.6872],
         [-0.0168,  0.9187, -0.0466,  ..., -1.0867,  0.7338, -0.1525],
         [-0.2133,  0.9298, -0.1380,  ...,  0.0164, -2.2652,  0.1844]],

        ...,

        [[ 1.1549, -0.2447,  1.3656,  ...,  0.1596, -0.7859,  1.7132],
         [ 0.1077, -0.9730, -0.2881,  ..., -0.2443,  1.65

# `Motif` and `MinimalMEME` integration 

## `Motif` and `MinimalMEME` class definitions

In [103]:
import os
import numpy as np
import pandas as pd
import re
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from dataclasses import dataclass
from typing import Optional
from io import TextIOBase

@dataclass
class Motif:
    identifier: str
    pfm: np.ndarray
    alphabet_length: int
    length: int
    name: Optional[str] = None
    source_sites: Optional[int] = None
    source_evalue: Optional[float] = None
    
    def __len__(self) -> int:
        return self.length
    
    
class MinimalMEME:
    """ http://meme-suite.org/doc/meme-format.html """
    
    __version_regex = re.compile('^MEME version ([0-9]+)$')
    __background_regex = re.compile('^Background letter frequencies(?: \(from (.+)\))?$')
    __background_sum_error = 0.00001
    __pfm_header_regex = re.compile('^letter-probability matrix:(?: alength= ([0-9]+))?(?: w= ([0-9]+))?(?: nsites= ([0-9]+))?(?: E= ([0-9.e-]+))?$')
    version = None
    alphabet = None
    strands = None
    background = None
    background_source = None
    motifs = None
    
    def __init__(self, path):
        self.motifs = {}
        
        # parse the minimal MEME file
        with open(path) as minimal_meme_file:
            line = minimal_meme_file.readline()
            # first line must be version
            self.version = self._parse_version(line)

            line = minimal_meme_file.readline()
            while line:
                if line.startswith('ALPHABET'):
                    if self.alphabet is None:
                        self.alphabet = self._parse_alphabet(line)
                        line = minimal_meme_file.readline()
                    else:
                        raise RuntimeError("Multiple alphabet definitions encountered in MEME file")
                elif line.startswith('strands: '):
                    if self.strands is None:
                        self.strands = self._parse_strands(line)
                        line = minimal_meme_file.readline()
                    else:
                        raise RuntimeError("Multiple strand definitions encountered in MEME file")
                elif line.startswith('Background letter frequencies'):
                    if self.background is None:
                        line = self._parse_background(line, minimal_meme_file)
                    else:
                        raise RuntimeError("Multiple background frequency definitions encountered in MEME file")
                elif line.startswith('MOTIF'):
                    line = self._parse_motif(line, minimal_meme_file)
                else:
                    line = minimal_meme_file.readline()
    
    def _parse_version(self, line: str) -> str:
        match = re.match(self.__version_regex, line)
        if match:
            return match.group(1)
        else:
            raise RuntimeError("Minimal MEME file missing version string on first line")
            
    def _parse_alphabet(self, line: str) -> str:
        if line.startswith('ALPHABET '):
            raise NotImplementedError("Alphabet definitions not supported")
        elif line.startswith('ALPHABET= '):
            return line.rstrip()[10:]
        else:
            raise RuntimeError('Unable to parse alphabet line')
            
    def _parse_strands(self, line: str) -> str:
        strands = line.rstrip()[9:]
        if not ((strands == '+') or (strands == '+ -')):
            raise RuntimeError("Invalid strand specification")
        else:
            return strands
        
    def _parse_background(self, line: str, handle: TextIOBase) -> str:
        match = re.match(self.__background_regex, line)
        if match:
            if match.group(1) is not None:
                self.background_source = match.group(1)
        else:
            raise RuntimeError("Unable to parse background frequency line")

        self.background = {}
        # start parsing possibly multiple lines of background frequencies
        line = handle.readline()
        while line:
            if (not line.rstrip()) or line.startswith('MOTIF'):
                if abs(1 - sum(self.background.values())) <= self.__background_sum_error:
                    return line
                else:
                    raise RuntimeError("Background frequencies do not sum to 1")
            else:
                line_freqs = line.rstrip().split(' ')
                if len(line_freqs) % 2 != 0:
                    raise RuntimeError("Invalid background frequency definition")
                for residue, freq in zip(line_freqs[0::2], line_freqs[1::2]):
                    self.background[residue] = float(freq)
            line = handle.readline()
    
    def _parse_motif(self, line: str, handle: TextIOBase) -> str:
        # parse motif identifier
        line_split = line.rstrip().split(' ')
        if (len(line_split) < 2) or (len(line_split) > 3):
            raise RuntimeError("Invalid motif name line")
        motif_identifier = line_split[1]
        motif_name = line_split[2] if len(line_split) == 3 else None
        
        line = handle.readline()
        # parse letter probability matrix header
        if not line.startswith('letter-probability matrix:'):
            raise RuntimeError("No letter-probability matrix header line in motif entry")
        match = re.match(self.__pfm_header_regex, line)
        if match:
            motif_alphabet_length = int(match.group(1)) if match.group(1) is not None else None
            motif_length = int(match.group(2)) if match.group(2) is not None else None
            motif_source_sites = int(match.group(3)) if match.group(3) is not None else None
            motif_source_evalue = float(match.group(4)) if match.group(4) is not None else None
        else:
            raise RuntimeError("Unable to parse letter-probability matrix header")
        
        # parse letter probability matrix
        line = handle.readline()
        pfm_rows = []
        while line:
            if (not line.rstrip()) or line.startswith('MOTIF'):
                if motif_identifier in self.motifs:
                    raise RuntimeError("Motif identifiers not unique within file")
                pfm = np.stack(pfm_rows)
                if motif_length is None:
                    motif_length = pfm.shape[0]
                elif motif_length != pfm.shape[0]:
                    raise RuntimeError("Provided motif length is not consistent with the letter-probability matrix shape")
                self.motifs[motif_identifier] = Motif(
                    identifier = motif_identifier,
                    pfm = pfm,
                    alphabet_length = motif_alphabet_length,
                    length = motif_length,
                    name = motif_name,
                    source_sites = motif_source_sites,
                    source_evalue = motif_source_evalue
                )
                return line
            else:
                line_split = line.rstrip().split()
                if motif_alphabet_length is None:
                    motif_alphabet_length = len(line_split)
                elif motif_alphabet_length != len(line_split):
                    raise RuntimeError("Letter-probability matrix row length doesn't equal alphabet length")
                pfm_row = np.array([float(s) for s in line_split])
                pfm_rows.append(pfm_row)
                line = handle.readline()

## Initializing a layer with `Motif` objects

In [104]:
core_promoter_elements = MinimalMEME('/cellar/users/aklie/projects/EUGENe/tests/_data/datasets/jores21/CPEs.meme')
tf_groups = MinimalMEME('/cellar/users/aklie/projects/EUGENe/tests/_data/datasets/jores21/TF-clusters.meme')
all_motifs = {**core_promoter_elements.motifs, **tf_groups.motifs}

In [122]:
kernel = torch.zeros(128, 4, 13)
nn.init.xavier_uniform_(kernel)

# overwrite part of kernel with pfms from motifs
for i, motif_id in enumerate(all_motifs):
    motif = all_motifs[motif_id]
    # convert PFM to PWM, assume equal background frequency of 0.25
    # truncates motifs longer than 13bp to 13bp
    kernel[i, :, :min(len(motif), kernel.shape[2])] = torch.tensor(motif.pfm[:min(len(motif), kernel.shape[2]), :] / 0.25).transpose(0, 1)

In [124]:
model.biconv.kernels[0] = kernel

In [125]:
model(x)

tensor([[-0.0013],
        [ 0.0682],
        [ 0.0221],
        [-0.0021],
        [ 0.0844],
        [ 0.1157],
        [ 0.0566],
        [ 0.2570],
        [-0.1326],
        [-0.1769]], grad_fn=<AddmmBackward0>)