In [3]:
import torch
from torch import nn
import torch.nn.functional as F

In [12]:
# hyperparameters
n_embed = 256
n_genes = 500
batch_size = 64
time_dim = 10
n_heads = 8
dropout = 0.2
n_perts = 2
n_blocks = 4

In [13]:

import scanpy as sc
import matplotlib.pyplot as plt

adata = sc.read("/Users/atamb/Downloads/NormanWeissman2019_filtered.h5ad")
# normalise data
sc.pp.normalize_total(adata, target_sum=1e4)
# log transform the data
sc.pp.log1p(adata)

adata_subset = adata[:,:n_genes]
# prep dataset
adata_subset.X

dataset = torch.tensor(adata_subset.X.toarray())
top_vals, top_inds = torch.topk(dataset, time_dim, dim=1)
top_inds = top_inds.to(device="mps")


# top_inds.to(device="mps")

gene_lst = list(adata_subset.var_names)
gene_lst.append("control")

pertlst = list(adata_subset.obs.perturbation)
len(pertlst)

pert_dataset = []

def gene_encoder(x):
    x = x.split("_")
    if len(x) == 1:
        # add one in order to account for mask being at index 0
        return [gene_lst.index(x[0]) + 1, gene_lst.index("control") + 1]
    else:
        return [gene_lst.index(x[0]) + 1, gene_lst.index(x[1]) + 1]


for pert in pertlst:
    ls = pert.split("_")
    if ls[0] not in gene_lst:
        gene_lst.append(ls[0])
    if len(ls) > 1:
        if ls[1] not in gene_lst:
            gene_lst.append(ls[1])
    
    pert_dataset.append(gene_encoder(pert))

    # print(gene_encoder(pert))

pert_dataset_tensor = torch.tensor(pert_dataset, device="mps")

n_genes = len(gene_lst)



In [14]:
def gen_batch():
    index_tensor = torch.randint(0, len(pert_dataset_tensor), (batch_size,), device="mps")
    # print(index_tensor)
    perts = pert_dataset_tensor[index_tensor]
    resp = top_inds[index_tensor]
    return perts, resp


In [15]:

# block_sizes = [128, 64, 32, 16]


class Head(nn.Module):
    def __init__(self, n_embed, head_size):
        super().__init__()
        self.head_size = head_size
        self.batch_qkv_matrices = nn.Linear(n_embed, head_size * n_heads * 3, bias=False) 
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        q,k,v = self.batch_qkv_matrices(x).split(self.head_size * n_heads, dim=-1) # Now Q,K,V of dim B, T, head size * n_heads
       
        B,T,C = x.shape
 
        # reshape to B, T, n_heads, head_size
        k = k.view(B, T, n_heads, self.head_size).transpose(1,2)
        q = q.view(B, T, n_heads, self.head_size).transpose(1,2) # Now of shape B, n_heads, T, head_size for BMM
        v = v.view(B, T, n_heads, self.head_size).transpose(1,2)
   
        # attention mechanism core
        weight_mat = q @ k.transpose(-2, -1)
        weight_mat = weight_mat * (self.head_size ** -0.5) #
        weight_mat = F.softmax(weight_mat, dim=-1)

        # regularisation
        weight_mat = self.dropout(weight_mat)

        # Multiply with values
        res = weight_mat @ v

        # post-processing
        res = res.transpose(1,2) # B, n_heads, T, C --> B, T, n_heads, C   
        res = res.contiguous().view(B, T, C)

        return res


class MHAttention(nn.Module):
    def __init__(self, n_embed, head_size):
        super().__init__()
        self.att_heads = Head(n_embed=n_embed, head_size=head_size)
        self.projection = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        # print(x.shape)
        res = self.att_heads(x)
        res = self.dropout(self.projection(res))
        return res 

class Feedforward(nn.Module):
    def __init__(self, n_embed) -> None:
        super().__init__()
        scale_factor = 6
        self.ff = nn.Sequential(
            nn.Linear(n_embed, n_embed * scale_factor),
            nn.ReLU(),
            nn.Linear(n_embed * scale_factor, n_embed),
            nn.ReLU(),
            nn.Linear(n_embed, n_embed*scale_factor),
            nn.ReLU(),
            nn.Linear(n_embed*scale_factor, n_embed),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.ff(x)


class Block(nn.Module):
    def __init__(self, n_embed, n_heads) -> None:
        super().__init__()
        self.ff = Feedforward(n_embed)
        self.mhatt = MHAttention(n_embed, (n_embed // n_heads))
        self.layer_norm1 = nn.LayerNorm(n_embed) 
        self.layer_norm2 = nn.LayerNorm(n_embed)
    def forward(self, x):
        x = x + self.mhatt(self.layer_norm1(x))
        x = x + self.ff(self.layer_norm2(x))
        return x


    
# Bidirectional Encoder representations from transformers for RNA-seq (BERNA)
class BERNA(nn.Module):
    def __init__(self):
        super(BERNA, self).__init__()
        # 0th embedding is for mask!!!
        self.embed_table = nn.Embedding(n_genes+1, n_embed)
        # two embeddings, unique for both perts and responses
        self.pos_embed = nn.Embedding(2, n_embed)
        # we want to just keep the indices for positional embeddings stored so we don't create a new tensor every time
        self.register_buffer("perts_pos_embed", torch.zeros(size=(batch_size,n_perts), device="mps").long())
        self.register_buffer("responses_pos_embed", torch.ones(size=(batch_size,time_dim), device="mps").long())

        # Attention blocks
        self.blocks = nn.Sequential(*[Block(n_embed, n_heads) for _ in range(n_blocks)])
        self.layernorm = nn.LayerNorm(n_embed)
        self.final_proj = nn.Linear(n_embed, n_genes)


    def forward(self,perts, responses):
        x = torch.cat([perts, responses], dim=-1) # of dimension (B, 12). 12 because 2 perts and 10 responses per batch element
        # get ones, then find random positions to mask, set to zero, then multiply
        mask = torch.ones_like(x)
        rand_pos = torch.randint(0, time_dim + n_perts, size=(batch_size, ), device="mps")
        
        target_token = x[torch.arange(batch_size), rand_pos]         

        mask[torch.arange(x.size(0)), rand_pos] = 0
   
        x = x * mask

        # re-split after masking into perts and responses
        perts_mod = x[:, :n_perts]
        responses_mod = x[:, n_perts:]

        # get perturbation embeddings
        perts_embed = self.embed_table(perts_mod) + self.pos_embed(self.perts_pos_embed)

        # get response embeddings
        responses_embed = self.embed_table(responses_mod) + self.pos_embed(self.responses_pos_embed)

        # re-concatenate and pass through blocks
        x = torch.cat([perts_embed, responses_embed], dim=1)
        x = self.final_proj(self.layernorm(self.blocks(x)))

        # get logits for masked positions        
        logits = x[torch.arange(batch_size), rand_pos]

        loss = F.cross_entropy(logits, target_token)        
        return loss        


# b = BERNA().to(device="mps")
# test_perts = torch.randint(0, n_genes, (batch_size,2,)).to(device="mps")
# test_responses = torch.randint(0, n_genes, (batch_size,10)).to(device="mps")

# resp = b(test_perts, test_responses)

# resp.shape


In [16]:
model = BERNA().to(device="mps")
optim = torch.optim.Adam(model.parameters(), lr=3e-4)


In [17]:
n_epochs = 5000

loss_lst = []

for i in range(n_epochs):
    optim.zero_grad()
    perts, resp = gen_batch()
    loss = model(perts, resp)
    loss.backward()
    loss_lst.append(loss.item())
    if i % 50 == 0:
        print(sum(loss_lst) / len(loss_lst))
        loss_lst = []
    optim.step()

6.436053276062012
3.87002788066864
2.5784301781654357
2.2551401090621948
2.1555699348449706
2.209217324256897
2.1970620560646057
2.1959004855155944
2.155860161781311
2.1045082664489745
2.0870564794540405
2.105900583267212
2.1085131669044497
2.151959528923035
2.114898271560669
2.100181882381439
2.0332925009727476
2.1372216892242433
2.0057388377189636
2.0102581429481505
2.060205237865448
2.0840709137916567
2.03759836435318
2.0780045223236083
2.009292631149292
1.9633903074264527
2.0514724683761596
2.0436591863632203
1.9800955486297607
1.9493935227394104
2.0009888863563536
1.9492456603050232
1.9421072936058044
1.9988496494293213
1.9725536632537841
1.9248845195770263
1.9107662296295167
1.951020085811615
1.935927131175995
1.9327314972877503
1.8982454490661622
1.9213224864006042
1.942727143764496
1.9726701283454895
1.927194182872772
1.9363679385185242
1.8983551263809204
1.9042508149147033
1.9356699061393738
1.8861456489562989
1.9115185022354126
1.8997963905334472
1.8955616855621338
1.85365439