In [1]:
# |default_exp modelsconv


In [2]:


#| export
import sys
sys.path.append('/opt/slh/rna/')
import torch
from torch import nn, einsum
from einops import rearrange, repeat
from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb

import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import math
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
from torch_geometric.utils import degree
from torch_geometric.data import Data, Batch
import numpy as np
from torch_geometric.utils import to_dense_batch
from x_transformers import ContinuousTransformerWrapper, Encoder, TransformerWrapper
from torch_geometric.nn import GATConv, GCNConv
from rnacomp.models import CombinationTransformerEncoderV1, Block_conv, CombinationTransformerEncoderV29
from x_transformers import ContinuousTransformerWrapper, Encoder, TransformerWrapper
import matplotlib.pyplot as plt



def good_luck():
    return True

In [3]:
# | export
class conv_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(conv_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class up_conv(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.up(x)
        return x


class U_Net(nn.Module):
    def __init__(self, img_ch=3, output_ch=1, CH_FOLD2=1):
        super(U_Net, self).__init__()

        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(ch_in=img_ch, ch_out=int(32 * CH_FOLD2))
        self.Conv2 = conv_block(ch_in=int(32 * CH_FOLD2), ch_out=int(64 * CH_FOLD2))
        self.Conv3 = conv_block(ch_in=int(64 * CH_FOLD2), ch_out=int(128 * CH_FOLD2))
        self.Conv4 = conv_block(ch_in=int(128 * CH_FOLD2), ch_out=int(256 * CH_FOLD2))
        self.Conv5 = conv_block(ch_in=int(256 * CH_FOLD2), ch_out=int(512 * CH_FOLD2))

        self.Up5 = up_conv(ch_in=int(512 * CH_FOLD2), ch_out=int(256 * CH_FOLD2))
        self.Up_conv5 = conv_block(
            ch_in=int(512 * CH_FOLD2), ch_out=int(256 * CH_FOLD2)
        )

        self.Up4 = up_conv(ch_in=int(256 * CH_FOLD2), ch_out=int(128 * CH_FOLD2))
        self.Up_conv4 = conv_block(
            ch_in=int(256 * CH_FOLD2), ch_out=int(128 * CH_FOLD2)
        )

        self.Up3 = up_conv(ch_in=int(128 * CH_FOLD2), ch_out=int(64 * CH_FOLD2))
        self.Up_conv3 = conv_block(ch_in=int(128 * CH_FOLD2), ch_out=int(64 * CH_FOLD2))

        self.Up2 = up_conv(ch_in=int(64 * CH_FOLD2), ch_out=int(32 * CH_FOLD2))
        self.Up_conv2 = conv_block(ch_in=int(64 * CH_FOLD2), ch_out=int(32 * CH_FOLD2))

        self.Conv_1x1 = nn.Conv2d(
            int(32 * CH_FOLD2), output_ch, kernel_size=1, stride=1, padding=0
        )

    def forward(self, x):
        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)

        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        d5 = torch.cat((x4, d5), dim=1)

        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)
        d1 = d1.squeeze(1)
        out = torch.transpose(d1, -1, -2) * d1

        return out


class UnetWrapper2D(nn.Module):
    def __init__(self, md, output_chans=2):
        super().__init__()
        self.md = md
        self.output_chans = output_chans

    def do_forward(self, x, crop16, crop):
        out = torch.zeros(
            x.shape[0], self.output_chans, x.shape[2], x.shape[3], device=x.device
        )
        res = self.md(x[:, :, :crop16, :crop16])
        out[:, :, :crop, :crop] = res[:, :, :crop, :crop]
        return out

    def forward(self, xs, crop_to_16, crop_original, original_order):
        res = []
        for x, crop16, crop in zip(xs, crop_to_16, crop_original):
            res.append(self.do_forward(x, crop16, crop))
        return torch.cat(res)[original_order]


@torch.no_grad()
def generate_batches(tensor, nn_out):
    # Sort the tensor
    if tensor.unique().shape[0] == 1:
        return [nn_out], torch.arange(len(tensor), device=tensor.device)
    sorted_tensor, order = tensor.sort()
    nn_out = nn_out.index_select(0, order)

    # Find the change points
    diff = torch.cat([torch.tensor([1]), torch.diff(sorted_tensor)])
    change_indices = torch.where(diff != 0)[0]
    change_indices = torch.cat([change_indices, torch.tensor([len(tensor)])])
    b = [
        nn_out[change_indices[i] : change_indices[i + 1]]
        for i in range(len(change_indices) - 1)
    ]
    return b, order.argsort()


def make_pair_mask(seq, seq_len):
    encode_mask = torch.arange(seq.shape[1], device=seq.device).expand(
        seq.shape[:2]
    ) < seq_len.unsqueeze(1)
    pair_mask = encode_mask[:, None, :] * encode_mask[:, :, None]
    assert isinstance(pair_mask, torch.BoolTensor) or isinstance(
        pair_mask, torch.cuda.BoolTensor
    )
    return torch.bitwise_not(pair_mask)


class ScaledSinuEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = nn.Parameter(
            torch.ones(
                1,
            )
        )
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, x):
        n, device = x.shape[1], x.device
        t = torch.arange(n, device=device).type_as(self.inv_freq)
        sinu = einsum("i , j -> i j", t, self.inv_freq)
        emb = torch.cat((sinu.sin(), sinu.cos()), dim=-1)
        return emb * self.scale


class CustomEmbedding(nn.Module):
    def __init__(self, dim, vocab=4):
        super().__init__()
        self.embed_seq = nn.Embedding(vocab, dim)
        self.pos_enc = ScaledSinuEmbedding(dim)

    def forward(self, x):
        x = self.embed_seq(x)
        x = x + self.pos_enc(x)
        return x


class SeqToImage(nn.Module):
    def __init__(self, dim, vocab=4):
        super().__init__()
        self.embed_h = CustomEmbedding(dim=dim, vocab=vocab)
        self.embed_w = CustomEmbedding(dim=dim, vocab=vocab)
        self.norm = nn.LayerNorm(dim)

    def forward(self, seq, mask):
        seq_h = self.embed_h(seq)
        seq_w = self.embed_w(seq)
        x = seq_h.unsqueeze(1) + seq_w.unsqueeze(2)
        x = self.norm(x)
        x.masked_fill_(mask[:, :, :, None], 0.0)  # bs, h, w, dim
        x = x.permute(0, 3, 1, 2)  # bs, dim, h, w
        return x


class Attn_pool(nn.Module):
    def __init__(self, n):
        super().__init__()
        self.conv = nn.Conv2d(n, n, 1)
        self.attn = nn.Conv2d(n, n, 1)

    def forward(self, x, key_padding_mask=None):
        emb = self.conv(x)
        attn = self.attn(x)

        # Apply the mask to attention scores before softmax
        if key_padding_mask is not None:
            attn = attn.masked_fill(key_padding_mask.unsqueeze(1), float("-inf"))
        # attn = torch.clamp(attn, min=-1e9, max=1e9)
        attn = attn.softmax(dim=-1)
        x = (emb * attn).sum(-1)
        return x, attn


class FeedForwardV5(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.2, out=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, out),
        )

    def forward(self, x):
        return self.net(x)


class RnaModelConvV0(nn.Module):
    def __init__(self, embed_size, conv_out=8, vecob_size=4):
        super().__init__()
        self.seq_to_image = SeqToImage(embed_size, vocab=vecob_size)
        self.md = UnetWrapper2D(U_Net(embed_size, conv_out), conv_out)
        self.attnpool = Attn_pool(conv_out)
        self.out = FeedForwardV5(conv_out, conv_out, out=2)

    def forward(self, batch):
        L_seq = batch["mask"].sum(1)
        L0 = batch["mask"].shape[1]

        crop_to_original = L_seq.unique()  # unique lengths
        crop_to_16 = [
            ((i // 16) + 1) * 16 for i in crop_to_original
        ]  # for each unique length, find the nearest 16 miltip
        seq = batch["seq"][:, : crop_to_16[-1]]  # shortening to largest 16 multiple

        # make a square mask [bs, crop_to_16[-1], crop_to_16[-1]]  #crop_to_16[-1] is the largest 16 multiple
        square_mask = make_pair_mask(seq, L_seq)

        x = self.seq_to_image(seq, square_mask)
        x, idc = generate_batches(L_seq, x)
        x = self.md(x, crop_to_16, crop_to_original, idc)
        x, attn = self.attnpool(x, square_mask)
        x = self.out(x.permute(0, 2, 1))
        x = F.pad(x, (0, 0, 0, L0 - x.shape[1], 0, 0))
        return x


class SELayer(nn.Module):
    def __init__(self, inp, oup, reduction=4):
        super().__init__()

        # self.avg_pool = nn.AdaptiveAvgPool1d(1)

        self.fc = nn.Sequential(
            nn.Linear(oup, int(inp // reduction)),
            nn.SiLU(),
            nn.Linear(int(inp // reduction), oup),
            # Concater(Bilinear(int(inp // reduction), int(inp // reduction // 2), rank=0.5, bias=True)),
            # nn.SiLU(),
            # nn.Linear(int(inp // reduction) +  int(inp // reduction // 2), oup),
            nn.Sigmoid(),
        )

    def forward(self, x):
        (
            b,
            c,
            _,
        ) = x.size()
        y = x.view(b, c, -1).mean(dim=2)
        y = self.fc(y).view(b, c, 1)
        return x * y


class Conv1D(nn.Conv1d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.src_key_padding_mask = None

    def forward(self, x, src_key_padding_mask=None):
        if src_key_padding_mask is not None:
            self.src_key_padding_mask = src_key_padding_mask
        if self.src_key_padding_mask is not None:
            x = torch.where(
                self.src_key_padding_mask.unsqueeze(-1)
                .expand(-1, -1, x.shape[-1])
                .bool(),
                torch.zeros_like(x),
                x,
            )

        return super().forward(x.permute(0, 2, 1)).permute(0, 2, 1)


class ResBlock(nn.Sequential):
    def __init__(self, d_model):
        super().__init__(
            nn.LayerNorm(d_model), nn.GELU(), Conv1D(d_model, d_model, 3, padding=1)
        )
        self.src_key_padding_mask = None

    def forward(self, x, src_key_padding_mask=None):
        self[-1].src_key_padding_mask = (
            src_key_padding_mask
            if src_key_padding_mask is not None
            else self.src_key_padding_mask
        )
        return x + super().forward(x)


class Extractor(nn.Sequential):
    def __init__(self, d_model, in_ch=4):
        super().__init__(
            nn.Embedding(in_ch, d_model // 4),
            Conv1D(d_model // 4, d_model, 7, padding=3),
            ResBlock(d_model),
        )

    def forward(self, x, src_key_padding_mask=None):
        for i in [1, 2]:
            self[i].src_key_padding_mask = src_key_padding_mask
        return super().forward(x)


class LocalBlock(nn.Module):
    def __init__(self, in_ch, ks, activation, out_ch=None):
        super().__init__()
        self.in_ch = in_ch
        self.out_ch = self.in_ch if out_ch is None else out_ch
        self.ks = ks

        self.block = nn.Sequential(
            nn.Conv1d(
                in_channels=self.in_ch,
                out_channels=self.out_ch,
                kernel_size=self.ks,
                padding="same",
                bias=False,
            ),
            nn.BatchNorm1d(self.out_ch),
            activation(),
        )

    def forward(self, x):
        return self.block(x)


class EffBlock(nn.Module):
    def __init__(
        self,
        in_ch,
        ks,
        resize_factor,
        filter_per_group,
        activation,
        out_ch=None,
        se_reduction=None,
        se_type="simple",
        inner_dim_calculation="out",
    ):
        super().__init__()
        self.in_ch = in_ch
        self.out_ch = self.in_ch if out_ch is None else out_ch
        self.resize_factor = resize_factor
        self.se_reduction = resize_factor if se_reduction is None else se_reduction
        self.ks = ks
        self.inner_dim_calculation = inner_dim_calculation

        """
        `in` refers to the original method of EfficientNetV2 to set the dimensionality of the EfficientNetV2-like block
        `out` is the mode used in the original LegNet approach

        This parameter slighly changes the mechanism of channel number calculation 
        which can be seen in the figure above (C, channel number is highlighted in red).
        """
        if inner_dim_calculation == "out":
            self.inner_dim = self.out_ch * self.resize_factor
        elif inner_dim_calculation == "in":
            self.inner_dim = self.in_ch * self.resize_factor
        else:
            raise Exception(f"Wrong inner_dim_calculation: {inner_dim_calculation}")

        self.filter_per_group = filter_per_group

        se_constructor = SELayer

        block = nn.Sequential(
            nn.Conv1d(
                in_channels=self.in_ch,
                out_channels=self.inner_dim,
                kernel_size=1,
                padding="same",
                bias=False,
            ),
            nn.BatchNorm1d(self.inner_dim),
            activation(),
            nn.Conv1d(
                in_channels=self.inner_dim,
                out_channels=self.inner_dim,
                kernel_size=ks,
                groups=self.inner_dim // self.filter_per_group,
                padding="same",
                bias=False,
            ),
            nn.BatchNorm1d(self.inner_dim),
            activation(),
            se_constructor(
                self.in_ch, self.inner_dim, reduction=self.se_reduction
            ),  # self.in_ch is not good
            nn.Conv1d(
                in_channels=self.inner_dim,
                out_channels=self.in_ch,
                kernel_size=1,
                padding="same",
                bias=False,
            ),
            nn.BatchNorm1d(self.in_ch),
            activation(),
        )

        self.block = block

    def forward(self, x):
        return self.block(x)


"""
The `activation()` in the optimized architecture simply equals `nn.Identity`
In the original LegNet approach it was `nn.SiLU`
"""


class MappingBlock(nn.Module):
    def __init__(self, in_ch, out_ch, activation):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv1d(
                in_channels=in_ch,
                out_channels=out_ch,
                kernel_size=1,
                padding="same",
            ),
            activation(),
        )

    def forward(self, x):
        return self.block(x)


class ResidualConcat(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return torch.concat([self.fn(x, **kwargs), x], dim=1)


import torch.nn.functional as F

from typing import Type


class LegNet(nn.Module):
    __constants__ = "resize_factor"

    def __init__(
        self,
        input_size: int,
        use_single_channel: bool,
        use_reverse_channel: bool,
        block_sizes: list[int] = [256, 128, 128, 64, 64, 64, 64],
        ks: int = 7,
        resize_factor: int = 4,
        activation: Type[nn.Module] = nn.SiLU,
        final_activation: Type[nn.Module] = nn.Identity,
        filter_per_group: int = 1,
        se_reduction: int = 4,
        res_block_type: str = "concat",
        se_type: str = "simple",
        inner_dim_calculation: str = "in",
    ):
        super().__init__()
        self.input_size = input_size
        self.block_sizes = block_sizes
        self.resize_factor = resize_factor
        self.se_reduction = se_reduction
        self.use_single_channel = use_single_channel
        self.use_reverse_channel = use_reverse_channel
        self.filter_per_group = filter_per_group
        self.final_ch = 2  # number of bins in the competition
        self.inner_dim_calculation = inner_dim_calculation
        self.res_block_type = res_block_type

        residual = ResidualConcat

        self.stem_block = LocalBlock(
            in_ch=self.in_channels, out_ch=block_sizes[0], ks=ks, activation=activation
        )

        blocks = []
        for ind, (prev_sz, sz) in enumerate(zip(block_sizes[:-1], block_sizes[1:])):
            block = nn.Sequential(
                residual(
                    EffBlock(
                        in_ch=prev_sz,
                        out_ch=sz,
                        ks=ks,
                        resize_factor=4,
                        activation=activation,
                        filter_per_group=self.filter_per_group,
                        se_type=se_type,
                        inner_dim_calculation=inner_dim_calculation,
                    )
                ),
                LocalBlock(in_ch=2 * prev_sz, out_ch=sz, ks=ks, activation=activation),
            )
            blocks.append(block)

        self.main = nn.Sequential(*blocks)

        self.mapper = MappingBlock(
            in_ch=block_sizes[-1], out_ch=self.final_ch, activation=final_activation
        )

    @property
    def in_channels(self) -> int:
        return self.input_size

    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = self.stem_block(x)
        x = self.main(x)
        x = x.permute(0, 2, 1)
        return x


class RnaModelConvV1(nn.Module):
    def __init__(self, dim=192, block_sizes=[256, 128, 128, 64, 64, 64, 64]):
        super().__init__()
        self.extractor = Extractor(dim)
        self.cnn_model = LegNet(
            input_size=dim,
            use_single_channel=False,
            use_reverse_channel=False,
            block_sizes=block_sizes,
        )
        self.proj_out = nn.Sequential(nn.Linear(block_sizes[-1], 2))

    def forward(self, x0):
        mask = x0["mask"]
        L0 = mask.shape[1]
        Lmax = mask.sum(-1).max()
        mask = mask[:, :Lmax]
        x = x0["seq"][:, :Lmax]

        x = self.extractor(x, src_key_padding_mask=~mask)
        x = self.cnn_model(x)
        x = self.proj_out(x)
        x = F.pad(x, (0, 0, 0, L0 - Lmax, 0, 0))
        return x


class EffBlockV2(nn.Sequential):
    def __init__(
        self,
        in_ch,
        ks,
        resize_factor,
        filter_per_group,
        activation,
        out_ch=None,
        se_reduction=None,
        se_type="simple",
        inner_dim_calculation="out",
        dropout=0.2,
    ):
        self.in_ch = in_ch
        self.out_ch = self.in_ch if out_ch is None else out_ch
        self.resize_factor = resize_factor
        self.se_reduction = resize_factor if se_reduction is None else se_reduction
        self.ks = ks
        self.inner_dim_calculation = inner_dim_calculation

        if inner_dim_calculation == "out":
            self.inner_dim = self.out_ch * self.resize_factor
        elif inner_dim_calculation == "in":
            self.inner_dim = self.in_ch * self.resize_factor
        else:
            raise Exception(f"Wrong inner_dim_calculation: {inner_dim_calculation}")

        self.filter_per_group = filter_per_group

        super().__init__(
            Conv1D(
                in_channels=self.in_ch,
                out_channels=self.inner_dim,
                kernel_size=1,
                padding="same",
                bias=False,
            ),
            nn.LayerNorm(self.inner_dim),
            activation(),
            nn.Dropout(dropout),
            Conv1D(
                in_channels=self.inner_dim,
                out_channels=self.inner_dim,
                kernel_size=ks,
                groups=self.inner_dim // self.filter_per_group,
                padding="same",
                bias=False,
            ),
            nn.LayerNorm(self.inner_dim),
            activation(),
            nn.Dropout(dropout),
            Conv1D(
                in_channels=self.inner_dim,
                out_channels=self.in_ch,
                kernel_size=1,
                padding="same",
                bias=False,
            ),
            nn.LayerNorm(self.in_ch),
            activation(),
            nn.Dropout(dropout),
        )

        self.src_key_padding_mask = None

    def forward(self, x, src_key_padding_mask=None):
        for i in [0, 3, 6]:
            self[i].src_key_padding_mask = src_key_padding_mask
        return super().forward(x)


class LocalBlockV2(nn.Sequential):
    def __init__(self, in_ch, ks, activation, dropout=0.2, out_ch=None):
        self.in_ch = in_ch
        self.out_ch = self.in_ch if out_ch is None else out_ch
        self.ks = ks

        super().__init__(
            Conv1D(
                in_channels=self.in_ch,
                out_channels=self.out_ch,
                kernel_size=self.ks,
                padding="same",
                bias=False,
            ),
            nn.LayerNorm(self.out_ch),
            activation(),
            nn.Dropout(dropout),
        )

    def forward(self, x, src_key_padding_mask=None):
        for i in [0]:
            self[i].src_key_padding_mask = src_key_padding_mask
        return super().forward(x)


class ConvolutionConcatBlockV2(nn.Module):
    def __init__(
        self,
        in_ch=256,
        ks=7,
        resize_factor=4,
        filter_per_group=1,
        activation=nn.GELU,
        out_ch=None,
        dropout=0.2,
    ):
        super().__init__()
        self.effblock = EffBlockV2(
            in_ch=in_ch,
            ks=ks,
            resize_factor=resize_factor,
            filter_per_group=filter_per_group,
            activation=activation,
            out_ch=out_ch,
            dropout=dropout,
        )

        self.localblock = LocalBlockV2(
            in_ch=in_ch * 2,
            ks=ks,
            activation=activation,
            out_ch=out_ch,
            dropout=dropout,
        )

    def forward(self, x, src_key_padding_mask=None):
        res = x
        x = self.effblock(x, src_key_padding_mask=src_key_padding_mask)
        x = torch.cat([x, res], dim=-1)
        x = self.localblock(x, src_key_padding_mask=src_key_padding_mask)
        return x


class CustomConvdV2(nn.Module):
    def __init__(
        self,
        dim: int,
        block_sizes: list[int] = [256, 128, 128, 64, 64, 64, 64],
        ks: int = 7,
        resize_factor: int = 4,
        activation: Type[nn.Module] = nn.SiLU,
        dropout=0.2,
    ):
        super().__init__()

        self.stem_block = LocalBlockV2(
            in_ch=dim, out_ch=block_sizes[0], ks=ks, activation=activation
        )

        self.blocks = nn.ModuleList()
        for ind, (prev_sz, sz) in enumerate(zip(block_sizes[:-1], block_sizes[1:])):
            block = ConvolutionConcatBlockV2(
                in_ch=prev_sz,
                out_ch=sz,
                ks=ks,
                resize_factor=resize_factor,
                activation=activation,
                dropout=dropout,
            )
            self.blocks.append(block)

    def forward(self, x, src_key_padding_mask=None):
        x = self.stem_block(x, src_key_padding_mask=src_key_padding_mask)
        for block in self.blocks:
            x = block(x, src_key_padding_mask=src_key_padding_mask)
        return x


class RnaModelConvV2(nn.Module):
    def __init__(
        self,
        dim=192,
        resize_factor=4,
        ks=7,
        activation=nn.SiLU,
        head_size=32,
        drop_pat_dropout=0.2,
        dropout=0.2,
        bpp_transfomer_depth=3,
    ):
        super().__init__()
        block_sizes = [dim, dim, dim, dim, dim]
        self.extractor = Extractor(dim)
        self.cnn_model = CustomConvdV2(
            dim=dim,
            block_sizes=block_sizes,
            resize_factor=resize_factor,
            ks=ks,
            activation=activation,
            dropout=0.0,
        )

        self.bb_comb_blocks = nn.ModuleList(
            [
                CombinationTransformerEncoderV1(
                    dim // 2,
                    head_size=head_size,
                    dropout=dropout,
                    drop_path=drop_pat_dropout * (i / (bpp_transfomer_depth - 1)),
                )
                for i in range(bpp_transfomer_depth)
            ]
        )

        self.proj_out = nn.Linear(dim // 2, 2)

    def forward(self, x0):
        mask = x0["mask"]
        L0 = mask.shape[1]
        Lmax = mask.sum(-1).max()
        mask = mask[:, :Lmax]
        x = x0["seq"][:, :Lmax]

        bpp = x0["bb_matrix_full_prob"][:, :Lmax, :Lmax]
        bpp_extra = x0["bb_matrix_full_prob_extra"][:, :Lmax, :Lmax].float()
        ss = x0["ss_adj"][:, :Lmax, :Lmax].float()

        x = self.extractor(x, src_key_padding_mask=~mask)
        x = self.cnn_model(x, src_key_padding_mask=~mask)

        for i, blk in enumerate(self.bb_comb_blocks):
            x = blk(x, bpp, bpp_extra, ss, mask)

        x = self.proj_out(x)
        x = F.pad(x, (0, 0, 0, L0 - Lmax, 0, 0))
        return x


class ConvolutionConcatBlockV3(nn.Module):
    def __init__(
        self,
        in_ch=256,
        ks=7,
        resize_factor=4,
        filter_per_group=1,
        activation=nn.GELU,
        out_ch=None,
        dropout=0.2,
        head_size=32,
        dropout_trasnfomer=0.2,
        drop_path=0.2,
    ):
        super().__init__()
        self.effblock = EffBlock(
            in_ch=in_ch,
            out_ch=out_ch,
            ks=ks,
            resize_factor=4,
            activation=activation,
            filter_per_group=filter_per_group,
            se_type="simple",
            inner_dim_calculation="out",
        )
        self.t_block = CombinationTransformerEncoderV1(
            in_ch,
            head_size=head_size,
            dropout=dropout_trasnfomer,
            drop_path=drop_path,
        )

    def forward(self, x, bpp, bpp_extra, ss, mask):
        res = x
        x = self.effblock(x.permute(0, 2, 1)).permute(0, 2, 1)
        x = self.t_block(x, bpp, bpp_extra, ss, mask)
        return x + res


# class RnaModelConvV3(nn.Module):
#     def __init__(
#         self,
#         dim=192,
#         resize_factor=4,
#         ks=7,
#         activation=nn.SiLU,
#         head_size=32,
#         drop_pat_dropout=0.2,
#         dropout=0.2,
#         transformer_depth=6,
#     ):
#         super().__init__()
#         block_sizes = [dim, dim, dim, dim, dim]
#         self.extractor = Extractor(dim)

#         self.blocks = nn.ModuleList()
#         for ind, (prev_sz, sz) in enumerate(zip(block_sizes[:-1], block_sizes[1:])):
#             block = ConvolutionConcatBlockV3(
#                 in_ch=prev_sz,
#                 out_ch=sz,
#                 ks=ks,
#                 resize_factor=resize_factor,
#                 activation=activation,
#                 dropout=dropout,
#                 head_size=32,
#                 dropout_trasnfomer=0.2,
#                 drop_path=drop_pat_dropout * (ind / (len(block_sizes) - 1)),
#             )
#             self.blocks.append(block)

#         self.enc = ContinuousTransformerWrapper(
#             dim_in=dim * 2,
#             dim_out=2,
#             max_seq_len=512,
#             attn_layers=Encoder(
#                 dim=dim,
#                 depth=transformer_depth,
#                 attn_flash=True,
#                 rotary_pos_emb=True,
#                 attn_gate_values=True,
#                 attn_head_scale=True,
#                 ff_post_act_ln=True,
#                 attn_qk_norm=True,
#                 attn_qk_norm_dim_scale=True,
#                 layer_dropout=drop_pat_dropout,  # stochastic depth - dropout entire layer
#                 attn_dropout=dropout,  # dropout post-attention
#             ),
#         )

#     def forward(self, x0):
#         mask = x0["mask"]
#         L0 = mask.shape[1]
#         Lmax = mask.sum(-1).max()
#         mask = mask[:, :Lmax]
#         x = x0["seq"][:, :Lmax]

#         bpp = x0["bb_matrix_full_prob"][:, :Lmax, :Lmax]
#         bpp_extra = x0["bb_matrix_full_prob_extra"][:, :Lmax, :Lmax].float()
#         ss = x0["ss_adj"][:, :Lmax, :Lmax].float()

#         x = self.extractor(x, src_key_padding_mask=~mask)
#         res = x
#         for i, blk in enumerate(self.blocks):
#             x = blk(x, bpp, bpp_extra, ss, mask)
#         x = torch.concat([x, res], -1)
#         x = self.enc(x, mask=mask)
#         x = F.pad(x, (0, 0, 0, L0 - Lmax, 0, 0))
#         return x

class RnaModelConvV3(nn.Module):
    def __init__(
        self,
        dim=192,
        resize_factor=4,
        ks=7,
        activation=nn.SiLU,
        head_size=32,
        drop_pat_dropout=0.2,
        dropout=0.2,
        transformer_depth=6,
    ):
        super().__init__()
        block_sizes = [dim, dim, dim, dim, dim]
        self.extractor = Extractor(dim)

        self.blocks = nn.ModuleList()
        for ind, (prev_sz, sz) in enumerate(zip(block_sizes[:-1], block_sizes[1:])):
            block = ConvolutionConcatBlockV3(
                in_ch=prev_sz,
                out_ch=sz,
                ks=ks,
                resize_factor=resize_factor,
                activation=activation,
                dropout=dropout,
                head_size=32,
                dropout_trasnfomer=0.2,
                drop_path=drop_pat_dropout * (ind / (len(block_sizes) - 1)),
            )
            self.blocks.append(block)

        self.enc = ContinuousTransformerWrapper(
            dim_in=dim,
            dim_out=2,
            max_seq_len=512,
            attn_layers=Encoder(
                dim=dim,
                depth=transformer_depth,
                attn_flash=True,
                rotary_pos_emb=True,
                attn_gate_values=True,
                attn_head_scale=True,
                ff_post_act_ln=True,
                attn_qk_norm=True,
                attn_qk_norm_dim_scale=True,
                layer_dropout=drop_pat_dropout,  # stochastic depth - dropout entire layer
                attn_dropout=dropout,  # dropout post-attention
            ),
        )

    def forward(self, x0):
        mask = x0["mask"]
        L0 = mask.shape[1]
        Lmax = mask.sum(-1).max()
        mask = mask[:, :Lmax]
        x = x0["seq"][:, :Lmax]

        bpp = x0["bb_matrix_full_prob"][:, :Lmax, :Lmax]
        bpp_extra = x0["bb_matrix_full_prob_extra"][:, :Lmax, :Lmax].float()
        ss = x0["ss_adj"][:, :Lmax, :Lmax].float()

        x = self.extractor(x, src_key_padding_mask=~mask)
        res = x
        for i, blk in enumerate(self.blocks):
            x = blk(x, bpp, bpp_extra, ss, mask)
        x = self.enc(x + res, mask=mask)
        x = F.pad(x, (0, 0, 0, L0 - Lmax, 0, 0))
        return x
    
    
    
    
class ConvolutionConcatBlockV4(nn.Module):
    def __init__(
        self,
        in_ch=256,
        ks=7,
        resize_factor=4,
        filter_per_group=1,
        activation=nn.GELU,
        out_ch=None,
        dropout=0.2,
        head_size=32,
        dropout_trasnfomer=0.2,
        drop_path=0.2,
    ):
        super().__init__()
        self.effblock = EffBlock(
            in_ch=in_ch,
            out_ch=out_ch,
            ks=ks,
            resize_factor=4,
            activation=activation,
            filter_per_group=filter_per_group,
            se_type="simple",
            inner_dim_calculation="out",
        )
        self.t_block = CombinationTransformerEncoderV29(
            in_ch,
            head_size=head_size,
            dropout=dropout_trasnfomer,
            drop_path=drop_path,
        )

    def forward(self, x, bpp, mask):
        res = x
        x = self.t_block(x, bpp, mask)
        x = self.effblock(x.permute(0, 2, 1)).permute(0, 2, 1)
        return x + res


class RnaModelConvV4(nn.Module):
    def __init__(
        self,
        dim=192,
        resize_factor=4,
        ks=7,
        activation=nn.SiLU,
        head_size=32,
        drop_pat_dropout=0.2,
        dropout=0.2,
        transformer_depth=12,
    ):
        super().__init__()
        block_sizes = [dim, dim, dim, dim, dim, dim]
        self.extractor = Extractor(dim)

        self.blocks = nn.ModuleList()
        for ind, (prev_sz, sz) in enumerate(zip(block_sizes[:-1], block_sizes[1:])):
            block = ConvolutionConcatBlockV4(
                in_ch=prev_sz,
                out_ch=sz,
                ks=ks,
                resize_factor=resize_factor,
                activation=activation,
                dropout=dropout,
                head_size=32,
                dropout_trasnfomer=0.2,
                drop_path=drop_pat_dropout * (ind / (len(block_sizes) - 1)),
            )
            self.blocks.append(block)
            
        self.t_blocks = nn.ModuleList(
            [
                Block_conv(
                    dim=dim,
                    num_heads=dim // head_size,
                    mlp_ratio=4,
                    drop_path=drop_pat_dropout * (i / (transformer_depth - 1)),
                    init_values=1,
                    drop=dropout,
                )
                for i in range(transformer_depth)
            ]
        )

        self.proj_out = nn.Sequential(nn.Linear(dim, 2))


    def forward(self, x0):
        mask = x0["mask"]
        L0 = mask.shape[1]
        Lmax = mask.sum(-1).max()
        mask = mask[:, :Lmax]
        x = x0["seq"][:, :Lmax]


        bpp_extra = x0["bb_matrix_full_prob_extra"][:, :Lmax, :Lmax].float()


        x = self.extractor(x, src_key_padding_mask=~mask)
        for i, blk in enumerate(self.blocks):
            x = blk(x, bpp_extra,  mask)

        for i, blk in enumerate(self.t_blocks):
            x = blk(x, key_padding_mask=~mask)

        x = self.proj_out(x)
        x = F.pad(x, (0, 0, 0, L0 - Lmax, 0, 0))
        return x

In [4]:
embed_size = 16
out_chan_unet = 32
md = RnaModelConvV4(dim=192*2).eval()
batch = torch.load("batch.pt")
with torch.no_grad():
    out = md(batch)

In [5]:
#|hide
#|eval: false
from nbdev.doclinks import nbdev_export
nbdev_export()

In [6]:
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input

In [7]:
class FlashAttention2d(nn.Module):
    def __init__(
        self,
        model_dim,
        num_head,
        softmax_scale,
        zero_init,
        use_bias,
        initializer_range,
        n_layers,
    ):
        super().__init__()
        assert model_dim % num_head == 0
        assert model_dim % num_head == 0
        self.key_dim = model_dim // num_head
        self.value_dim = model_dim // num_head

        self.causal = False
        self.checkpointing = False

        if softmax_scale:
            self.softmax_scale = self.key_dim ** (-0.5)
        else:
            self.softmax_scale = None

        self.num_head = num_head

        self.Wqkv = nn.Linear(model_dim, 3 * model_dim, bias=use_bias)

        self.out_proj = nn.Linear(model_dim, model_dim, bias=use_bias)

        self.initialize(zero_init, use_bias, initializer_range, n_layers)

    def initialize(self, zero_init, use_bias, initializer_range, n_layers):
        nn.init.normal_(self.Wqkv.weight, mean=0.0, std=initializer_range)

        if use_bias:
            nn.init.constant_(self.Wqkv.bias, 0.0)
            nn.init.constant_(self.out_proj.bias, 0.0)

        if zero_init:
            nn.init.constant_(self.out_proj.weight, 0.0)
        else:
            nn.init.normal_(
                self.out_proj.weight,
                mean=0.0,
                std=initializer_range / math.sqrt(2 * n_layers),
            )

    def forward(self, pair_act, attention_mask):
        batch_size = pair_act.shape[0]
        seqlen = pair_act.shape[1]
        extended_batch_size = batch_size * seqlen

        qkv = self.Wqkv(pair_act)
        not_attention_mask = torch.logical_not(attention_mask)

        x_qkv = rearrange(
            qkv, "b s f ... -> (b s) f ...", b=batch_size, f=seqlen, s=seqlen
        )
        key_padding_mask = rearrange(
            not_attention_mask,
            "b s f ... -> (b s) f ...",
            b=batch_size,
            f=seqlen,
            s=seqlen,
        )

        x_unpad, indices, cu_seqlens, max_s = unpad_input(x_qkv, key_padding_mask)
        x_unpad = rearrange(
            x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=self.num_head
        )

        output_unpad = flash_attn_varlen_qkvpacked_func(
            x_unpad,
            cu_seqlens,
            max_s,
            0.0,
            softmax_scale=self.softmax_scale,
            causal=self.causal,
        )

        pre_pad_latent = rearrange(output_unpad, "nnz h d -> nnz (h d)")
        padded_latent = pad_input(pre_pad_latent, indices, extended_batch_size, seqlen)
        output = rearrange(padded_latent, "b f (h d) -> b f h d", h=self.num_head)

        output = rearrange(
            output, "(b s) f h d -> b s f (h d)", b=batch_size, f=seqlen, s=seqlen
        )

        return self.out_proj(output)


class TriangleAttention(nn.Module):
    def __init__(
        self,
        model_dim,
        num_head,
        orientation,
        softmax_scale,
        precision,
        zero_init,
        use_bias,
        initializer_range,
        n_layers,
    ):
        super().__init__()

        self.model_dim = model_dim
        self.num_head = num_head

        assert orientation in ["per_row", "per_column"]
        self.orientation = orientation

        self.input_norm = nn.LayerNorm(model_dim, eps=1e-6)

        self.attn = FlashAttention2d(
            model_dim,
            num_head,
            softmax_scale,
            zero_init,
            use_bias,
            initializer_range,
            n_layers,
        )

    def forward(self, pair_act, pair_mask, cycle_infer=False):
        assert len(pair_act.shape) == 4

        if self.orientation == "per_column":
            pair_act = torch.swapaxes(pair_act, -2, -3)
            if pair_mask is not None:
                pair_mask = torch.swapaxes(pair_mask, -1, -2)

        pair_act = self.input_norm(pair_act)

        if self.training and not cycle_infer:
            pair_act = checkpoint(self.attn, pair_act, pair_mask, use_reentrant=True)
        else:
            pair_act = self.attn(pair_act, pair_mask)

        if self.orientation == "per_column":
            pair_act = torch.swapaxes(pair_act, -2, -3)

        return pair_act


class ConvFeedForward(nn.Module):
    def __init__(
        self,
        model_dim,
        ff_dim,
        use_bias,
        initializer_range,
        n_layers,
        kernel,
        zero_init=True,
    ):
        super(ConvFeedForward, self).__init__()

        self.zero_init = zero_init

        self.input_norm = nn.GroupNorm(1, model_dim)

        if kernel == 1:
            self.conv1 = nn.Conv2d(model_dim, ff_dim, kernel_size=1, bias=use_bias)
            self.conv2 = nn.Conv2d(ff_dim, model_dim, kernel_size=1, bias=use_bias)
        else:
            self.conv1 = nn.Conv2d(
                model_dim,
                ff_dim,
                bias=use_bias,
                kernel_size=kernel,
                padding=(kernel - 1) // 2,
            )
            self.conv2 = nn.Conv2d(
                ff_dim,
                model_dim,
                bias=use_bias,
                kernel_size=kernel,
                padding=(kernel - 1) // 2,
            )

        self.act = nn.SiLU()

        self.initialize(zero_init, use_bias, initializer_range, n_layers)

    def initialize(self, zero_init, use_bias, initializer_range, n_layers):
        nn.init.normal_(self.conv1.weight, mean=0.0, std=initializer_range)

        if use_bias:
            nn.init.constant_(self.conv1.bias, 0.0)
            nn.init.constant_(self.conv2.bias, 0.0)

        if zero_init:
            nn.init.constant_(self.conv2.weight, 0.0)
        else:
            nn.init.normal_(
                self.conv2.weight,
                mean=0.0,
                std=initializer_range / math.sqrt(2 * n_layers),
            )

    def forward(self, x):
        x = x.permute(0, 3, 1, 2)

        x = self.input_norm(x)
        x = self.act(self.conv1(x))
        x = self.conv2(x)
        x = x.permute(0, 2, 3, 1)
        return x


class BlockTrinagle(nn.Module):
    def __init__(self, config):
        super().__init__()

        ff_dim = int(config.ff_factor * config.model_dim)

        self.attn_pair_row = TriangleAttention(
            config.model_dim,
            config.num_head,
            "per_row",
            config.softmax_scale,
            config.precision,
            config.zero_init,
            config.use_bias,
            config.initializer_range,
            config.n_layers,
        )
        self.attn_pair_col = TriangleAttention(
            config.model_dim,
            config.num_head,
            "per_column",
            config.softmax_scale,
            config.precision,
            config.zero_init,
            config.use_bias,
            config.initializer_range,
            config.n_layers,
        )

        self.pair_dropout_row = nn.Dropout(p=config.resi_dropout / 2)
        self.pair_dropout_col = nn.Dropout(p=config.resi_dropout / 2)

        self.pair_transition = ConvFeedForward(
            config.model_dim,
            ff_dim,
            use_bias=config.use_bias,
            kernel=config.ff_kernel,
            initializer_range=config.initializer_range,
            zero_init=config.zero_init,
            n_layers=config.n_layers,
        )

        self.res_dropout = nn.Dropout(p=config.resi_dropout)

    def forward(self, pair_act, pair_mask, cycle_infer=False):
        pair_act = pair_act + self.pair_dropout_row(
            self.attn_pair_row(pair_act, pair_mask, cycle_infer)
        )
        pair_act = pair_act + self.pair_dropout_col(
            self.attn_pair_col(pair_act, pair_mask, cycle_infer)
        )
        pair_act = pair_act + self.res_dropout(self.pair_transition(pair_act))

        return pair_act
    

In [8]:
class config:
  cycling = False
  ff_factor = 4
  ff_kernel = 3
  initializer_range = 0.02
  model_dim = 32
  n_layers= 1
  num_head= 2
  resi_dropout = 0.1
  softmax_scale = True
  use_bias =  True
  use_glu =  False
  zero_init = False
  precision = 16
  
def make_pair_mask(src, src_len):
    encode_mask = torch.arange(src.shape[1], device=src.device).expand(
        src.shape[:2]
    ) < src_len.unsqueeze(1)
    pair_mask = encode_mask[:, None, :] * encode_mask[:, :, None]
    assert isinstance(pair_mask, torch.BoolTensor) or isinstance(
        pair_mask, torch.cuda.BoolTensor
    )
    return torch.bitwise_not(pair_mask)


batch = torch.load("batch.pt")
mask = batch["mask"]
L0 = mask.shape[1]
Lmax = mask.sum(-1).max()
mask = mask[:, :Lmax]
x = batch["seq"][:, :Lmax]
embed = nn.Embedding(4, 32).eval()
with torch.no_grad():
    embeddings = embed(x)

model = BlockTrinagle(config).eval()
pair_latent = embeddings.unsqueeze(1) + embeddings.unsqueeze(2)
pair_mask = make_pair_mask(mask, mask.sum(-1))
pair_latent.masked_fill_(pair_mask[:, :, :, None], 0.0)
pair_latent.shape

torch.Size([16, 170, 170, 32])

In [9]:
#  with torch.no_grad(),torch.cuda.amp.autocast():
#     latent = model(pair_act=pair_latent, pair_mask=pair_mask, cycle_infer=False)

In [None]:
pair_latent

torch.Size([16, 170, 170])

In [None]:
pair_latent

In [None]:
f

In [None]:
f


In [1]:
import torch_geometric

In [2]:
torch_geometric.__version__

'2.4.0'

In [4]:
torch_geometric.nn.conv.CuGraphSAGEConv?

[0;31mInit signature:[0m
[0mtorch_geometric[0m[0;34m.[0m[0mnn[0m[0;34m.[0m[0mconv[0m[0;34m.[0m[0mCuGraphSAGEConv[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0min_channels[0m[0;34m:[0m [0mint[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mout_channels[0m[0;34m:[0m [0mint[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0maggr[0m[0;34m:[0m [0mstr[0m [0;34m=[0m [0;34m'mean'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mnormalize[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mroot_weight[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mproject[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mbias[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
The GraphSAGE operator from the `"Inductive Representation L