# <center> BENCHMARKING CLASTER WITH HYENA-DNA AND ENFORMER <center>

## 1. Hyena-DNA:

The following is an adaptation of the open source code provided by the authors of Hyena-DNA, who are to be credited for it. 

The paper can be found as:
Eric Nguyen and Michael Poli and Marjan Faizi and Armin Thomas and Callum Birch-Sykes and Michael Wornow and Aman Patel and Clayton Rabideau and Stefano Massaroli and Yoshua Bengio and Stefano Ermon and Stephen A. Baccus and Chris Ré, _HyenaDNA: Long-Range Genomic Sequence Modeling at Single Nucleotide Resolution_ , 2023.

https://doi.org/10.48550/arXiv.2306.15794

The original public colab notebook can be found here:

https://colab.research.google.com/drive/1wyVEQd4R3HYLTUOXEEQmp_I8aNC_aLhL

**Edits:**
The goal was to add a head to the Hyena-DNA backbone. We tried obtaining only the embeddings and then adding the head on top of those, but we could not predict EU-seq profiles from the pretrained embeddings.

The structure of the backbone remains untouched. We added custom regression heads aimed to
- Perform dimensionality reduction from the high dimensional embeddings.
- Link the resulting sequence representations to our outputs, i.e. EU-seq levels at a kbp resolution.

In [20]:
! pip install  einops torchvision transformers==4.26.1 nvidia-ml-py3 genomic-benchmarks OmegaConf

Collecting transformers==4.26.1
  Using cached transformers-4.26.1-py3-none-any.whl.metadata (100 kB)
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers==4.26.1)
  Downloading tokenizers-0.13.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Using cached transformers-4.26.1-py3-none-any.whl (6.3 MB)
Using cached tokenizers-0.13.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
Installing collected packages: tokenizers, transformers
  Attempting uninstall: tokenizers
    Found existing installation: tokenizers 0.15.2
    Uninstalling tokenizers-0.15.2:
      Successfully uninstalled tokenizers-0.15.2
  Attempting uninstall: transformers
    Found existing installation: transformers 4.39.1
    Uninstalling transformers-4.39.1:
      Successfully uninstalled transformers-4.39.1
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the foll

**Train and test HyenaDNA**

In [2]:
%%writefile Hyena_DNA_Esrum.py 

"""
Python file containing the HyenaDNA ipynb cells stacked:
https://colab.research.google.com/drive/1wyVEQd4R3HYLTUOXEEQmp_I8aNC_aLhL
All credit is for the original developers.

We added custom heads and modified and added training functions accordingly.
"""

#@title Installs
# ! pip install einops
# ! pip install torchvision
# ! pip install transformers==4.26.1
# ! pip install genomic-benchmarks
# ! pip install OmegaConf


#@title Imports
# for HyenaDNA specifically
from pathlib import Path
import torch
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from einops import rearrange
from typing import Optional
from functools import partial
from torch import Tensor
from torchvision.ops import StochasticDepth
from collections import namedtuple

import os
import requests
import shutil
import subprocess


checkpoints_path = Path("../checkpoints/")
checkpoints_path.mkdir(exist_ok=True, parents=True)

#@title Hyena layer

def fftconv(u, k, D):
    """
    We apply a convolution through the fourier domain (from the Convolution Theorem)

    """
    seqlen = u.shape[-1]
    fft_size = 2 * seqlen

    k_f = torch.fft.rfft(k, n=fft_size) / fft_size
    u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)

    if len(u.shape) > 3: k_f = k_f.unsqueeze(1)
    y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]

    out = y + u * D.unsqueeze(-1)
    return out.to(dtype=u.dtype)


@torch.jit.script
def mul_sum(q, y):
    return (q * y).sum(dim=1)

class OptimModule(nn.Module):
    """ Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters """

    def register(self, name, tensor, lr=None, wd=0.0):
        """Register a tensor with a configurable learning rate and 0 weight decay"""

        if lr == 0.0:
            self.register_buffer(name, tensor)
        else:
            self.register_parameter(name, nn.Parameter(tensor))

            optim = {}
            if lr is not None: optim["lr"] = lr
            if wd is not None: optim["weight_decay"] = wd
            setattr(getattr(self, name), "_optim", optim)


class Sin(nn.Module):
    """The Sin activation function for the Hyena Filter function."""
    def __init__(self, dim, w=10, train_freq=True):
        super().__init__()
        self.freq = nn.Parameter(w * torch.ones(1, dim)) if train_freq else w * torch.ones(1, dim)

    def forward(self, x):
        return torch.sin(self.freq * x)


class PositionalEmbedding(OptimModule):
    def __init__(self, emb_dim: int, seq_len: int, lr_pos_emb: float=1e-5, **kwargs):
        """Complex exponential positional embeddings for Hyena filters."""
        super().__init__()

        self.seq_len = seq_len
        # The time embedding fed to the filteres is normalized so that t_f = 1
        t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1

        if emb_dim > 1:
            bands = (emb_dim - 1) // 2
        # To compute the right embeddings we use the "proper" linspace
        t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None]
        w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1

        f = torch.linspace(1e-4, bands - 1, bands)[None, None]
        z = torch.exp(-1j * f * w)
        z = torch.cat([t, z.real, z.imag], dim=-1)
        self.register("z", z, lr=lr_pos_emb)
        self.register("t", t, lr=0.0)

    def forward(self, L):
        return self.z[:, :L], self.t[:, :L]


class ExponentialModulation(OptimModule):
    """The window function applied to the output of the (MLP) filter function."""
    def __init__(
        self,
        d_model,
        fast_decay_pct=0.3,
        slow_decay_pct=1.5,
        target=1e-2,
        modulation_lr=0.0,
        modulate: bool=True,
        shift: float = 0.05,
        **kwargs
    ):
        super().__init__()
        self.modulate = modulate
        self.shift = shift
        max_decay = math.log(target) / fast_decay_pct
        min_decay = math.log(target) / slow_decay_pct
        deltas = torch.linspace(min_decay, max_decay, d_model)[None, None]
        self.register("deltas", deltas, lr=modulation_lr)

    def forward(self, t, x):
        if self.modulate:
            decay = torch.exp(-t * self.deltas.abs())
            x = x * (decay + self.shift)
        return x


class HyenaFilter(OptimModule):
    def __init__(
            self,
            d_model,
            emb_dim=3, # dim of input to MLP, augments with positional encoding
            order=16, # width of the implicit MLP
            fused_fft_conv=False,
            seq_len=1024,
            lr=1e-3,
            lr_pos_emb=1e-5,
            dropout=0.0,
            w=1, # frequency of periodic activations
            wd=0, # weight decay of kernel parameters
            bias=True,
            num_inner_mlps=2,
            normalized=False,
            **kwargs
        ):
        """
        Implicit long filter with modulation.

        Args:
            d_model: number of channels in the input
            emb_dim: dimension of the positional encoding (`emb_dim` - 1) // 2 is the number of bands
            order: width of the FFN
            num_inner_mlps: number of inner linear layers inside filter MLP

        Note:
            filter_dropout is not implemented
        """
        super().__init__()

        self.d_model = d_model
        self.use_bias = bias
        self.fused_fft_conv = fused_fft_conv
        self.bias = nn.Parameter(torch.randn(self.d_model))
        self.dropout = nn.Dropout(dropout)

        act = Sin(dim=order, w=w)
        self.emb_dim = emb_dim
        assert emb_dim % 2 != 0 and emb_dim >= 3, "emb_dim must be odd and greater or equal to 3 (time, sine and cosine)"
        self.seq_len = seq_len

        self.pos_emb = PositionalEmbedding(emb_dim, seq_len, lr_pos_emb)

        self.implicit_filter = nn.Sequential(
            nn.Linear(emb_dim, order),
            act,
        )
        for i in range(num_inner_mlps):
            self.implicit_filter.append(nn.Linear(order, order))
            self.implicit_filter.append(act)

        self.implicit_filter.append(nn.Linear(order, d_model, bias=False))

        self.modulation = ExponentialModulation(d_model, **kwargs)

        self.normalized = normalized
        for c in self.implicit_filter.children():
            for name, v in c.state_dict().items():
                optim = {"weight_decay": wd, "lr": lr}
                setattr(getattr(c, name), "_optim", optim)

    def filter(self, L, *args, **kwargs):
        z, t = self.pos_emb(L)
        h = self.implicit_filter(z)
        h = self.modulation(t, h)
        return h

    def forward(self, x, L, k=None, bias=None, *args, **kwargs):
        if k is None: k = self.filter(L)

        # Ensure compatibility with filters that return a tuple
        k = k[0] if type(k) is tuple else k

        y = fftconv(x, k, bias)
        return y


class HyenaOperator(nn.Module):
    def __init__(
            self,
            d_model,
            l_max,
            order=2,
            filter_order=64,
            dropout=0.0,
            filter_dropout=0.0,
            **filter_args,
        ):
        r"""
        Hyena operator described in the paper https://arxiv.org/pdf/2302.10866.pdf

        Args:
            d_model (int): Dimension of the input and output embeddings (width of the layer)
            l_max: (int): Maximum input sequence length. Defaults to None
            order: (int): Depth of the Hyena recurrence. Defaults to 2
            dropout: (float): Dropout probability. Defaults to 0.0
            filter_dropout: (float): Dropout probability for the filter. Defaults to 0.0
        """
        super().__init__()

        self.d_model = d_model
        self.l_max = l_max
        self.order = order
        inner_width = d_model * (order + 1)
        self.dropout = nn.Dropout(dropout)
        self.in_proj = nn.Linear(d_model, inner_width)
        self.out_proj = nn.Linear(d_model, d_model)

        self.short_filter = nn.Conv1d(
            inner_width,
            inner_width,
            3,
            padding=2,
            groups=inner_width
        )
        self.filter_fn = HyenaFilter(
            d_model * (order - 1),
            order=filter_order,
            seq_len=l_max,
            channels=1,
            dropout=filter_dropout,
            **filter_args
        )

    def forward(self, u, *args, **kwargs):
        l = u.size(-2)
        l_filter = min(l, self.l_max)
        u = self.in_proj(u)
        u = rearrange(u, 'b l d -> b d l')

        uc = self.short_filter(u)[...,:l_filter]
        *x, v = uc.split(self.d_model, dim=1)

        k = self.filter_fn.filter(l_filter)[0]
        k = rearrange(k, 'l (o d) -> o d l', o=self.order - 1)
        bias = rearrange(self.filter_fn.bias, '(o d) -> o d', o=self.order - 1)

        for o, x_i in enumerate(reversed(x[1:])):
            v = self.dropout(v * x_i)
            v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o])

        y = rearrange(v * x[0], 'b d l -> b l d')

        y = self.out_proj(y)
        return y

#@title Self-Attention (alternative)

"""
If you'd like to try the HyenaDNA model using attention instead, you can. ie,
use a regular decoder only Transformer.

Borrowed from the FlashAttention library by Tri Dao.
"""

class SelfAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
        super().__init__()
        self.causal = causal
        self.softmax_scale = softmax_scale
        self.dropout_p = attention_dropout

    def forward(self, qkv, causal=None, key_padding_mask=None):
        """Implements the multihead softmax attention.
        Arguments
        ---------
            qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
            causal: if passed, will override self.causal
            key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
                False means to mask out. (B, S)
        """
        batch_size, seqlen = qkv.shape[0], qkv.shape[1]
        causal = self.causal if causal is None else causal
        q, k, v = qkv.unbind(dim=2)
        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
        scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
        if key_padding_mask is not None:
            padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype,
                                      device=scores.device)
            padding_mask.masked_fill_(key_padding_mask, 0.0)
            scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
        if causal:
            # "triu_tril_cuda_template" not implemented for 'BFloat16'
            # So we have to construct the mask in float
            causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
            scores = scores + causal_mask.to(dtype=scores.dtype)
        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
        attention_drop = F.dropout(attention, self.dropout_p if self.training else 0.0)
        output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
        return output

class MHA(nn.Module):
    """Multi-head self-attention and cross-attention
    """

    def __init__(self, embed_dim, num_heads, bias=True, dropout=0.0,
                 softmax_scale=None, causal=False, layer_idx=None, dwconv=False,return_residual=False,device=None, dtype=None) -> None:
        """
            return_residual: whether to return the input x along with the output. This is for
                performance reason: for post-norm architecture, returning the input allows us
                to fuse the backward of nn.Linear with the residual connection.
        """
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.embed_dim = embed_dim
        self.causal = causal
        self.layer_idx = layer_idx
        self.dwconv = dwconv
        self.return_residual = return_residual

        self.num_heads = num_heads
        assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
        self.head_dim = self.embed_dim // num_heads

        linear_cls = nn.Linear
        linear_resid_cls = LinearResidual
        inner_attn_cls =  SelfAttention

        if not self.return_residual:
            self.Wqkv = linear_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
        else:
            self.Wqkv = linear_resid_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
        if self.dwconv:
            self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2,
                                        groups=3 * embed_dim)

        self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale,
                                         attention_dropout=dropout)

        # output projection always have the bias (for now)
        self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs)

    def forward(self, x, key_padding_mask=None, **kwargs):
        """
        Arguments:
            x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
                cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
                is the is the sum of the sequence lengths in the batch.
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into x. Only applicable when using
                FlashAttention.
            max_seqlen: int. Maximum sequence length in the batch.
            key_padding_mask: boolean mask, True means to keep, False means to mask out.
                (batch, seqlen). Only applicable when not using FlashAttention.
            mixer_subset: for cross-attention only. If not None, will take a subset of x
                before applying the query projection. Useful for e.g., ViT where we only care
                about the CLS token in the last layer.
            inference_params: for generation. Adapted from Megatron-LM (and Apex)
            https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
        """

        kwargs = ({'key_padding_mask': key_padding_mask, **kwargs})

        if not self.return_residual:
            qkv = self.Wqkv(x)
        else:
            qkv, x = self.Wqkv(x)
        if self.dwconv:
            qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2],
                            'b d s -> b s d').contiguous()
        qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, d=self.head_dim)

        context = self.inner_attn(qkv, **kwargs)

        out = self.out_proj(rearrange(context, '... h d -> ... (h d)'))
        return out if not self.return_residual else (out, x)

#@title MLP layer

"""
The MLP layer after the mixer layer (HyenaOperator).
"""

class Mlp(nn.Module):

    def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu,
                 return_residual=False, device=None, dtype=None):
        """
        From https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/modules/mlp.py
        """
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.return_residual = return_residual
        self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs)
        self.activation = activation
        self.fc2 = nn.Linear(hidden_features, out_features, **factory_kwargs)

    def forward(self, x):
        y = self.fc1(x)
        y = self.activation(y)
        y = self.fc2(y)
        return y if not self.return_residual else (y, x)

#@title Block layer (Hyena + MLP layers)

"""
A block consists of a Mixer layer (Hyena or attention), and a MLP layer.

"""

class LinearResidual(nn.Linear):
    """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.
    """

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return super().forward(input), input

class Block(nn.Module):

    def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
                 dropout_cls=nn.Dropout, prenorm=True, resid_dropout1=0., resid_dropout2=0.,
                 drop_path1=0., drop_path2=0.,
                 return_residual=False,
                 residual_in_fp32=False):
        """
        From https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/modules/block.py
        For prenorm=True, this Block has a slightly different structure compared to a regular
        prenorm Transformer block.
        The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
        [Ref: https://arxiv.org/abs/2002.04745]
        Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
        the hidden_states (output of the MLP) and the residual.
        This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
        The residual needs to be provided (except for the very first block).
        For prenorm=False, this Block has the same structure as a regular postnorm Transformer
        block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
        return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
        This is for performance reason: for post-norm architecture, returning the input allows us
        to fuse the backward of nn.Linear with the residual connectio
        """
        super().__init__()
        self.prenorm = prenorm
        self.return_residual = return_residual
        self.residual_in_fp32 = residual_in_fp32
        if self.residual_in_fp32:
            assert self.prenorm, 'residual_in_fp32 is only compatible with prenorm=True'
        if mixer_cls is None:
            mixer_cls = partial(MHA, num_heads=dim // 64)
        if mlp_cls is None:
            mlp_cls = partial(Mlp, hidden_features=4 * dim)
        self.mixer = mixer_cls()
        self.dropout1 = dropout_cls(resid_dropout1)
        self.drop_path1 = StochasticDepth(drop_path1, mode='row')
        self.norm1 = norm_cls(dim)
        self.mlp = mlp_cls(dim)
        if not isinstance(self.mlp, nn.Identity):
            self.dropout2 = dropout_cls(resid_dropout2)
            self.drop_path2 = StochasticDepth(drop_path2, mode='row')
            self.norm2 = norm_cls(dim)

    def forward(self, hidden_states, residual = None,
                mixer_subset=None, mixer_kwargs=None):
        r"""Pass the input through the encoder layer.
        Args:
            hidden_states: the sequence to the encoder layer (required).
            residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
            mixer_subset: for cross-attention only. If not None, will take a subset of x
                before applying the query projection. Useful for e.g., ViT where we only care
                about the CLS token in the last layer.
        """
        if self.prenorm:
            dropped = self.drop_path1(self.dropout1(hidden_states))
            residual = (dropped + residual) if residual is not None else dropped
            hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
            if self.residual_in_fp32:
                residual = residual.to(torch.float32)
            if mixer_kwargs is None:
                mixer_kwargs = {}
            if mixer_subset is not None:
                mixer_kwargs['mixer_subset'] = mixer_subset
            hidden_states = self.mixer(hidden_states, **mixer_kwargs)
            if mixer_subset is not None:
                residual = residual[:, mixer_subset]
            if not isinstance(self.mlp, nn.Identity):
                dropped = self.drop_path2(self.dropout2(hidden_states))
                residual = (dropped + residual) if residual is not None else dropped
                hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
                if self.residual_in_fp32:
                    residual = residual.to(torch.float32)

                hidden_states = self.mlp(hidden_states)
            return hidden_states, residual
        else:
            assert residual is None
            mixer_out = self.mixer(
                hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
            )
            if self.return_residual:  # mixer out is actually a pair here
                mixer_out, hidden_states = mixer_out

            hidden_states = self.norm1((self.drop_path1(self.dropout1(mixer_out))
                                        + hidden_states).to(dtype=self.norm1.weight.dtype))

            if not isinstance(self.mlp, nn.Identity):
                mlp_out = self.mlp(hidden_states)
                if self.return_residual:  # mlp out is actually a pair here
                    mlp_out, hidden_states = mlp_out

                hidden_states = self.norm2((self.drop_path2(self.dropout2(mlp_out))
                                            + hidden_states).to(dtype=self.norm2.weight.dtype))

            return hidden_states

def create_mixer_cls(layer=None,
                     attn_layer_idx=None, attn_cfg=None, layer_idx=None,
                     device=None, dtype=None):
    factory_kwargs = {'device': device, 'dtype': dtype}
    if attn_layer_idx is not None and layer_idx in attn_layer_idx:
        causal = True if attn_cfg is None else attn_cfg.pop('causal', True)

        mha_cls = MHA

        mixer_cls = partial(mha_cls, causal=causal, layer_idx=layer_idx,
                            **(attn_cfg if attn_cfg is not None else {}),**factory_kwargs)
    else:
        # mixer_cls = instantiate(registry.layer, layer, partial=True, layer_idx=layer_idx, **factory_kwargs)

        mixer_cls = partial(HyenaOperator, **layer)

    return mixer_cls

def create_mlp_cls(d_model, d_inner=None, device=None, dtype=None):
    factory_kwargs = {'device': device, 'dtype': dtype}
    inner_dim = d_inner if d_inner is not None else 4 * d_model

    mlp_cls = partial(Mlp, hidden_features=inner_dim,
                          activation=partial(F.gelu, approximate='tanh'), **factory_kwargs)

    return mlp_cls


def create_block(d_model, d_inner=None,
                 layer=None, attn_layer_idx=None,
                 attn_cfg=None, layer_norm_epsilon=1e-5,
                 resid_dropout1=0.0, resid_dropout2=0.0, residual_in_fp32=False,
                 layer_idx=None,
                 device=None, dtype=None):
    factory_kwargs = {'device': device, 'dtype': dtype}
    mixer_cls = create_mixer_cls(layer=layer,
                                 attn_layer_idx=attn_layer_idx,
                                 attn_cfg=attn_cfg, layer_idx=layer_idx,
                                 **factory_kwargs)
    mlp_cls = create_mlp_cls(d_model, d_inner=d_inner,
                             **factory_kwargs)
    norm_cls = partial(nn.LayerNorm, eps=layer_norm_epsilon, **factory_kwargs)
    block = Block(d_model, mixer_cls, mlp_cls, norm_cls=norm_cls,
                  prenorm=True, resid_dropout1=resid_dropout1, resid_dropout2=resid_dropout2,residual_in_fp32=residual_in_fp32)
    block.layer_idx = layer_idx
    return block


# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True,
                  glu_act=False):
    if isinstance(module, nn.Linear):
        nn.init.normal_(module.weight, std=initializer_range)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, std=initializer_range)

    if rescale_prenorm_residual:
        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
        #
        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
        for name, p in module.named_parameters():
            if name in ["out_proj.weight", "fc2.weight"]:
                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
                nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))
            # If using GLU activation for now, we scale the std by 2
            elif name in ["output_linear.0.weight"]:
                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
                if not glu_act:
                    nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))
                else:
                    out_features = p.shape[0]
                    # Multiplying the first half of the matrix by 2 since sigmoid scales it down by 0.5
                    # on average.
                    nn.init.normal_(p[:out_features // 2], mean=0.0, std=initializer_range / math.sqrt(2 * n_layer) * 2)

#@title Backbone model (stack of blocks)

"""
A backbone model consists of a stack of blocks. If you use attention, then
positional embeddings are included. When using Hyena, then the pos emb
revert to doing nothing.
"""

class GPT2Embeddings(nn.Module):

    def __init__(self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None,
                 word_embed_proj_dim=None, device=None, dtype=None):
        """
            If max_position_embeddings <= 0, there's no position embeddings
            If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
                the project up to embed_dim
        """
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        if word_embed_proj_dim is None:
            self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx,
                                                **factory_kwargs)
            self.project_in = None
        else:
            self.word_embeddings = nn.Embedding(vocab_size, word_embed_proj_dim,
                                                padding_idx=padding_idx, **factory_kwargs)
            self.project_in = nn.Linear(word_embed_proj_dim, embed_dim, bias=False,
                                        **factory_kwargs)
        self.max_position_embeddings = max_position_embeddings
        if self.max_position_embeddings > 0:
            self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim,
                                                    **factory_kwargs)

    def forward(self, input_ids, position_ids=None):
        """
            input_ids: (batch, seqlen)
            position_ids: (batch, seqlen)
        """
        batch_size, seqlen = input_ids.shape
        embeddings = self.word_embeddings(input_ids)
        if self.project_in is not None:
            embeddings = self.project_in(embeddings)
        if self.max_position_embeddings > 0:
            if position_ids is None:
                position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
            position_embeddings = self.position_embeddings(position_ids)
            embeddings = embeddings + position_embeddings
        return embeddings

class LMBackbone(nn.Module):

    def __init__(self, d_model: int, n_layer: int, d_inner: int, vocab_size: int,
                 process_group=None, layer=None,
                 attn_layer_idx=None, attn_cfg=None, max_position_embeddings=0,
                 resid_dropout: float = 0.0, embed_dropout: float = 0.1,
                 layer_norm_epsilon: float = 1e-5, initializer_cfg=None,residual_in_fp32=False,
                 device=None, dtype=None, **kwargs) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.process_group = process_group
        self.residual_in_fp32 = residual_in_fp32
        # note max_position_embeddings is 0 for Hyena, and therefore isn't used
        self.embeddings = GPT2Embeddings(d_model, vocab_size, max_position_embeddings,
                                             **factory_kwargs)

        self.layers = nn.ModuleList([create_block(
            d_model, d_inner=d_inner,
            layer=layer, attn_layer_idx=attn_layer_idx,
            attn_cfg=attn_cfg, layer_norm_epsilon=layer_norm_epsilon,
            resid_dropout1=embed_dropout if i == 0 else resid_dropout,
            resid_dropout2=resid_dropout, residual_in_fp32=residual_in_fp32,layer_idx=i,
            **factory_kwargs,
        ) for i in range(n_layer)])

        self.drop_f = nn.Dropout(resid_dropout)
        self.ln_f = nn.LayerNorm(d_model, eps=layer_norm_epsilon, **factory_kwargs)

        self.apply(partial(_init_weights, n_layer=n_layer,
                           **(initializer_cfg if initializer_cfg is not None else {})))

    def forward(self, input_ids, position_ids=None):
        #print("Before embeddings:", torch.cuda.memory_allocated(device='cuda'))
        hidden_states = self.embeddings(input_ids, position_ids=position_ids,)
        residual = None

        #print(hidden_states.shape, hidden_states.dtype)
        #print("After embeddings:", torch.cuda.memory_allocated(device='cuda'))

        for layer in self.layers:
            hidden_states, residual = layer(hidden_states, residual)
            #print(hidden_states.shape)
            #print(torch.cuda.memory_allocated(device='cuda'))

        dropped = self.drop_f(hidden_states)
        residual = (dropped + residual) if residual is not None else dropped
        hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))

        return hidden_states


#@title Decoder head layer

"""
A simple decoder head (using MLP) to predict a sequence level classification.
You have the option to average across all the tokens in a sequence or using the
"last" token to classify.  At least, those 2 worked best for us, but we provide
other "modes" as well.

We only need this for classification.  Otherwise we'll use the hidden
states of the backbone as embeddings.

"""


class SequenceDecoder(nn.Module):
    def __init__(
        self, d_model, d_output=None, l_output=None, use_lengths=False, mode="last"
    ):
        super().__init__()

        self.output_transform = nn.Identity() if d_output is None else nn.Linear(d_model, d_output)

        if l_output is None:
            self.l_output = None
            self.squeeze = False
        elif l_output == 0:
            # Equivalent to getting an output of length 1 and then squeezing
            self.l_output = 1
            self.squeeze = True
        else:
            assert l_output > 0
            self.l_output = l_output
            self.squeeze = False

        self.use_lengths = use_lengths
        self.mode = mode

        if mode == 'ragged':
            assert not use_lengths

    def forward(self, x, state=None, lengths=None, l_output=None):
        """
        x: (n_batch, l_seq, d_model)
        Returns: (n_batch, l_output, d_output)
        """

        if self.l_output is None:
            if l_output is not None:
                assert isinstance(l_output, int)  # Override by pass in
            else:
                # Grab entire output
                l_output = x.size(-2)
            squeeze = False
        else:
            l_output = self.l_output
            squeeze = self.squeeze

        if self.mode == "last":
            restrict = lambda x: x[..., -l_output:, :]
        elif self.mode == "first":
            restrict = lambda x: x[..., :l_output, :]
        elif self.mode == "pool":
            restrict = lambda x: (
                torch.cumsum(x, dim=-2)
                / torch.arange(
                    1, 1 + x.size(-2), device=x.device, dtype=x.dtype
                ).unsqueeze(-1)
            )[..., -l_output:, :]

            def restrict(x):
                L = x.size(-2)
                s = x.sum(dim=-2, keepdim=True)
                if l_output > 1:
                    c = torch.cumsum(x[..., -(l_output - 1) :, :].flip(-2), dim=-2)
                    c = F.pad(c, (0, 0, 1, 0))
                    s = s - c  # (B, l_output, D)
                    s = s.flip(-2)
                denom = torch.arange(
                    L - l_output + 1, L + 1, dtype=x.dtype, device=x.device
                )
                s = s / denom
                return s

        elif self.mode == "sum":
            restrict = lambda x: torch.cumsum(x, dim=-2)[..., -l_output:, :]
            # TODO use same restrict function as pool case
        elif self.mode == 'ragged':
            assert lengths is not None, "lengths must be provided for ragged mode"
            # remove any additional padding (beyond max length of any sequence in the batch)
            restrict = lambda x: x[..., : max(lengths), :]
        else:
            raise NotImplementedError(
                "Mode must be ['last' | 'first' | 'pool' | 'sum']"
            )

        # Restrict to actual length of sequence
        if self.use_lengths:
            assert lengths is not None
            x = torch.stack(
                [
                    restrict(out[..., :length, :])
                    for out, length in zip(torch.unbind(x, dim=0), lengths)
                ],
                dim=0,
            )
        else:
            x = restrict(x)

        if squeeze:
            assert x.size(-2) == 1
            x = x.squeeze(-2)

        x = self.output_transform(x)

        return x

    def step(self, x, state=None):
        # Ignore all length logic
        return self.output_transform(x)


class ConvLinearModel(nn.Module):
    def __init__(self, input_depth, input_length, output_length,avg_pool_size, kernel_size, stride, padding, dilation, dropout_rate, hidden_depth):
        super(ConvLinearModel, self).__init__()
        # 2D Average Pooling
        self.avgpool2d = nn.AvgPool2d(kernel_size=avg_pool_size)

        # Dropout
        self.dropout = nn.Dropout(dropout_rate)

        # Adjust the input depth and sequence length based on the 2D pooling
        new_depth = input_depth // avg_pool_size[0]
        new_seq_length = input_length // avg_pool_size[1]

        # 1D Convolution
        self.conv1 = nn.Conv1d(in_channels=new_depth, out_channels=hidden_depth, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation)

        # Calculate the output length after convolution
        self.seq_length_after_conv = self.calculate_output_length(new_seq_length, kernel_size, stride, padding, dilation)

        # Fully connected layer with calculated input size
        self.fc = nn.Linear(hidden_depth * self.seq_length_after_conv, output_length)
        self.softplus = nn.Softplus()

    def calculate_output_length(self, input_length, kernel_size, stride, padding, dilation):
        return ((input_length + 2 * padding - dilation * (kernel_size - 1) - 1) // stride) + 1

    def forward(self, x):
        # Reshape x to (batch_size, 1, input_depth, seq_length) for 2D avg pooling
        x = x.view(x.size(0),1, x.size(2), x.size(1))
        #print(x.shape)
        x = self.avgpool2d(x)
       # print(x.shape)

        # Reshape back to (batch_size, new_depth, new_seq_length) for 1D convolution
        x = x.view(x.size(0), -1, x.size(3))

        # Apply 1D convolution
        x = self.conv1(x)

        # Apply dropout to the inputs of the final layer
        x = self.dropout(x)

        # Reshape x for the linear layer
        x = x.view(x.size(0), -1)

        # Pass through the fully connected layer
        x = self.fc(x)

        # Apply Softplus activation
        x = self.softplus(x)

        return x


#@title Model (backbone + head)


"""
Putting it all together, the model consists of a backbone model
and a decoder head (you can turn off head for embeddings only too).

Here we use a simple head to do multi-classification, but
can also swap the head to do next token prediction too.  We defer to the main
HyenaDNA for that code, since pretraining with next token prediction isn't quite
feasible on colab.

"""

class HyenaDNAModel(nn.Module):
    def __init__(self, d_model: int, n_layer: int, d_inner: int, vocab_size: int,
                 layer=None, attn_layer_idx=None, attn_cfg=None, max_position_embeddings=0,
                 resid_dropout: float = 0.0, embed_dropout: float = 0.1,
                 layer_norm_epsilon: float = 1e-5, initializer_cfg=None,residual_in_fp32=False,
                 pad_vocab_size_multiple: int = 1, use_head=False, n_classes: int = 2,
                 device=None, dtype=None, **kwargs) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        if vocab_size % pad_vocab_size_multiple != 0:
            vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)

        self.use_head = use_head

        # check if layer (config) has d_model (HF code differs from main Safari code)
        if 'd_model' not in layer:
            layer['d_model'] = d_model

        self.backbone = LMBackbone(
            d_model=d_model, n_layer=n_layer, d_inner=d_inner, vocab_size=vocab_size,
            layer=layer, attn_layer_idx=attn_layer_idx, attn_cfg=attn_cfg,
            max_position_embeddings=max_position_embeddings,
            resid_dropout=resid_dropout, embed_dropout=embed_dropout,
            layer_norm_epsilon=layer_norm_epsilon,
            initializer_cfg=initializer_cfg, residual_in_fp32=residual_in_fp32,
            **factory_kwargs, **kwargs
        )

        # we only need a head if doing classification, otherwise we'll use the
        # hidden states as embeddings
        if self.use_head:

            #self.head = SequenceDecoder(d_model=d_model, d_output=1, l_output=230, mode='pool')
            input_depth = 256
            input_length = 32768 #160000
            output_length = (15*2+1)# *2 only one condition now #(57*2+1)*2
            avg_pool_size= (1,128) #(depth_axis, seqlen_axis)
            kernel_size = 9
            dilation=2
            stride = 2
            padding= (dilation * (kernel_size - 1)) // 2
            dropout_rate=.3
            hidden_depth=10
    

            self.head = ConvLinearModel(input_depth, input_length, output_length,avg_pool_size, kernel_size, stride, padding, dilation, dropout_rate, hidden_depth)
        
        # Initialize weights and apply final processing
        self.apply(partial(_init_weights, n_layer=n_layer,
                           **(initializer_cfg if initializer_cfg is not None else {})))

        # if self.use_head:
        #     self.tie_weights()

    # def tie_weights(self):
    #     self.head.weight = self.backbone.embeddings.word_embeddings.weight

    def forward(self, input_ids, position_ids=None, state=None): # state for the repo interface
        hidden_states = self.backbone(input_ids, position_ids=position_ids)
        if self.use_head:
            return self.head(hidden_states)

        else:
            return hidden_states


#@title Huggingface Pretrained Wrapper
# for Huggingface integration, we use a wrapper class around the model
# to load weights
import json
import os
import subprocess
import transformers
from transformers import PreTrainedModel, AutoModelForCausalLM, PretrainedConfig
import re

def inject_substring(orig_str):
    """Hack to handle matching keys between models trained with and without
    gradient checkpointing."""

    # modify for mixer keys
    pattern = r"\.mixer"
    injection = ".mixer.layer"

    modified_string = re.sub(pattern, injection, orig_str)

    # modify for mlp keys
    pattern = r"\.mlp"
    injection = ".mlp.layer"

    modified_string = re.sub(pattern, injection, modified_string)

    return modified_string

def load_weights(scratch_dict, pretrained_dict, checkpointing=False):
    """Loads pretrained (backbone only) weights into the scratch state dict.

    scratch_dict: dict, a state dict from a newly initialized HyenaDNA model
    pretrained_dict: dict, a state dict from the pretrained ckpt
    checkpointing: bool, whether the gradient checkpoint flag was used in the
    pretrained model ckpt. This slightly changes state dict keys, so we patch
    that if used.

    return:
    dict, a state dict with the pretrained weights loaded (head is scratch)

    # loop thru state dict of scratch
    # find the corresponding weights in the loaded model, and set it

    """

    # need to do some state dict "surgery"
    for key, value in scratch_dict.items():
        if 'backbone' in key:
            # the state dicts differ by one prefix, '.model', so we add that
            key_loaded = 'model.' + key
            # breakpoint()
            # need to add an extra ".layer" in key
            if checkpointing:
                key_loaded = inject_substring(key_loaded)
            try:
                scratch_dict[key] = pretrained_dict[key_loaded]
            except:
                raise Exception('key mismatch in the state dicts!')

    # scratch_dict has been updated
    return scratch_dict

def download_files(url, output_dir):
    # Create the output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Download the files
    response = requests.get(url, stream=True)
    
    if response.status_code == 200:
        # Get the filename from the URL
        filename = os.path.join(output_dir, url.split('/')[-1])
        
        # Write the downloaded content to the file
        with open(filename, 'wb') as f:
            shutil.copyfileobj(response.raw, f)
        
        return filename
    else:
        raise ValueError(f"Failed to download files from {url}")

def download_pretrained_model(model_name, output_dir):
    hf_url = f'https://huggingface.co/LongSafari/{model_name}/resolve/main/config.json'
    config_file = download_files(hf_url, output_dir)
    config = json.load(open(config_file))
    
    weights_url = f'https://huggingface.co/LongSafari/{model_name}/resolve/main/weights.ckpt'
    weights_file = download_files(weights_url, output_dir)
    
    return config, weights_file


class HyenaDNAPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """
    base_model_prefix = "hyenadna"

    def __init__(self, config):
        pass

    def forward(self, input_ids, **kwargs):
        return self.model(input_ids, **kwargs)

    @classmethod
    def from_pretrained(cls,
                        path,
                        model_name,
                        download=False,
                        config=None,
                        device='cpu',
                        use_head=False,
                        n_classes=2,
                      ):
        # first check if it is a local path
        pretrained_model_name_or_path = os.path.join(path, model_name)
        if os.path.isdir(pretrained_model_name_or_path) and download == False:
            if config is None:
                config = json.load(open(os.path.join(pretrained_model_name_or_path, 'config.json')))
        else:
            # hf_url = f'https://huggingface.co/LongSafari/{model_name}'
            # print("here again")
            # subprocess.run(f'rm -rf {pretrained_model_name_or_path}', shell=True)
            # command = f'mkdir -p {path} && cd {path} && git lfs install && git clone {hf_url}'
            # subprocess.run(command, shell=True)
            config, weights_file = download_pretrained_model(model_name, pretrained_model_name_or_path)

            if config is None:
                config = json.load(open(os.path.join(pretrained_model_name_or_path, 'config.json')))

        scratch_model = HyenaDNAModel(**config, use_head=use_head, n_classes=n_classes)  # the new model format
        loaded_ckpt = torch.load(
            os.path.join(pretrained_model_name_or_path, 'weights.ckpt'),
            map_location=torch.device(device)
        )

        # need to load weights slightly different if using gradient checkpointing
        if config.get("checkpoint_mixer", False):
            checkpointing = config["checkpoint_mixer"] == True or config["checkpoint_mixer"] == True
        else:
            checkpointing = False

        # grab state dict from both and load weights
        state_dict = load_weights(scratch_model.state_dict(), loaded_ckpt['state_dict'], checkpointing=checkpointing)

        # scratch model has now been updated
        scratch_model.load_state_dict(state_dict)
        print("Loaded pretrained weights ok!")
        return scratch_model

#@title Tokenizer

"""
Just a simple character level tokenizer.

From: https://github.com/dariush-bahrami/character-tokenizer/blob/master/charactertokenizer/core.py

CharacterTokenzier for Hugging Face Transformers.
This is heavily inspired from CanineTokenizer in transformers package.
"""
import json
import os
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Union

from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer

class CharacterTokenizer(PreTrainedTokenizer):
    def __init__(self, characters: Sequence[str], model_max_length: int, padding_side: str='left', **kwargs):
        """Character tokenizer for Hugging Face transformers.
        Args:
            characters (Sequence[str]): List of desired characters. Any character which
                is not included in this list will be replaced by a special token called
                [UNK] with id=6. Following are list of all of the special tokens with
                their corresponding ids:
                    "[CLS]": 0
                    "[SEP]": 1
                    "[BOS]": 2
                    "[MASK]": 3
                    "[PAD]": 4
                    "[RESERVED]": 5
                    "[UNK]": 6
                an id (starting at 7) will be assigned to each character.
            model_max_length (int): Model maximum sequence length.
        """
        self.characters = characters
        self.model_max_length = model_max_length
        bos_token = AddedToken("[BOS]", lstrip=False, rstrip=False)
        eos_token = AddedToken("[SEP]", lstrip=False, rstrip=False)
        sep_token = AddedToken("[SEP]", lstrip=False, rstrip=False)
        cls_token = AddedToken("[CLS]", lstrip=False, rstrip=False)
        pad_token = AddedToken("[PAD]", lstrip=False, rstrip=False)
        unk_token = AddedToken("[UNK]", lstrip=False, rstrip=False)

        mask_token = AddedToken("[MASK]", lstrip=True, rstrip=False)

        super().__init__(
            bos_token=bos_token,
            eos_token=sep_token,
            sep_token=sep_token,
            cls_token=cls_token,
            pad_token=pad_token,
            mask_token=mask_token,
            unk_token=unk_token,
            add_prefix_space=False,
            model_max_length=model_max_length,
            padding_side=padding_side,
            **kwargs,
        )

        self._vocab_str_to_int = {
            "[CLS]": 0,
            "[SEP]": 1,
            "[BOS]": 2,
            "[MASK]": 3,
            "[PAD]": 4,
            "[RESERVED]": 5,
            "[UNK]": 6,
            **{ch: i + 7 for i, ch in enumerate(characters)},
        }
        self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}

    @property
    def vocab_size(self) -> int:
        return len(self._vocab_str_to_int)

    def _tokenize(self, text: str) -> List[str]:
        return list(text)

    def _convert_token_to_id(self, token: str) -> int:
        return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"])

    def _convert_id_to_token(self, index: int) -> str:
        return self._vocab_int_to_str[index]

    def convert_tokens_to_string(self, tokens):
        return "".join(tokens)

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]
        result = cls + token_ids_0 + sep
        if token_ids_1 is not None:
            result += token_ids_1 + sep
        return result

    def get_special_tokens_mask(
        self,
        token_ids_0: List[int],
        token_ids_1: Optional[List[int]] = None,
        already_has_special_tokens: bool = False,
    ) -> List[int]:
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0,
                token_ids_1=token_ids_1,
                already_has_special_tokens=True,
            )

        result = [1] + ([0] * len(token_ids_0)) + [1]
        if token_ids_1 is not None:
            result += ([0] * len(token_ids_1)) + [1]
        return result

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]

        result = len(cls + token_ids_0 + sep) * [0]
        if token_ids_1 is not None:
            result += len(token_ids_1 + sep) * [1]
        return result

    def get_config(self) -> Dict:
        return {
            "char_ords": [ord(ch) for ch in self.characters],
            "model_max_length": self.model_max_length,
        }

    @classmethod
    def from_config(cls, config: Dict) -> "CharacterTokenizer":
        cfg = {}
        cfg["characters"] = [chr(i) for i in config["char_ords"]]
        cfg["model_max_length"] = config["model_max_length"]
        return cls(**cfg)

    def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
        cfg_file = Path(save_directory) / "tokenizer_config.json"
        cfg = self.get_config()
        with open(cfg_file, "w") as f:
            json.dump(cfg, f, indent=4)

    @classmethod
    def from_pretrained(cls, save_directory: Union[str, os.PathLike], **kwargs):
        cfg_file = Path(save_directory) / "tokenizer_config.json"
        with open(cfg_file) as f:
            cfg = json.load(f)
        return cls.from_config(cfg)


#@title GenomicBenchmark dataset

"""
The GenomicBenchmarks dataset will automatically download to /contents on colab.
There are 8 datasets to choose from.

"""

from random import random
import numpy as np
from pathlib import Path
from torch.utils.data import DataLoader

#from genomic_benchmarks.loc2seq import download_dataset
#from genomic_benchmarks.data_check import is_downloaded


# helper functions
def exists(val):
    return val is not None

def coin_flip():
    return random() > 0.5


string_complement_map = {'A': 'T', 'C': 'G', 'G': 'C', 'T': 'A', 'a': 't', 'c': 'g', 'g': 'c', 't': 'a'}
# augmentation
def string_reverse_complement(seq):
    rev_comp = ''
    for base in seq[::-1]:
        if base in string_complement_map:
            rev_comp += string_complement_map[base]
        # if bp not complement map, use the same bp
        else:
            rev_comp += base
    return rev_comp


from random import random
import numpy as np
from pathlib import Path
from torch.utils.data import DataLoader

from genomic_benchmarks.loc2seq import download_dataset
from genomic_benchmarks.data_check import is_downloaded


class GenomicBenchmarkDataset(torch.utils.data.Dataset):

    '''
    Loop thru bed file, retrieve (chr, start, end), query fasta file for sequence.
    Returns a generator that retrieves the sequence.

    Genomic Benchmarks Dataset, from:
    https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks


    '''

    def __init__(
        self,
        split,
        max_length,
        dataset_name='human_enhancers_cohn',
        d_output=2, # default binary classification
        dest_path="./data/", # default for colab
        tokenizer=None,
        tokenizer_name=None,
        use_padding=None,
        add_eos=False,
        rc_aug=False,
        return_augs=False,
    ):

        self.max_length = max_length
        self.use_padding = use_padding
        self.tokenizer_name = tokenizer_name
        self.tokenizer = tokenizer
        self.return_augs = return_augs
        self.add_eos = add_eos
        self.d_output = d_output  # needed for decoder to grab
        self.rc_aug = rc_aug

        if not is_downloaded(dataset_name, cache_path=dest_path):
            print("downloading {} to {}".format(dataset_name, dest_path))
            download_dataset(dataset_name, version=0, dest_path=dest_path)
        else:
            print("already downloaded {}-{}".format(split, dataset_name))

        # use Path object
        base_path = Path(dest_path) / dataset_name / split

        self.all_paths = []
        self.all_labels = []
        label_mapper = {}

        for i, x in enumerate(base_path.iterdir()):
            label_mapper[x.stem] = i

        for label_type in label_mapper.keys():
            for x in (base_path / label_type).iterdir():
                self.all_paths.append(x)
                self.all_labels.append(label_mapper[label_type])

    def __len__(self):
        return len(self.all_paths)

    def __getitem__(self, idx):
        txt_path = self.all_paths[idx]
        with open(txt_path, "r") as f:
            content = f.read()
        x = content
        y = self.all_labels[idx]

        # apply rc_aug here if using
        if self.rc_aug and coin_flip():
            x = string_reverse_complement(x)

        seq = self.tokenizer(x,
            add_special_tokens=False,
            padding="max_length" if self.use_padding else None,
            max_length=self.max_length,
            truncation=True,
        )  # add cls and eos token (+2)
        seq = seq["input_ids"]  # get input_ids

        # need to handle eos here
        if self.add_eos:
            # append list seems to be faster than append tensor
            seq.append(self.tokenizer.sep_token_id)

        # convert to tensor
        seq = torch.LongTensor(seq)

        # need to wrap in list
        target = torch.LongTensor([y])

        return seq, target

############################## Our Datasets ###############

import torch
import pandas as pd
from Bio import SeqIO
import zlib
import sys


class CLASTER_HYENA_Dataset(torch.utils.data.Dataset):

    '''
    Returns a generator that retrieves the sequence.
    '''

    def __init__(
        self,
        fasta_file,
        csv_file,
        max_length,
        target_columns,
        d_output=None,
        tokenizer=None,
        tokenizer_name=None,
        use_padding=None,
        add_eos=False,
        rc_aug=False,
        return_augs=False,
    ):

        self.max_length = max_length
        self.use_padding = use_padding
        self.tokenizer_name = tokenizer_name
        self.tokenizer = tokenizer
        self.return_augs = return_augs
        self.add_eos = add_eos
        self.rc_aug = rc_aug

        # Read CSV file
        self.targets_df = pd.read_csv(csv_file).set_index('ID')

        # Extract continuous target vectors
        self.continuous_targets = self.targets_df[target_columns].values

        # Read FASTA file and store sequences and corresponding IDs in a dictionary
        self.sequences = {}

        for record in SeqIO.parse(fasta_file, "fasta"):
            header_parts = record.description.split("::")
            seq_id = header_parts[0]  # Remove '>'
            sequence = str(record.seq)


            # Create reverse complement and store it with the "_flipped" identifier
            reverse_complement = string_reverse_complement(sequence)
            flipped_seq_id = seq_id + "_flipped"
            # Crop distances:
            crop_dist = (160000-32768)//2
            # Store the sequence with its ID
            self.sequences[seq_id] = sequence[crop_dist:-crop_dist]
            self.sequences[flipped_seq_id] = reverse_complement[crop_dist:-crop_dist]

        # Extract IDs from the target CSV file (assuming the ID column is named 'ID')
        self.target_ids = self.targets_df.index

        # Match sequence IDs with target IDs and store their indices
        self.matched_indices = [i for i, seq_id in enumerate(self.target_ids) if seq_id in self.sequences]

        # Determine the number of output dimensions if not provided
        if d_output is None:
            self.d_output = len(target_columns)
        else:
            self.d_output = d_output

    def __len__(self):
        return len(self.matched_indices)

    def __getitem__(self, idx):
        # Get sequence ID and retrieve the sequence
        seq_id = self.target_ids[self.matched_indices[idx]]
        sequence = self.sequences[seq_id]

        # Get continuous target vector
        target = self.continuous_targets[self.matched_indices[idx]]

        # Tokenize sequence
        encoded_sequence = self.tokenizer(
            sequence,
            add_special_tokens=False,
            padding="max_length" if self.use_padding else None,
            max_length=self.max_length,
            truncation=True,
        )["input_ids"]    # Get input_ids

        # Need to handle eos here
        if self.add_eos:
            # Append list seems to be faster than append tensor
            encoded_sequence.append(self.tokenizer.sep_token_id)

        # Convert to tensor
        encoded_sequence = torch.LongTensor(encoded_sequence)

        # Convert target to tensor
        target = torch.FloatTensor(target)

        return encoded_sequence, target


import torch.optim as optim
import logging

LOG_FILENAME = "../checkpoints/hyenadna_32k.log"
logging.basicConfig(filename=LOG_FILENAME, level=logging.INFO)  

"""
We provide simple training code for the GenomicBenchmark datasets.
"""

def train(model, device, train_loader, optimizer, epoch, loss_fn):
    """Training loop."""
    model.train()
    train_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target.squeeze())
        train_loss += loss_fn(output, target.squeeze()).item()  # sum up batch loss
        loss.backward()
        optimizer.step()

    train_loss /= len(train_loader.dataset)
    logging.info('Train Epoch: {}\tAverage Loss: {:.6f}'.format(
        epoch, train_loss))
    torch.save(model,'../checkpoints/hyenadna-small-32k-seqlen/model_32k.pt')

def test(model, device, test_loader, loss_fn):
    """Test loop."""
    model.eval()
    test_loss = 0
    correct = 0

    targets = np.array([])
    predictions = np.array([])

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += loss_fn(output, target.squeeze()).item()  # sum up batch loss
            targets = np.concatenate([targets,torch.flatten(torch.clone(target)).cpu().numpy()])
            predictions = np.concatenate([predictions,torch.flatten(torch.clone(output)).cpu().numpy()]) 
            
    test_loss /= len(test_loader.dataset)

    np.save("Hyena_finetunned_targets.npy",targets)
    np.save("Hyena_finetunned_predictions.npy",predictions)

    logging.info('\nTest set: Average loss: {:.6f}\t'.format(
        test_loss))

import json
import os
import subprocess
import transformers
from transformers import PreTrainedModel, AutoModelForCausalLM, PretrainedConfig
import nvidia_smi

def run_train():
    '''
    Main entry point for training.  Select the dataset name and metadata, as
    well as model and training args, and you're off to the genomic races!

    ### GenomicBenchmarks Metadata
    # there are 8 datasets in this suite, choose 1 at a time, with their corresponding settings
    # name                                num_seqs        num_classes     median len    std
    # dummy_mouse_enhancers_ensembl       1210            2               2381          984.4
    # demo_coding_vs_intergenomic_seqs    100_000         2               200           0
    # demo_human_or_worm                  100_000         2               200           0
    # human_enhancers_cohn                27791           2               500           0
    # human_enhancers_ensembl             154842          2               269           122.6
    # human_ensembl_regulatory            289061          3               401           184.3
    # human_nontata_promoters             36131           2               251           0
    # human_ocr_ensembl                   174756          2               315           108.1

    '''
    # experiment settings:
    num_epochs = 100  # ~100 seems fine
    max_length = 500  # max len of sequence of dataset (of what you want)
    use_padding = True
    dataset_name = 'human_enhancers_cohn'
    batch_size = 256
    learning_rate = 6e-4  # good default for Hyena
    rc_aug = True  # reverse complement augmentation
    add_eos = False  # add end of sentence token
    weight_decay = 0.1

    # for fine-tuning, only the 'tiny' model can fit on colab
    pretrained_model_name = 'hyenadna-tiny-1k-seqlen'  # use None if training from scratch

    # we need these for the decoder head, if using
    use_head = True
    n_classes = 2

    # you can override with your own backbone config here if you want,
    # otherwise we'll load the HF one by default
    backbone_cfg = None

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print("Using device:", device)
    logging.info(f"Using device: {device}")

    # instantiate the model (pretrained here)
    if pretrained_model_name in ['hyenadna-tiny-1k-seqlen']:
        # use the pretrained Huggingface wrapper instead
        model = HyenaDNAPreTrainedModel.from_pretrained(
            './checkpoints',
            pretrained_model_name,
            download=False,
            config=backbone_cfg,
            device=device,
            use_head=use_head,
            n_classes=n_classes,
        )

    # from scratch
    else:
        model = HyenaDNAModel(**backbone_cfg, use_head=use_head, n_classes=n_classes)

    # create tokenizer
    tokenizer = CharacterTokenizer(
        characters=['A', 'C', 'G', 'T', 'N'],  # add DNA characters, N is uncertain
        model_max_length=max_length + 2,  # to account for special tokens, like EOS
        add_special_tokens=False,  # we handle special tokens elsewhere
        padding_side='left', # since HyenaDNA is causal, we pad on the left
    )

    # create datasets
    ds_train = GenomicBenchmarkDataset(
        max_length = max_length,
        use_padding = use_padding,
        split = 'train',
        tokenizer=tokenizer,
        dataset_name=dataset_name,
        rc_aug=rc_aug,
        add_eos=add_eos,
    )

    ds_test = GenomicBenchmarkDataset(
        max_length = max_length,
        use_padding = use_padding,
        split = 'test',
        tokenizer=tokenizer,
        dataset_name=dataset_name,
        rc_aug=rc_aug,
        add_eos=add_eos,
    )


    train_loader = DataLoader(ds_train, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(ds_test, batch_size=batch_size, shuffle=False)

    # loss function
    loss_fn = nn.CrossEntropyLoss()

    # create optimizer
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    model.to(device)

    for epoch in range(num_epochs):
        train(model, device, train_loader, optimizer, epoch, loss_fn)
        test(model, device, test_loader, loss_fn)
        optimizer.step()


def run_train_CLASTER_HYENA():

    '''
    Main entry point for training.  Select the dataset name and metadata, as
    well as model and training args, and you're off to the genomic races!

    '''
    # experiment settings:
    ONLY_TEST = True
    num_epochs = 200 #50  # ~100 seems fine
    max_length = 32768 #160000  # max len of sequence of dataset (of what you want)
    use_padding = True
    batch_size = 16
    learning_rate = 5e-7  # Super small batch size -> lower learning rate
    rc_aug = True  # reverse complement augmentation
    add_eos = False  # add end of sentence token
    weight_decay = 0.1
    n_classes=2
    # create datasets

    train_fasta_file = "../inputs/DNA_sequences/training_boundaries.fasta"
    test_fasta_file = "../inputs/DNA_sequences/test_boundaries.fasta"
    train_csv_file = "../targets/training_targets.csv"
    test_csv_file = "../targets/test_targets.csv"

    #max_length = 160000
    N_BINS = 15 #57
    target_columns = [f"{i}{cond}" for cond in ["_ctrl"] for i in range(-N_BINS, N_BINS+1)]

    # for fine-tuning, only the 'tiny' model can fit on colab
    pretrained_model_name = 'hyenadna-small-32k-seqlen' #'hyenadna-medium-160k-seqlen'  # use None if training from scratch
    #model_name_or_path = './checkpoints/hyenadna-medium-160k-seqlen-custom/'
    # we need these for the decoder head, if using
    use_head = True

    # you can override with your own backbone config here if you want,
    # otherwise we'll load the HF one by default
    backbone_cfg = None

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print("Using device:", device)


    #instantiate the model (pretrained here)
    if pretrained_model_name in ['hyenadna-small-32k-seqlen']: #['hyenadna-medium-160k-seqlen']:
        # use the pretrained Huggingface wrapper instead
        print('Here')
        model = HyenaDNAPreTrainedModel.from_pretrained(
            '../checkpoints/',
            pretrained_model_name,
            download=True,
            config=backbone_cfg,
            device=device,
            use_head=use_head,
            n_classes=n_classes,
        )

    # from scratch
    else:
        model = HyenaDNAModel(**backbone_cfg, use_head=use_head, n_classes=n_classes)

    # backbone_cfg = json.load(open(os.path.join(model_name_or_path, 'config.json')))
    # model = HyenaDNAModel(**backbone_cfg, use_head=use_head, n_classes=n_classes)
    #print(torch.cuda.memory_allocated(device=device))

    # create tokenizer
    tokenizer = CharacterTokenizer(
        characters=['A', 'C', 'G', 'T', 'N'],  # add DNA characters, N is uncertain
        model_max_length=max_length + 2,  # to account for special tokens, like EOS
        add_special_tokens=False,  # we handle special tokens elsewhere
        padding_side='left', # since HyenaDNA is causal, we pad on the left
    )


    ds_train = CLASTER_HYENA_Dataset(train_fasta_file,
        train_csv_file,
        max_length,
        target_columns,
        d_output=None, # No need for default here, it will be inferred from target_columns
        tokenizer=tokenizer,
        tokenizer_name=None,
        use_padding=use_padding,
        add_eos=add_eos,
        rc_aug=rc_aug,
        return_augs=False)

    ds_test = CLASTER_HYENA_Dataset(test_fasta_file,
        test_csv_file,
        max_length,
        target_columns,
        d_output=None, # No need for default here, it will be inferred from target_columns
        tokenizer=tokenizer,
        tokenizer_name=None,
        use_padding=use_padding,
        add_eos=add_eos,
        rc_aug=rc_aug,
        return_augs=False)


    train_loader = DataLoader(ds_train, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(ds_test, batch_size=batch_size, shuffle=False)

    # loss function
    loss_fn = nn.SmoothL1Loss()

    # create optimizer
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    model.to(device)

    # nvidia_smi.nvmlInit()
    # deviceCount = nvidia_smi.nvmlDeviceGetCount()
    # for i in range(deviceCount):
    #     handle = nvidia_smi.nvmlDeviceGetHandleByIndex(i)
    #     util = nvidia_smi.nvmlDeviceGetUtilizationRates(handle)
    #     mem = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
    #     print(f"|Device {i}| Mem Free: {mem.free/1024**2:5.2f}MB / {mem.total/1024**2:5.2f}MB | gpu-util: {util.gpu/100.0:3.1%} | gpu-mem: {util.memory/100.0:3.1%} |")

    # print(torch.cuda.memory_allocated(device=device))

    if not ONLY_TEST:
        for epoch in range(num_epochs):
            train(model, device, train_loader, optimizer, epoch, loss_fn)
            test(model, device, test_loader, loss_fn)
            optimizer.step()

    else:
        model = torch.load('../checkpoints/hyenadna-small-32k-seqlen/model_32k.pt')
        model.eval()
        test(model, device, test_loader, loss_fn)


######################################################################################################################################
#@title Single example
import json
import os
import subprocess
import transformers
from transformers import PreTrainedModel, AutoModelForCausalLM, PretrainedConfig
from pathlib import Path



def inference_single_sequential(path: Path, filename: str, savepath: Path, checkpoint_folder: str ):

    '''
    this selects which backbone to use, and grabs weights/ config from HF
    4 options:
      'hyenadna-tiny-1k-seqlen'   # fine-tune on colab ok
      'hyenadna-small-32k-seqlen'
      'hyenadna-medium-160k-seqlen'  # inference only on colab
      'hyenadna-medium-450k-seqlen'  # inference only on colab
      'hyenadna-large-1m-seqlen'  # inference only on colab
    '''

    # you only need to select which model to use here, we'll do the rest!
    pretrained_model_name = 'hyenadna-medium-160k-seqlen' #'hyenadna-tiny-1k-seqlen'# 'hyenadna-medium-450k-seqlen'

    max_lengths = {
        'hyenadna-tiny-1k-seqlen': 1024,
        'hyenadna-small-32k-seqlen': 32768,
        'hyenadna-medium-160k-seqlen': 160000,
        'hyenadna-medium-450k-seqlen': 450000,  # T4 up to here
        'hyenadna-large-1m-seqlen': 1_000_000,  # only A100 (paid tier)
    }

    max_length = max_lengths[pretrained_model_name]  # auto selects

    # data settings:
    use_padding = True
    rc_aug = False  # reverse complement augmentation
    add_eos = False  # add end of sentence token

    # we need these for the decoder head, if using
    use_head = False
    n_classes = 2  # not used for embeddings only

    # you can override with your own backbone config here if you want,
    # otherwise we'll load the HF one in None
    backbone_cfg = None

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print("Using device:", device)

    # instantiate the model (pretrained here)
    if pretrained_model_name in ['hyenadna-tiny-1k-seqlen',
                                 'hyenadna-small-32k-seqlen',
                                 'hyenadna-medium-160k-seqlen',
                                 'hyenadna-medium-450k-seqlen',
                                 'hyenadna-large-1m-seqlen']:
        # use the pretrained Huggingface wrapper instead
        model = HyenaDNAPreTrainedModel.from_pretrained(
            checkpoint_folder,
            pretrained_model_name,
            download=True,
            config=backbone_cfg,
            device=device,
            use_head=use_head,
            n_classes=n_classes,
        )

    # from scratch
    elif pretrained_model_name is None:
        model = HyenaDNAModel(**backbone_cfg, use_head=use_head, n_classes=n_classes)

    # create tokenizer
    tokenizer = CharacterTokenizer(
        characters=['A', 'C', 'G', 'T', 'N'],  # add DNA characters, N is uncertain
        model_max_length=max_length + 2,  # to account for special tokens, like EOS
        add_special_tokens=False,  # we handle special tokens elsewhere
        padding_side='left', # since HyenaDNA is causal, we pad on the left
    )

    #### Get named embeddings ####

    with open(path / filename, 'r') as f:
        for line in f:
            if line[0] == '>':
                seq_id = line.strip().split(':')[0][1:]
            else:
                seq = line.strip()
                rev_seq = string_reverse_complement(seq)
                for seq_name,sequence in zip([seq_id, seq_id + '_flipped'],[seq, rev_seq]):
                    savefile = savepath / f"{seq_name}.npy"
                    if not savefile.is_file():
                        #sequence = 'ACTG' * int(max_length/4)
                        tok_seq = tokenizer(sequence)
                        tok_seq = tok_seq["input_ids"]  # grab ids

                        # place on device, convert to tensor
                        tok_seq = torch.LongTensor(tok_seq).unsqueeze(0)  # unsqueeze for batch dim
                        tok_seq = tok_seq.to(device)

                        # prep model and forward
                        model.to(device)
                        model.eval()
                        with torch.inference_mode():
                            embeddings = model(tok_seq) 
                            embeddings = embeddings.to(dtype=torch.float16).cpu()
                            print(torch.max(embeddings), torch.min(embeddings))
                            np.save(savepath / f"{seq_name}.npy",embeddings)  # embeddings here!


################################################################################
######################### S C R I P T ##########################################

path = Path("../inputs/DNA_sequences/")
filename : str = "test_boundaries.fasta"
savepath: Path = Path("../inputs/DNA_sequences/test_embeddings_Hyena-DNA/")
savepath.mkdir(parents=True, exist_ok=True)

checkpoint_folder : str = "../checkpoints/"
#Path(checkpoint_folder).mkdir(parents=True, exist_ok=True)


# Uncomment if we want to get pretrained embeddings
inference_single_sequential(path, filename, savepath, checkpoint_folder)

# Uncomment if we want to train and test HyenaDNA:
#run_train_CLASTER_HYENA()



Overwriting Hyena_DNA_Esrum.py


## 2. Enformer

The original paper can be found as:

Avsec, Ž., Agarwal, V., Visentin, D. et al. Effective gene expression prediction from sequence by integrating long-range interactions. Nat Methods 18, 1196–1203 (2021). https://doi.org/10.1038/s41592-021-01252-x . In the case of the Enformer, we simply obtained the embeddings from the pretrained model matching our sequences. 


Code to build the Enformer using pytorch and load pretrained weights was obtained from:
https://github.com/lucidrains/enformer-pytorch.

>KUDOS: Huge kudos to Phil Wang (lucidrains) for open sourcing the pytorch version of the Enformer



**Install the Enformer pytorch package:**

In [3]:
! pip install enformer-pytorch>=0.5

**Get Enformer embeddings**

In [25]:
%%writefile Enformer_GPU.py

from pathlib import Path
import numpy as np
import torch
from enformer_pytorch import from_pretrained
import torch.nn.functional as F
import os

# Set the device to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the model globally and move it to the specified device
model = from_pretrained('EleutherAI/enformer-official-rough').to(device)

###### Functions #######

def pad_tensor_symmetrically(input_tensor, target_length, pad_value=4):
    if input_tensor.shape != (1, 160000):
        raise ValueError("Input tensor must have shape (1, 160000)")

    total_padding = target_length - input_tensor.shape[1]
    padding_per_side = total_padding // 2
    padded_tensor = F.pad(input_tensor, (padding_per_side, padding_per_side), "constant", pad_value)
    return padded_tensor

def map_dna_to_numeric_tensor(sequence):
    mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3, 'N': 4}
    numeric_sequence = [mapping[nucleotide] for nucleotide in sequence if nucleotide in mapping]
    sequence_tensor = torch.tensor(numeric_sequence, dtype=torch.long).unsqueeze(0)
    return sequence_tensor

string_complement_map = {'A': 'T', 'C': 'G', 'G': 'C', 'T': 'A', 'a': 't', 'c': 'g', 'g': 'c', 't': 'a'}

def string_reverse_complement(seq):
    return ''.join(string_complement_map.get(base, base) for base in reversed(seq))

def process_sequence(path, seq_id, sequence, savepath):
    rev_seq = string_reverse_complement(sequence)
    for seq_name, seq in zip([seq_id, seq_id[:-8] + '_rev'], [sequence, rev_seq]):
        savefile = savepath / f"{seq_name}.npy"
        if not savefile.is_file():
            seq_tensor = map_dna_to_numeric_tensor(seq)
            seq_tensor = pad_tensor_symmetrically(seq_tensor, target_length=196608)
            seq_tensor = seq_tensor.to(device)  # Move data to GPU
            _, embeddings = model(seq_tensor, return_embeddings=True)
            embeddings = embeddings.detach().cpu().numpy()  # Move data back to CPU for saving
            np.save(savefile, embeddings.astype(np.float32))

def create_Enformer_embeddings(path, filename, savepath):
    with open(path / filename, 'r') as f:
        for line in f:
            if line.startswith('>'):
                seq_id = line.strip().split(':')[0][1:]
            else:
                seq = line.strip()
                process_sequence(path, seq_id, seq, savepath)

#### Main script ####

path = Path("../inputs/DNA_sequences/")
split_list = ["training", "validation", "test"]

for split in split_list:
    filename = f"{split}_boundaries.fasta"
    savepath = Path(f"../inputs/DNA_sequences/{split}_embeddings_Enformer/")
    savepath.mkdir(parents=True, exist_ok=True)
    create_Enformer_embeddings(path, filename, savepath)

Overwriting Enformer_GPU.py


We can now run it as a python file in a slurm based system. We will greatly benefit from the multiprocessing functionality.

```bash
srun -- python Enformer.py
```


**Train and test head on top of embeddings**

>_Note: The test set is quite large. I predicted it all at once using 100 CPUs, but the code can be splitted to predict in batches and join the predictions afterwards if needed. Additionally, I removed all reversed (duplicated) samples from inputs and targets for test predictions to match those of Hyena._

In [4]:
%%writefile Enformer_head.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
#from sklearn.metrics import r2_score
import os
import matplotlib.pyplot as plt
from pathlib import Path
import logging
from tqdm import tqdm

###### Data class and head Model ######################
class EnformerHeadDataset(Dataset):
    """
    This class reads input numpy arrays and target csv files and returns input-target torch tensors.
    Args:
        - data_folder: Directory containing our input files stored as separate numpy arrays (sampleID.npy).
        - targets_file: CSV file containing the IDs of the samples and their corresponding targets.
        - N_kbp: Number of kilobase pairs from TSS to include in the target arrays. For example, N_kbp = 32 means targets from -32 kbp to +32 kbp.
        - stack_type: "hstack" or "vstack" for the concatenation method of control and treated target arrays.
    Returns:
        - Input or data tensor.
        - Target tensor.
    """
    def __init__(self, data_folder, targets_file, stack_type, bin_size):
        assert os.path.exists(data_folder), f"The specified data folder does not exist: {data_folder}"
        assert stack_type in ["hstack", "vstack"], f"Invalid stack type: {stack_type}"
        self.data_folder = data_folder
        self.stack_type = stack_type
        self.bin_size = bin_size
        targets_df = pd.read_csv(targets_file)

        
        ctrl_cols = [col for col in targets_df.columns if "ctrl" in col]
        treated_cols = [col for col in targets_df.columns if "treated" in col]

        # Filter out targets for which input files do not exist and separate control and treated targets
        self.sample_info = []
        for _, row in targets_df.iterrows():
            sample_id = row['ID']
            file_path = os.path.join(data_folder, f"{sample_id}.npy")
            if os.path.exists(file_path):
                ctrl_target = row[ctrl_cols].values
                treated_target = row[treated_cols].values
                self.sample_info.append((sample_id, ctrl_target, treated_target))

    def __len__(self):
        return len(self.sample_info)

    def __getitem__(self, idx):
        sample_name, ctrl_target, treated_target = self.sample_info[idx]
        data_path = os.path.join(self.data_folder, f"{sample_name}.npy")
        data = np.load(data_path)

        # Stack the targets based on stack_type
        if self.stack_type == "hstack":
            target = np.hstack((ctrl_target, treated_target))
        else:  # vstack
            target = np.vstack((ctrl_target, treated_target))

        target = torch.from_numpy(np.array(target, dtype=np.float32))
        return torch.tensor(data.reshape(data.shape[1], data.shape[2]), dtype=torch.float32), torch.tensor(target, dtype=torch.float32)

class DNAConvNet(nn.Module):
    def __init__(self, input_depth, output_size, n_conditions, kernel_size, stride, padding, dilation, dropout_prob):
        super(DNAConvNet, self).__init__()
        self.dropout = nn.Dropout(p=dropout_prob)
        self.conv1 = nn.Conv1d(in_channels=input_depth, out_channels=n_conditions, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation)
        self.softplus = nn.Softplus()  # Adding the Softplus activation layer as in Enformer

    def forward(self, x):
        x = x.view(x.size(0), x.size(2), x.size(1)) # (64, 3072,896)
        # Apply dropout to the input
        x = self.dropout(x)

        x = self.conv1(x)
        x = self.softplus(x)  # Applying softplus activation after convolution
        return x

class DNAConvandDense(nn.Module):
    def __init__(self, input_depth, input_length, output_length, kernel_size, stride, padding, dilation, dropout_prob, hidden_depth):
        super(DNAConvandDense, self).__init__()
        self.dropout = nn.Dropout(p=dropout_prob)
        self.conv1d = nn.Conv1d(in_channels=input_depth, out_channels=hidden_depth, kernel_size=kernel_size,stride=stride, padding=padding, dilation=dilation)
        
        # Calculate the output length after convolution
        self.seq_length_after_conv = self.calculate_output_length(input_length, kernel_size, stride, padding, dilation)
        
        # Fully connected layer with calculated input size
        self.fc = nn.Linear(hidden_depth * self.seq_length_after_conv, output_length)
        
        # Softplus activation
        self.softplus = nn.Softplus()


    def calculate_output_length(self, input_length, kernel_size, stride, padding, dilation):
        return ((input_length + 2 * padding - dilation * (kernel_size - 1) - 1) // stride) + 1

    def forward(self, x):
        x = x.view(x.size(0), x.size(2), x.size(1)) # (64, 3072,896)
        # Apply 1D convolution
        x = self.conv1d(x)
        # Apply dropout to the input
        x = self.dropout(x)

        # Flatten the convolution output
        x = x.view(x.size(0), -1)

        # Pass through the fully connected layer
        x = self.fc(x)
        
        # Apply Softplus activation
        x = self.softplus(x)

        return x


class ConvLinearModel(nn.Module):
    def __init__(self, input_depth, input_length, output_length, kernel_size, stride, padding, dilation, dropout_prob, hidden_depth):
        super(ConvLinearModel, self).__init__()
        # 2D Average Pooling
        self.avgpool2d = nn.AvgPool2d(kernel_size=avg_pool_size)

        # Dropout
        self.dropout = nn.Dropout(dropout_rate)

        # Adjust the input depth and sequence length based on the 2D pooling
        new_depth = input_depth // avg_pool_size[0]
        new_seq_length = input_length // avg_pool_size[1]

        # 1D Convolution
        self.conv1 = nn.Conv1d(in_channels=new_depth, out_channels=hidden_depth, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation)

        # Calculate the output length after convolution
        self.seq_length_after_conv = self.calculate_output_length(new_seq_length, kernel_size, stride, padding, dilation)
        
        # Fully connected layer with calculated input size
        self.fc = nn.Linear(hidden_depth * self.seq_length_after_conv, output_length)
        self.softplus = nn.Softplus()
        
    def calculate_output_length(self, input_length, kernel_size, stride, padding, dilation):
        return ((input_length + 2 * padding - dilation * (kernel_size - 1) - 1) // stride) + 1

    def forward(self, x):
        # Reshape x to (batch_size, 1, input_depth, seq_length) for 2D avg pooling
        x = x.view(x.size(0),1, x.size(2), x.size(1)) 
        x = self.avgpool2d(x)
        print(x.shape)

        # Reshape back to (batch_size, new_depth, new_seq_length) for 1D convolution
        x = x.view(x.size(0), -1, x.size(3))

        # Apply dropout to the inputs of the convolutions
        x = self.dropout(x)

        # Apply 1D convolution
        x = self.conv1(x)

        # Apply dropout to the inputs of the final layer
        x = self.dropout(x)

        # Reshape x for the linear layer
        x = x.view(x.size(0), -1)
        
        # Pass through the fully connected layer
        x = self.fc(x)

        # Apply Softplus activation
        x = self.softplus(x)

        return x


#### Custom loss #############
class ZIPLoss(nn.Module):
    def __init__(self, alpha=.5, eps=.02):
        super(ZIPLoss, self).__init__()
        self.alpha = alpha
        self.eps = eps

    def forward(self, inputs, targets):
        total_loss = 0
        for i in range(inputs.shape[1]):  # Iterate over channels
            lambda_ = inputs[:, i, :]
            logit_zero = inputs[:, i, :]
            target = targets[:, i, :]

            # Adjust for non-true zeros
            mask = (target < self.eps).float()
            target = torch.where(mask == 1, torch.zeros_like(target), target)

            # Poisson loss part
            poisson_loss = lambda_ - target * torch.log(lambda_ + 1e-8)

            # Zero-inflation part
            zero_inflation_prob = torch.sigmoid(logit_zero)
            zero_inflation_loss = -torch.where(mask == 1,
                                               self.alpha * torch.log(zero_inflation_prob + 1e-8),
                                               torch.log(1 - zero_inflation_prob + 1e-8))

            # Combine the Poisson and zero-inflation losses for this channel
            channel_loss = zero_inflation_loss + poisson_loss

            total_loss += channel_loss.mean()

        return total_loss / inputs.shape[1]  # Average loss over channels

class ZINBLoss(nn.Module):
    def __init__(self, dispersion=None, zero_inflation_prob=None, channel_weights=None):
        """
        Zero-Inflated Negative Binomial loss module for multi-channel data with predefined extra parameters.

        Parameters:
        - dispersion: Optional, predefined dispersion parameter (alpha) for the negative binomial distribution.
                      Can be a scalar or a tensor with shape (channel,).
        - zero_inflation_prob: Optional, predefined tensor of zero-inflation probabilities with shape (batch_size, channel, seq_length).
        - channel_weights: Optional tensor with shape (channel,) representing the weight of each channel in the loss calculation.
        """
        super(ZINBLoss, self).__init__()
        self.epsilon = 1e-10

        if dispersion is not None:
            self.register_buffer('dispersion', torch.tensor(dispersion))
        else:
            self.register_buffer('dispersion', torch.tensor(1.0))

        self.zero_inflation_prob = zero_inflation_prob
        self.channel_weights = channel_weights

    def forward(self, true_counts, predicted_counts):
        batch_size, num_channels, seq_length = true_counts.shape

        if self.channel_weights is None:
            channel_weights = torch.ones(num_channels, device=true_counts.device)
        else:
            channel_weights = self.channel_weights

        if self.zero_inflation_prob is None:
            zero_inflation_prob = torch.zeros_like(predicted_counts)
        else:
            zero_inflation_prob = self.zero_inflation_prob

        total_loss = 0.0

        for channel in range(num_channels):
            # Extract data for the current channel
            true_counts_channel = true_counts[:, channel, :]
            predicted_counts_channel = predicted_counts[:, channel, :]
            predicted_zero_inflation_channel = zero_inflation_prob[:, channel, :]
            dispersion_channel = self.dispersion[channel] if self.dispersion.ndim > 0 else self.dispersion

            # Negative Binomial term
            theta_channel = 1.0 / (dispersion_channel + self.epsilon)
            t1 = torch.lgamma(true_counts_channel + theta_channel) - torch.lgamma(theta_channel) - torch.lgamma(true_counts_channel + 1)
            t2 = (theta_channel + true_counts_channel) * torch.log1p(predicted_counts_channel / theta_channel) + true_counts_channel * (torch.log(theta_channel) - torch.log(predicted_counts_channel + self.epsilon))

            nb_term = t1 + t2

            # Zero-Inflation term
            zero_inflation_term = predicted_zero_inflation_channel + F.softplus(-predicted_zero_inflation_channel - nb_term) - F.softplus(-nb_term)

            # Combine terms
            zero_inflation_mask = (true_counts_channel == 0).type(torch.float32)
            channel_loss = -torch.mean(zero_inflation_term * zero_inflation_mask + nb_term * (1 - zero_inflation_mask))

            # Weighted sum of the channel losses
            total_loss += channel_weights[channel] * channel_loss

        # Average loss over channels
        loss = total_loss / num_channels

        return loss
############## Training and auxilliary functions ##############


def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params

def r2_score(outputs, targets):
    """
    Function to compute R2 coefficient between predictions and targets.
    This allows for all operations to happen in the GPU )if available)
    without moving back and forth to the CPU.
    """
    target_mean = torch.mean(targets)
    ss_tot = torch.sum((targets - target_mean) ** 2)
    ss_res = torch.sum((targets - outputs) ** 2)
    r2 = 1 - ss_res / ss_tot
    return r2.cpu().detach().numpy()

# Train and validate the model
def train_with_validation(model, train_loader, val_loader, criterion, optimizer, num_epochs, device, savepath):
    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []

    logging.info("Training the model")
    for epoch in range(num_epochs):
        model.train()
        total_loss, total_r2, count = 0, 0, 0
        # Initialize progress bar
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False)
        for inputs, targets in train_loader:
            # Send data to GPU
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            # Calculate R^2 score
            total_r2 += r2_score(outputs, targets)
            count += 1
            # Update progress bar
            progress_bar.set_postfix({'train_loss': loss.item()})

        avg_train_loss = total_loss / count
        avg_train_r2 = total_r2 / count
        train_losses.append(avg_train_loss)
        train_accuracies.append(avg_train_r2)

        # Validation step
        val_loss, val_r2 = validate(model, val_loader, criterion, device)
        val_losses.append(val_loss)
        val_accuracies.append(val_r2)

        logging.info(f"Epoch [{epoch + 1}/{num_epochs}], Training Loss: {avg_train_loss:.4f}, Training R^2: {avg_train_r2:.4f}, Validation Loss: {val_loss:.4f}, Validation R^2: {val_r2:.4f}")

        # Close the progress bar at the end of the epoch
        progress_bar.close()

    # Save the final model
    torch.save(model, savepath / 'model_smoothl1.pt')

    return train_losses, val_losses, train_accuracies, val_accuracies


def validate(model, val_loader, criterion, device):
    model.eval()
    total_loss, total_r2, count = 0, 0, 0
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            total_loss += loss.item()
            total_r2 += r2_score(outputs, targets)
            count += 1
    avg_loss = total_loss / count
    avg_r2 = total_r2 / count
    return avg_loss, avg_r2

def plot_losses_accuracies(train_losses, val_losses, train_accuracies, val_accuracies, savepath):
    logging.info("Plotting results")
    fig = plt.figure(figsize=(12, 5))

    # Plot losses
    plt.subplot(1, 2, 1)
    plt.plot(np.arange(1,len(train_losses)+1), train_losses, label='Training Loss')
    plt.plot(np.arange(1,len(val_losses)+1), val_losses, label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    # Plot point-wise R2 evolution
    plt.subplot(1, 2, 2)
    plt.plot(np.arange(1,len(train_accuracies)+1),train_accuracies, label='Training R²')
    plt.plot(np.arange(1,len(val_accuracies)+1), val_accuracies, label='Validation R²')
    plt.title('Training and Validation R² Score')
    plt.xlabel('Epoch')
    plt.ylabel('R² Score')
    plt.legend()

    fig.savefig(savepath / "Reconstruction_metrics.png")


def test_model_with_predictions(model, test_loader, device):
    model.eval()  # Set the model to evaluation mode
    predictions = []
    targets_list = []  # if you also want to store true targets

    with torch.no_grad():  # No gradient calculation
        for inputs, targets in test_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            predictions.append(outputs.cpu().numpy())
            targets_list.append(targets.cpu().numpy())  # if storing targets

    return predictions, targets_list
########################## Script #########################
# Train model and test on unseen data
IS_HYENA = False # We will not train on HyenaDNA's embeddings
TRAIN = False
TEST = True
# Paths:
train_data_folder = "../inputs/DNA_sequences/training_embeddings_Enformer/"
val_data_folder = "../inputs/DNA_sequences/validation_embeddings_Enformer/"
test_data_folder = "../inputs/DNA_sequences/test_embeddings_Enformer/"

targets_file = '../targets/training_targets_Enformer.csv'
test_targets_file = '../targets/test_targets_Enformer.csv'
#savepath = Path("./results/")
savepath = Path("../benchmarks/Enformer/")
savepath.mkdir(parents=True, exist_ok=True)
# Hyperparameters:
input_length = 896 # Embeddings add 2 characters to the 160kbp long sequences
input_depth = 3072
hidden_depth = 10
output_size = 115 # 57kbp per side #input_length
n_conditions = 1
stack_type = 'hstack' #'vstack' # vstack when outputting directly from convolutions
kernel_size = 9
stride = 2
dilation = 2
padding = (dilation * (kernel_size - 1)) // 2 # proper padding
num_epochs = 20

batch_size = 64
dropout_rate = 0.4
lr = 1e-4
#################

if IS_HYENA:
    train_data_folder = "/projects/cbmr_shared/scratch/Hyena_DNA/training/"
    val_data_folder = "/projects/cbmr_shared/scratch/Hyena_DNA/validation/"
    test_data_folder = "/projects/cbmr_shared/scratch/Hyena_DNA/test/"

    targets_file = '/projects/cbmr_shared/scratch/Hyena_DNA/targets/benchmark_sequence_based_target_arrays_1kbp_57_bins_2_conditions_decareads_abs.csv'
    test_targets_file = '/projects/cbmr_shared/scratch/Hyena_DNA/targets/benchmark_sequence_based_target_arrays_1kbp_57_bins_2_conditions_test_decareads.csv'

    savepath = Path("/projects/cbmr_shared/scratch/Hyena_DNA/results/Hyena/")
    savepath.mkdir(parents=True, exist_ok=True)

    # Hyperparameters: 
    input_length = 160002 # Embeddings add 2 characters to the 160kbp long sequences
    input_depth = 256
    hidden_depth = 20
    output_size = 115 # 2 conditions from - N_kbt to + N_kbt (adding 0, central bin)
    n_conditions = 1
    stack_type = 'hstack' #'hstack'
    kernel_size = 9
    stride = 333
    num_epochs = 200
    dilation = 100
    padding = (dilation * (kernel_size - 1)) // 2  

    batch_size = 64
    avg_pool_size = (1,3)
    dropout_rate = .3
    lr = 1e-4


# Setup logging
logging.basicConfig(filename=str(savepath) + '/training.log', level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Device:{device}")
model = DNAConvandDense(input_depth=input_depth, input_length=input_length, output_length=n_conditions*output_size, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, dropout_prob=dropout_rate, hidden_depth=hidden_depth).to(device)
if IS_HYENA:
    model = ConvLinearModel(input_depth, input_length, n_conditions*output_size, kernel_size, stride, padding, dilation, dropout_rate, hidden_depth).to(device)
total_params, trainable_params = count_parameters(model)
logging.info(f"Total parameters: {total_params}")
logging.info(f"Trainable parameters: {trainable_params}")

if TRAIN:
    logging.info("Beginning training")

    train_dataset = EnformerHeadDataset(train_data_folder, targets_file, stack_type, kernel_size)
    val_dataset = EnformerHeadDataset(val_data_folder, targets_file, stack_type, kernel_size)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    criterion = torch.nn.SmoothL1Loss() #torch.nn.PoissonNLLLoss(log_input=False) #ZINBLoss() #zinb_loss_channelwise()#ZeroInflatedPoissonLoss() #torch.nn.PoissonNLLLoss(log_input=False)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    train_losses, val_losses, train_accuracies, val_accuracies= train_with_validation(model, train_loader, val_loader, criterion, optimizer, num_epochs=num_epochs, device=device, savepath=savepath)
    plot_losses_accuracies(train_losses, val_losses, train_accuracies, val_accuracies, savepath)

############# Test predictions #############
if TEST:
    #device = torch.device("cpu")
    logging.info("Beginning test predictions")

    model = torch.load(Path("../benchmarks/Enformer/model_smoothl1.pt"), map_location=device)
    model.to(device)

    test_dataset = EnformerHeadDataset(test_data_folder, test_targets_file, stack_type, kernel_size)
    test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False)

    predictions, targets = test_model_with_predictions(model, test_loader, device)

    # Save Enformer predictions and targets in numpy arrays:
    np.save(savepath / "enformer_test_predictions.npy",np.array(predictions), allow_pickle="False")
    np.save(savepath / "enformer_test_targets.npy",np.array(targets), allow_pickle="False")

Overwriting Enformer_head.py
