In [1]:
import torch.nn as nn


In [2]:
import numpy as np
import torch
import math
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
import warnings
from termcolor import colored
import torch.nn.functional as F
from einops import rearrange, repeat

try:
    from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func, flash_attn_unpadded_kvpacked_func
    from flash_attn.bert_padding import unpad_input, pad_input
    from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
except:
    warnings.warn(colored('Could not import flash_attn', 'magenta'))
    load_flash_attn_check = False


class FlashAttention2d(nn.Module):
    def __init__(self, model_dim, num_head, softmax_scale,
                 zero_init, use_bias, initializer_range, n_layers):
        super().__init__()
        assert model_dim % num_head == 0
        assert model_dim % num_head == 0
        self.key_dim = model_dim // num_head
        self.value_dim = model_dim // num_head

        self.causal = False
        self.checkpointing = False

        if softmax_scale:
            self.softmax_scale = self.key_dim ** (-0.5)
        else:
            self.softmax_scale = None

        self.num_head = num_head

        self.Wqkv = nn.Linear(model_dim, 3 * model_dim, bias=use_bias)

        self.out_proj = nn.Linear(model_dim, model_dim, bias=use_bias)

        self.initialize(zero_init, use_bias, initializer_range, n_layers)

    def initialize(self, zero_init, use_bias, initializer_range, n_layers):

        nn.init.normal_(self.Wqkv.weight, mean=0.0, std=initializer_range)

        if use_bias:
            nn.init.constant_(self.Wqkv.bias, 0.0)
            nn.init.constant_(self.out_proj.bias, 0.0)

        if zero_init:
            nn.init.constant_(self.out_proj.weight, 0.0)
        else:
            nn.init.normal_(self.out_proj.weight, mean=0.0, std=initializer_range / math.sqrt(2 * n_layers))

    def forward(self, pair_act, attention_mask):

        batch_size = pair_act.shape[0]
        seqlen = pair_act.shape[1]
        extended_batch_size = batch_size * seqlen

        qkv = self.Wqkv(pair_act)
        not_attention_mask = torch.logical_not(attention_mask)

        x_qkv = rearrange(qkv, 'b s f ... -> (b s) f ...', b=batch_size, f=seqlen, s=seqlen)
        key_padding_mask = rearrange(not_attention_mask, 'b s f ... -> (b s) f ...', b=batch_size, f=seqlen, s=seqlen)

        x_unpad, indices, cu_seqlens, max_s = unpad_input(x_qkv, key_padding_mask)
        x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=self.num_head)

        if self.training and self.checkpointing:
            output_unpad = torch.utils.checkpoint.checkpoint(flash_attn_unpadded_qkvpacked_func,
                                                             x_unpad, cu_seqlens, max_s, 0.0, self.softmax_scale,
                                                             self.causal, False
                                                             )
        else:
            output_unpad = flash_attn_unpadded_qkvpacked_func(
                x_unpad, cu_seqlens, max_s, 0.0,
                softmax_scale=self.softmax_scale, causal=self.causal
            )

        pre_pad_latent = rearrange(output_unpad, 'nnz h d -> nnz (h d)')
        padded_latent = pad_input(pre_pad_latent, indices, extended_batch_size, seqlen)
        output = rearrange(padded_latent, 'b f (h d) -> b f h d', h=self.num_head)

        output = rearrange(output, '(b s) f h d -> b s f (h d)', b=batch_size, f=seqlen, s=seqlen)

        return self.out_proj(output)


class Attention2d(nn.Module):
    def __init__(self, model_dim, num_head, softmax_scale,
                 precision, zero_init, use_bias,
                 initializer_range, n_layers):
        super().__init__()
        assert model_dim % num_head == 0
        assert model_dim % num_head == 0
        self.key_dim = model_dim // num_head
        self.value_dim = model_dim // num_head

        if softmax_scale:
            self.softmax_scale = torch.sqrt(torch.FloatTensor([self.key_dim]))
        else:
            self.softmax_scale = False

        self.num_head = num_head
        self.model_dim = model_dim

        if precision == "fp32" or precision == 32 or precision == "bf16":
            self.mask_bias = -1e9
        elif precision == "fp16" or precision == 16:
            self.mask_bias = -1e4
        else:
            raise UserWarning(f"unknown precision: {precision} . Please us fp16, fp32 or bf16")

        self.Wqkv = nn.Linear(model_dim, 3 * model_dim, bias=use_bias)
        self.out_proj = nn.Linear(model_dim, model_dim, bias=use_bias)

        self.initialize(zero_init, use_bias, initializer_range, n_layers)

    def initialize(self, zero_init, use_bias, initializer_range, n_layers):

        nn.init.normal_(self.Wqkv.weight, mean=0.0, std=initializer_range)

        if use_bias:
            nn.init.constant_(self.Wqkv.bias, 0.0)
            nn.init.constant_(self.out_proj.bias, 0.0)

        if zero_init:
            nn.init.constant_(self.out_proj.weight, 0.0)
        else:
            nn.init.normal_(self.out_proj.weight, mean=0.0, std=initializer_range / math.sqrt(2 * n_layers))

    def forward(self, pair_act, attention_mask):

        batch_size = pair_act.size(0)
        N_seq = pair_act.size(1)
        N_res = pair_act.size(2)

        query, key, value = self.Wqkv(pair_act).split(self.model_dim, dim=3)

        query = query.view(batch_size, N_seq, N_res, self.num_head, self.key_dim).permute(0, 1, 3, 2, 4)
        key = key.view(batch_size, N_seq, N_res, self.num_head, self.value_dim).permute(0, 1, 3, 4, 2)
        value = value.view(batch_size, N_seq, N_res, self.num_head, self.value_dim).permute(0, 1, 3, 2, 4)

        attn_weights = torch.matmul(query, key)

        if self.softmax_scale:
            attn_weights = attn_weights / self.softmax_scale.to(pair_act.device)

        if attention_mask is not None:
            attention_mask = attention_mask[:, :, None, None, :]
            attn_weights.masked_fill_(attention_mask, self.mask_bias)
        attn_weights = F.softmax(attn_weights, dim=-1)

        weighted_avg = torch.matmul(attn_weights, value).permute(0, 1, 3, 2, 4)

        output = self.out_proj(weighted_avg.reshape(batch_size, N_seq, N_res, self.num_head * self.value_dim))
        return output


class TriangleAttention(nn.Module):
    def __init__(self, model_dim, num_head, orientation, softmax_scale,
                 precision, zero_init, use_bias, flash_attn,
                 initializer_range, n_layers):
        super().__init__()

        self.model_dim = model_dim
        self.num_head = num_head

        assert orientation in ['per_row', 'per_column']
        self.orientation = orientation

        self.input_norm = nn.LayerNorm(model_dim, eps=1e-6)

        if flash_attn:
            self.attn = FlashAttention2d(model_dim, num_head, softmax_scale, zero_init, use_bias,
                                         initializer_range, n_layers)
        else:
            self.attn = Attention2d(model_dim, num_head, softmax_scale,
                                    precision, zero_init, use_bias, initializer_range, n_layers)

    def forward(self, pair_act, pair_mask, cycle_infer=False):

        assert len(pair_act.shape) == 4

        if self.orientation == 'per_column':
            pair_act = torch.swapaxes(pair_act, -2, -3)
            if pair_mask is not None:
                pair_mask = torch.swapaxes(pair_mask, -1, -2)

        pair_act = self.input_norm(pair_act)

        if self.training and not cycle_infer:
            pair_act = checkpoint(self.attn, pair_act, pair_mask, use_reentrant=True)
        else:
            pair_act = self.attn(pair_act, pair_mask)

        if self.orientation == 'per_column':
            pair_act = torch.swapaxes(pair_act, -2, -3)

        return pair_act
    
class RNAformerBlock(nn.Module):

    def __init__(self, config):
        super().__init__()

        ff_dim = int(config.ff_factor * config.model_dim)

        self.attn_pair_row = TriangleAttention(config.model_dim, config.num_head, 'per_row', config.softmax_scale,
                                               config.precision, config.zero_init, config.use_bias, config.flash_attn,
                                               config.initializer_range, config.n_layers)
        self.attn_pair_col = TriangleAttention(config.model_dim, config.num_head, 'per_column', config.softmax_scale,
                                               config.precision, config.zero_init, config.use_bias, config.flash_attn,
                                               config.initializer_range, config.n_layers)

        self.pair_dropout_row = nn.Dropout(p=config.resi_dropout / 2)
        self.pair_dropout_col = nn.Dropout(p=config.resi_dropout / 2)

        if config.ff_kernel:
            self.pair_transition = ConvFeedForward(config.model_dim, ff_dim, use_bias=config.use_bias,
                                                   kernel=config.ff_kernel,
                                                   initializer_range=config.initializer_range,
                                                   zero_init=config.zero_init,
                                                   n_layers=config.n_layers)
        else:
            self.pair_transition = FeedForward(config.model_dim, ff_dim, use_bias=config.use_bias, glu=config.use_glu,
                                               initializer_range=config.initializer_range, zero_init=config.zero_init,
                                               n_layers=config.n_layers)

        self.res_dropout = nn.Dropout(p=config.resi_dropout)

    def forward(self, pair_act, pair_mask, cycle_infer=False):

        pair_act = pair_act + self.pair_dropout_row(self.attn_pair_row(pair_act, pair_mask, cycle_infer))
        pair_act = pair_act + self.pair_dropout_col(self.attn_pair_col(pair_act, pair_mask, cycle_infer))
        pair_act = pair_act + self.res_dropout(self.pair_transition(pair_act))

        return pair_act
    
class ConvFeedForward(nn.Module):

    def __init__(self, model_dim, ff_dim, use_bias, initializer_range, n_layers, kernel, zero_init=True):
        super(ConvFeedForward, self).__init__()

        self.zero_init = zero_init

        self.input_norm = nn.GroupNorm(1, model_dim)

        if kernel == 1:
            self.conv1 = nn.Conv2d(model_dim, ff_dim, kernel_size=1, bias=use_bias)
            self.conv2 = nn.Conv2d(ff_dim, model_dim, kernel_size=1, bias=use_bias)
        else:
            self.conv1 = nn.Conv2d(model_dim, ff_dim, bias=use_bias, kernel_size=kernel, padding=(kernel - 1) // 2)
            self.conv2 = nn.Conv2d(ff_dim, model_dim, bias=use_bias, kernel_size=kernel, padding=(kernel - 1) // 2)

        self.act = nn.SiLU()

        self.initialize(zero_init, use_bias, initializer_range, n_layers)

    def initialize(self, zero_init, use_bias, initializer_range, n_layers):

        nn.init.normal_(self.conv1.weight, mean=0.0, std=initializer_range)

        if use_bias:
            nn.init.constant_(self.conv1.bias, 0.0)
            nn.init.constant_(self.conv2.bias, 0.0)

        if zero_init:
            nn.init.constant_(self.conv2.weight, 0.0)
        else:
            nn.init.normal_(self.conv2.weight, mean=0.0, std=initializer_range / math.sqrt(2 * n_layers))

    def forward(self, x):

        x = x.permute(0, 3, 1, 2)

        x = self.input_norm(x)
        x = self.act(self.conv1(x))
        x = self.conv2(x)
        x = x.permute(0, 2, 3, 1)
        return x
    
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class PosEmbedding(nn.Module):
    def __init__(self, vocab, model_dim, max_len, rel_pos_enc, initializer_range):

        super().__init__()

        self.rel_pos_enc = rel_pos_enc
        self.max_len = max_len

        self.embed_seq = nn.Embedding(vocab, model_dim)

        self.scale = nn.Parameter(torch.sqrt(torch.FloatTensor([model_dim // 2])), requires_grad=False)

        if rel_pos_enc:
            self.embed_pair_pos = nn.Linear(max_len, model_dim, bias=False)
        else:
            self.embed_pair_pos = nn.Linear(model_dim, model_dim, bias=False)

            pe = torch.zeros(max_len, model_dim)
            position = torch.arange(0, max_len).unsqueeze(1).type(torch.FloatTensor)
            div_term = torch.exp(
                torch.arange(0, model_dim, 2).type(torch.FloatTensor) * -(math.log(10000.0) / model_dim))
            pe[:, 0::2] = torch.sin(position * div_term)
            pe[:, 1::2] = torch.cos(position * div_term)
            pe = pe.unsqueeze(0)
            pe = torch.nn.Parameter(pe, requires_grad=False)
            self.register_buffer('pe', pe)

        self.initialize(initializer_range)  #

    def initialize(self, initializer_range):

        nn.init.normal_(self.embed_seq.weight, mean=0.0, std=initializer_range)
        nn.init.normal_(self.embed_pair_pos.weight, mean=0.0, std=initializer_range)

    def relative_position_encoding(self, src_seq):

        residue_index = torch.arange(src_seq.size()[1], device=src_seq.device).expand(src_seq.size())
        rel_pos = F.one_hot(torch.clip(residue_index, min=0, max=self.max_len - 1), self.max_len)

        if isinstance(self.embed_pair_pos.weight, torch.cuda.BFloat16Tensor):
            rel_pos = rel_pos.type(torch.bfloat16)
        elif isinstance(self.embed_pair_pos.weight, torch.cuda.HalfTensor):
            rel_pos = rel_pos.half()
        else:
            rel_pos = rel_pos.type(torch.float32)

        pos_encoding = self.embed_pair_pos(rel_pos)
        return pos_encoding

    def forward(self, src_seq):

        seq_embed = self.embed_seq(src_seq) * self.scale

        if self.rel_pos_enc:
            seq_embed = seq_embed + self.relative_position_encoding(src_seq)
        else:
            seq_embed = seq_embed + self.embed_pair_pos(self.pe[:, :src_seq.size(1)])

        return seq_embed


class EmbedSequence2Matrix(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.src_embed_1 = PosEmbedding(config.seq_vocab_size, config.model_dim, config.max_len,
                                        config.rel_pos_enc, config.initializer_range)
        self.src_embed_2 = PosEmbedding(config.seq_vocab_size, config.model_dim, config.max_len,
                                        config.rel_pos_enc, config.initializer_range)

        self.norm = nn.LayerNorm(config.model_dim)

    def forward(self, src_seq):
        seq_1_embed = self.src_embed_1(src_seq)
        seq_2_embed = self.src_embed_2(src_seq)

        pair_latent = seq_1_embed.unsqueeze(1) + seq_2_embed.unsqueeze(2)

        pair_latent = self.norm(pair_latent)

        return pair_latent
    
class RNAformerStack(nn.Module):

    def __init__(self, config):
        super().__init__()

        self.output_ln = nn.LayerNorm(config.model_dim)

        module_list = []
        for idx in range(config.n_layers):
            layer = RNAformerBlock(config=config)
            module_list.append(layer)
        self.layers = nn.ModuleList(module_list)

    def forward(self, pair_act, pair_mask, cycle_infer=False):

        for idx, layer in enumerate(self.layers):
            pair_act = layer(pair_act, pair_mask, cycle_infer=cycle_infer)

        pair_act = self.output_ln(pair_act)

        return pair_act



In [3]:
class CONFIG:
    model_dim= 256 # hidden dimension of transformer
    n_layers= 6  # number of transformer layers
    num_head= 4  # number of heads per layer    
    ff_factor= 4  # hidden dim * ff_factor = size of feed-forward layer
    ff_kernel= 3    
    cycling = False 
    resi_dropout = 0.1
    embed_dropout = 0.1  
    rel_pos_enc =True  # relative position encoding
    head_bias= False
    ln_eps = 1e-5    
    softmax_scale= True
    key_dim_scaler= True
    gating = False
    use_glu = False
    use_bias = True  
    flash_attn =  False   
    initializer_range = 0.02
    zero_init =  False
    precision = 16
    seq_vocab_size = 5
    max_len = 88

In [5]:
import pandas as pd 
file = "../data/all_test_data.npy"
df = pd.DataFrame(np.load(file, allow_pickle=True).tolist())

In [6]:
from typing import List
import torch
import numpy as np


class IndexDataset(torch.utils.data.Dataset):

    def __init__(self, dataset_indices):
        self.dataset_indices = dataset_indices

    def __getitem__(self, index):
        return self.dataset_indices[index]

    def __len__(self):
        return len(self.dataset_indices)


class TokenBasedRandomSampler:

    def __init__(self, dataset, token_key_fn, batch_token_size, batching, repeat, sort_samples, shuffle,
                 shuffle_pool_size, drop_last, seed=1):

        super().__init__()
        self.token_length = [token_key_fn(s) for s in dataset]

        self.batch_token_size = batch_token_size

        self.repeat = repeat
        self.batching = batching
        self.sort_samples = sort_samples
        self.drop_last = drop_last

        self.shuffle = shuffle
        self.shuffle_pool_size = shuffle_pool_size

        self.rng = np.random.default_rng(seed=seed)

        self.reverse = False

        if not self.repeat:
            self.minibatches = self.precompute_minibatches()

    def __len__(self):
        if not self.batching:
            return len(self.token_length)
        else:
            return len(self.minibatches)

    def get_index_list(self):
        index_list = [i for i in range(len(self.token_length))]
        if self.shuffle:
            self.rng.shuffle(index_list)
        return index_list

    def get_index_iter(self):
        while True:
            index_list = self.get_index_list()
            for i in index_list:
                yield i

    def pool_and_sort(self, sample_iter):
        pool = []
        for sample in sample_iter:
            if not self.sort_samples:
                yield sample
            else:
                pool.append(sample)
                if len(pool) >= self.shuffle_pool_size:
                    pool.sort(key=lambda x: self.token_length[x], reverse=self.reverse)
                    self.reverse = not self.reverse
                    while len(pool) > 0:
                        yield pool.pop()
        if len(pool) > 0:
            pool.sort(key=lambda x: self.token_length[x], reverse=self.reverse)
            self.reverse = not self.reverse
            while len(pool) > 0:
                yield pool.pop()

    def get_minibatches(self, index_iter):

        minibatch, max_size_in_batch = [], 0

        if self.batching and self.shuffle and self.shuffle_pool_size and self.sort_samples:
            index_iter = self.pool_and_sort(index_iter)

        for sample in index_iter:

            if self.batching:
                minibatch.append(sample)
                max_size_in_batch = max(max_size_in_batch, self.token_length[sample])
                size_so_far = len(minibatch) * max(max_size_in_batch, self.token_length[sample])
                if size_so_far == self.batch_token_size:
                    yield minibatch
                    minibatch, max_size_in_batch = [], 0
                if size_so_far > self.batch_token_size:
                    yield minibatch[:-1]
                    minibatch = minibatch[-1:]
                    max_size_in_batch = self.token_length[minibatch[0]]
            else:
                yield [sample]

        if (not self.drop_last) and len(minibatch) > 0:
            yield minibatch

    def precompute_minibatches(self):
        index_iter = self.get_index_list()

        minibatches = [m for m in self.get_minibatches(index_iter) if len(m) > 0]
        if self.shuffle:
            self.rng.shuffle(minibatches)
        return minibatches

    def __iter__(self):
        if self.repeat:
            index_iter = self.get_index_iter()
            for batch in self.get_minibatches(index_iter):
                yield batch
        else:

            for m in self.minibatches:
                yield m


class CollatorRNA:

    def __init__(self, pad_index, ignore_index):
        self.ignore_index = ignore_index
        self.pad_index = pad_index

    def __call__(self, samples, neg_samples=False) -> List[List[int]]:
        # tokenize the input text samples

        with torch.no_grad():
            batch_dict = {k: [dic[k] for dic in samples] for k in samples[0] if k in ['length', 'pdb_sample', 'pos1id']}

            batch_dict['length'] = torch.stack(batch_dict['length'])

            if 'pdb_sample' in batch_dict:
                batch_dict['pdb_sample'] = torch.stack(batch_dict['pdb_sample'])

            max_len = batch_dict['length'].max()
            batch_size = len(samples)

            src_seq = torch.full((batch_size, max_len), self.pad_index)
            src_struct = torch.full((batch_size, max_len), self.pad_index)

            if 'pos1id' in batch_dict:
                max_pos = max(pos.shape[0] for pos in batch_dict['pos1id'])
                pos1id = torch.full((batch_size, max_pos), self.pad_index)
                pos2id = torch.full((batch_size, max_pos), self.pad_index)

                trg_seq = torch.full((batch_size, max_len), self.ignore_index)
                trg_struct = torch.full((batch_size, max_len), self.ignore_index)

            src_mat = torch.LongTensor(batch_size, max_len, max_len).fill_(self.pad_index)
            trg_mat = torch.LongTensor(batch_size, max_len, max_len).fill_(self.ignore_index)

            for b_id, sample in enumerate(samples):
                src_seq[b_id, :sample['src_seq'].size(0)] = sample['src_seq']

                if 'src_struct' in batch_dict:
                    src_struct[b_id, :sample['src_struct'].size(0)] = sample['src_struct']

                if 'pos1id' in batch_dict:
                    pos1id[b_id, :sample['pos1id'].size(0)] = sample['pos1id']
                    pos2id[b_id, :sample['pos2id'].size(0)] = sample['pos2id']
                    trg_seq[b_id, :sample['trg_seq'].size(0)] = sample['trg_seq']
                    trg_struct[b_id, :sample['trg_struct'].size(0)] = sample['trg_struct']

                    src_mat[b_id, :batch_dict['length'][b_id], :batch_dict['length'][b_id]] = 0
                    trg_mat[b_id, :batch_dict['length'][b_id], :batch_dict['length'][b_id]] = 0

                    src_mat[b_id, sample['pos1id'], sample['pos2id']] = 1
                    src_mat[b_id, sample['pos2id'], sample['pos1id']] = 1
                    trg_mat[b_id, sample['pos1id'], sample['pos2id']] = 1
                    trg_mat[b_id, sample['pos2id'], sample['pos1id']] = 1

            batch_dict['src_seq'] = src_seq

            if 'src_struct' in batch_dict:
                batch_dict['src_struct'] = src_struct

            if 'pos1id' in batch_dict:
                batch_dict['pos1id'] = pos1id
                batch_dict['pos2id'] = pos2id
                batch_dict['trg_seq'] = trg_seq
                batch_dict['trg_struct'] = trg_struct
                batch_dict['src_mat'] = src_mat
                batch_dict['trg_mat'] = trg_mat

        return batch_dict
    
def make_pair_mask(src, src_len):
     encode_mask = torch.arange(src.shape[1], device=src.device).expand(src.shape[:2]) < src_len.unsqueeze(1)
     pair_mask = encode_mask[:, None, :] * encode_mask[:, :, None]
     assert isinstance(pair_mask, torch.BoolTensor) or isinstance(pair_mask, torch.cuda.BoolTensor)
     return torch.bitwise_not(pair_mask)

In [18]:
sequence = df['sequence'][:100]

In [19]:
IGNORE_INDEX = -100
PAD_INDEX = 0
input_samples = []
for seq in sequence:
    length = len(seq)
    seq_vocab = ['A', 'C', 'G', 'U', 'N']
    seq_stoi = dict(zip(seq_vocab, range(len(seq_vocab))))
    int_sequence = list(map(seq_stoi.get, seq))
    input_sample = torch.LongTensor(int_sequence)
    input_sample = {'src_seq': input_sample, 'length': torch.LongTensor([len(input_sample)])[0]}
    input_samples.append(input_sample)
collator = CollatorRNA(PAD_INDEX, IGNORE_INDEX)
batch = collator(input_samples)

In [24]:
batch

{'length': tensor([ 88,  71,  73,  93,  66,  62,  86,  85, 151,  76,  94,  69,  75, 160,
          72,  69,  82,  53,  62,  71,  77,  47,  65,  76,  72,  52,  63, 137,
          62,  90,  80,  60,  71,  90,  72,  90, 112,  71,  66,  71,  70, 166,
         159,  77,  72,  67,  50,  61,  55,  52,  52, 114,  71,  68,  52,  56,
          66,  84,  69, 141,  71,  79,  77,  64, 165,  56,  84,  60,  71,  93,
          92,  67,  68,  93,  85,  75,  60,  65,  56,  72,  74,  55,  90,  92,
          97,  74,  55,  48,  78,  68,  75,  68,  51,  56,  69,  52,  73,  92,
          76,  57]),
 'src_seq': tensor([[3, 2, 0,  ..., 0, 0, 0],
         [3, 0, 2,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [1, 2, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [2, 1, 2,  ..., 0, 0, 0]])}

In [27]:
pair_mask = make_pair_mask(batch['src_seq'], batch['length'])
pair_mask.shape

torch.Size([100, 166, 166])

In [28]:
pair_latent = EmbedSequence2Matrix(CONFIG)(batch['src_seq'])
pair_latent.masked_fill_(pair_mask[:, :, :, None], 0.0)
pair_latent.shape

torch.Size([100, 166, 166, 256])

In [17]:
RNAformerStack(CONFIG)(pair_latent, pair_mask).shape

torch.Size([1, 88, 88, 256])

In [40]:
import torch
import torch.nn as nn

# Input Data
data = torch.tensor([
    [1, 0, 2, 4],
    [2, 8, 5, 6]
], dtype=torch.float32).unsqueeze(1)  # Shape: (2, 1, 4)

# Masks
masks = torch.tensor([
    [True, True, False, False],
    [True, False, False, False]
], dtype=torch.bool)  # Ensure dtype is bool

# Apply the mask to zero out values
masked_data = torch.where(masks.unsqueeze(1), torch.zeros_like(data), data)  # Shape should still be (2, 1, 4)

# Define Conv1d layer
conv1d_layer = nn.Conv1d(in_channels=1, out_channels=2, kernel_size=3, padding=1)

# Apply Conv1d
output = conv1d_layer(masked_data)
print(output)


tensor([[[ 0.4216,  1.5679,  3.5777,  1.1813],
         [-0.3246, -0.4616, -1.2953, -0.7179]],

        [[ 5.0070,  6.7411,  2.1508,  0.5941],
         [-0.8725, -3.4541,  1.5236,  0.0857]]],
       grad_fn=<ConvolutionBackward0>)


In [41]:
masked_data

tensor([[[0., 0., 2., 4.]],

        [[0., 8., 5., 6.]]])

In [42]:
masked_data.shape

torch.Size([2, 1, 4])

In [43]:
data.shape

torch.Size([2, 1, 4])

In [44]:
output.shape

torch.Size([2, 2, 4])