In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [59]:
n_embed = 32
n_genes = 500
batch_size = 1
time_dim = 10
n_heads = 4
dropout = 0.2
n_perts = 2
n_blocks = 6
# block_sizes = [128, 64, 32, 16]


class Head(nn.Module):
    def __init__(self, n_embed, head_size):
        super().__init__()
        self.head_size = head_size
        self.batch_qkv_matrices = nn.Linear(n_embed, head_size * n_heads * 3, bias=False) 
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        q,k,v = self.batch_qkv_matrices(x).split(self.head_size * n_heads, dim=-1) # Now Q,K,V of dim B, T, head size * n_heads
       
        B,T,C = x.shape
 
        # reshape to B, T, n_heads, head_size
        k = k.view(B, T, n_heads, self.head_size).transpose(1,2)
        q = q.view(B, T, n_heads, self.head_size).transpose(1,2) # Now of shape B, n_heads, T, head_size for BMM
        v = v.view(B, T, n_heads, self.head_size).transpose(1,2)
   
        # attention mechanism core
        weight_mat = q @ k.transpose(-2, -1)
        weight_mat = weight_mat * (self.head_size ** -0.5) #
        weight_mat = F.softmax(weight_mat, dim=-1)

        # regularisation
        weight_mat = self.dropout(weight_mat)

        # Multiply with values
        res = weight_mat @ v

        # post-processing
        res = res.transpose(1,2) # B, n_heads, T, C --> B, T, n_heads, C   
        res = res.contiguous().view(B, T, C)

        return res


class MHAttention(nn.Module):
    def __init__(self, n_embed, head_size):
        super().__init__()
        self.att_heads = Head(n_embed=n_embed, head_size=head_size)
        self.projection = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        # print(x.shape)
        res = self.att_heads(x)
        res = self.dropout(self.projection(res))
        return res 

class Feedforward(nn.Module):
    def __init__(self, n_embed) -> None:
        super().__init__()
        scale_factor = 4
        self.ff = nn.Sequential(
            nn.Linear(n_embed, n_embed * scale_factor),
            nn.ReLU(),
            nn.Linear(n_embed * scale_factor, n_embed),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.ff(x)


class Block(nn.Module):
    def __init__(self, n_embed, n_heads) -> None:
        super().__init__()
        self.ff = Feedforward(n_embed)
        self.mhatt = MHAttention(n_heads, (n_embed // n_heads))
        self.layer_norm1 = nn.LayerNorm(n_embed) 
        self.layer_norm2 = nn.LayerNorm(n_embed)
    def forward(self, x):
        x = x + self.mhatt(self.layer_norm1(x))
        x = x + self.ff(self.layer_norm2(x))
        return x 
    

# Bidirectional Encoder representations from transformers for RNA-seq (BERNA)
class BERNA(nn.Module):
    def __init__(self):
        super(BERNA, self).__init__()
        # 0th embedding is for mask!!!
        self.embed_table = nn.Embedding(n_genes+1, n_embed)
        self.pos_embed = nn.Embedding(2, n_embed)
        self.register_buffer("perts_pos_embed", torch.zeros(size=(n_perts,1), device="mps").long())
        self.register_buffer("responses_pos_embed", torch.ones(size=(time_dim,1), device="mps").long())
        self.blocks = nn.Sequential(*[Block(n_embed, n_heads) for _ in range(n_blocks)])

    def forward(self,perts, responses):
        x = torch.cat([perts, responses], dim=-1) # of dimension (B, 12)
        mask = torch.ones_like(x)
        rand_pos = torch.randint(0, time_dim + n_perts, size=(batch_size, ), device="mps")
        mask[:, rand_pos] = 0
        x = x * mask
        perts_mod = x[:, :n_perts]
        responses_mod = x[:, n_perts:]

        print(perts_mod)

        # perts_embed = self.embed_table(perts_mod) + self.pos_embed(self.perts_pos_embed)
        # responses_embed = self.embed_table(responses_mod) + self.pos_embed(self.responses_pos_embed)

        # print(perts_embed.shape)
        # print(responses_embed.shape)


        # x = torch.cat([perts_embed, responses_embed], dim=1)

b = BERNA().to(device="mps")
test_perts = torch.randint(0, n_genes, (batch_size,2,)).to(device="mps")
test_responses = torch.randint(0, n_genes, (batch_size,10,)).to(device="mps")

resp = b(test_perts, test_responses)
# resp.shape


tensor([[136, 495]], device='mps:0') tensor([[136, 495]], device='mps:0')


torch.Size([1, 2, 2, 32])
torch.Size([1, 10, 10, 32])
