In [108]:
import torch
from torch import nn
import torch.nn.functional as F

In [109]:
# hyperparameters
n_embed = 128
n_genes = 500
batch_size = 64
time_dim = 20
n_heads = 8
dropout = 0
n_perts = 2
n_blocks = 6

In [110]:

import scanpy as sc
import matplotlib.pyplot as plt

adata = sc.read("/Users/atamb/Downloads/NormanWeissman2019_filtered.h5ad")
# normalise data
sc.pp.normalize_total(adata, target_sum=1e4)
# log transform the data
sc.pp.log1p(adata)

adata_subset = adata[:,:n_genes]
# prep dataset
adata_subset.X

dataset = torch.tensor(adata_subset.X.toarray())
top_vals, top_inds = torch.topk(dataset, time_dim, dim=1)
top_inds = top_inds.to(device="mps")


# top_inds.to(device="mps")

gene_lst = list(adata_subset.var_names)
gene_lst.append("control")

pertlst = list(adata_subset.obs.perturbation)
len(pertlst)

pert_dataset = []

def gene_encoder(x):
    x = x.split("_")
    if len(x) == 1:
        # add one in order to account for mask being at index 0
        return [gene_lst.index(x[0]) + 1, gene_lst.index("control") + 1]
    else:
        return [gene_lst.index(x[0]) + 1, gene_lst.index(x[1]) + 1]


for pert in pertlst:
    ls = pert.split("_")
    if ls[0] not in gene_lst:
        gene_lst.append(ls[0])
    if len(ls) > 1:
        if ls[1] not in gene_lst:
            gene_lst.append(ls[1])
    
    pert_dataset.append(gene_encoder(pert))

    # print(gene_encoder(pert))

pert_dataset_tensor = torch.tensor(pert_dataset, device="mps")

n_genes = len(gene_lst)



In [111]:
def gen_batch():
    index_tensor = torch.randint(0, len(pert_dataset_tensor), (batch_size,), device="mps")
    # print(index_tensor)
    perts = pert_dataset_tensor[index_tensor]
    resp = top_inds[index_tensor]
    return perts, resp

In [112]:

# block_sizes = [128, 64, 32, 16]
mask_factor = 0.2

class Head(nn.Module):
    def __init__(self, n_embed, head_size):
        super().__init__()
        self.head_size = head_size
        self.batch_qkv_matrices = nn.Linear(n_embed, head_size * n_heads * 3, bias=False) 
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        q,k,v = self.batch_qkv_matrices(x).split(self.head_size * n_heads, dim=-1) # Now Q,K,V of dim B, T, head size * n_heads

        B,T,C = x.shape
 
        # reshape to B, T, n_heads, head_size
        k = k.view(B, T, n_heads, self.head_size).transpose(1,2)
        q = q.view(B, T, n_heads, self.head_size).transpose(1,2) # Now of shape B, n_heads, T, head_size for BMM
        v = v.view(B, T, n_heads, self.head_size).transpose(1,2)
   
        # attention mechanism core
        weight_mat = q @ k.transpose(-2, -1)
        weight_mat = weight_mat * (self.head_size ** -0.5) #
        weight_mat = F.softmax(weight_mat, dim=-1)

        # regularisation
        weight_mat = self.dropout(weight_mat)

        # Multiply with values
        res = weight_mat @ v

        # post-processing
        res = res.transpose(1,2) # B, n_heads, T, C --> B, T, n_heads, C   
        res = res.contiguous().view(B, T, C)

        return res


class MHAttention(nn.Module):
    def __init__(self, n_embed, head_size):
        super().__init__()
        self.att_heads = Head(n_embed=n_embed, head_size=head_size)
        self.projection = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        # print(x.shape)
        res = self.att_heads(x)
        res = self.dropout(self.projection(res))
        return res 

class Feedforward(nn.Module):
    def __init__(self, n_embed) -> None:
        super().__init__()
        scale_factor = 6
        self.ff = nn.Sequential(
            nn.Linear(n_embed, n_embed * scale_factor),
            nn.ReLU(),
            nn.Linear(n_embed * scale_factor, n_embed),
            nn.ReLU(),
            nn.Linear(n_embed, n_embed*scale_factor),
            nn.ReLU(),
            nn.Linear(n_embed*scale_factor, n_embed),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.ff(x)


class Block(nn.Module):
    def __init__(self, n_embed, n_heads) -> None:
        super().__init__()
        self.ff = Feedforward(n_embed)
        self.mhatt = MHAttention(n_embed, (n_embed // n_heads))
        self.layer_norm1 = nn.LayerNorm(n_embed) 
        self.layer_norm2 = nn.LayerNorm(n_embed)
    def forward(self, x):
        x = x + self.mhatt(self.layer_norm1(x))
        x = x + self.ff(self.layer_norm2(x))
        return x


    
# Bidirectional Encoder representations from transformers for RNA-seq (BERNA)
class BERNA(nn.Module):
    def __init__(self):
        super(BERNA, self).__init__()
        # 0th embedding is for mask!!!
        self.embed_table = nn.Embedding(n_genes+1, n_embed)
        # two embeddings, unique for both perts and responses
        self.pos_embed = nn.Embedding(2, n_embed)
        # we want to just keep the indices for positional embeddings stored so we don't create a new tensor every time
        self.register_buffer("perts_pos_embed", torch.zeros(size=(batch_size,n_perts), device="mps").long())
        self.register_buffer("responses_pos_embed", torch.ones(size=(batch_size,time_dim), device="mps").long())

        # Attention blocks
        self.blocks = nn.Sequential(*[Block(n_embed, n_heads) for _ in range(n_blocks)])
        self.layernorm = nn.LayerNorm(n_embed)
        self.final_proj = nn.Linear(n_embed, n_genes)


    def forward(self,perts, responses):
        x = torch.cat([perts, responses], dim=-1) # of dimension (B, 12). 12 because 2 perts and 10 responses per batch element
        # get ones, then find random positions to mask, set to zero, then multiply

        # generate mask
        mask = torch.ones_like(x)
        # creation of random positions to mask
        rand_pos = torch.randint(0, time_dim + n_perts, size=(batch_size, int(x.shape[-1]*mask_factor) ), device="mps")
        # gather all positions to mask, ie the target tokens which we want to predict
        # print(x.shape)
        target_token = torch.gather(x, 1, rand_pos)
        # set all positions we wish to predict in x to zero, is [MASK] token
        mask.scatter_(1, rand_pos, 0)
        x = x * mask

        # # re-split after masking into perts and responses
        perts_mod = x[:, :n_perts]
        responses_mod = x[:, n_perts:]

        # get perturbation embeddings
        perts_embed = self.embed_table(perts_mod) + self.pos_embed(self.perts_pos_embed)

        # get response embeddings
        responses_embed = self.embed_table(responses_mod) + self.pos_embed(self.responses_pos_embed)

        # re-concatenate and pass through blocks
        x = torch.cat([perts_embed, responses_embed], dim=1)
        x = self.final_proj(self.layernorm(self.blocks(x)))

        # ensure that rand_pos is same shape as x
        modified_rand_pos = rand_pos.unsqueeze(-1).expand(-1,-1,x.shape[-1])
        logits = torch.gather(x, 1, modified_rand_pos)
        B,T,C = logits.shape
        logits = logits.view(B*T, C)
        target_token = target_token.view(B*T)
        loss = F.cross_entropy(logits, target_token)        
        return loss        


# b = BERNA().to(device="mps")
# test_perts = torch.randint(0, n_genes, (batch_size,2,)).to(device="mps")
# test_responses = torch.randint(0, n_genes, (batch_size,10)).to(device="mps")
# resp = b(test_perts, test_responses)
# print(resp)
# resp.shape


In [113]:
model = BERNA().to(device="mps")
optim = torch.optim.Adam(model.parameters(), lr=3e-4)

In [114]:
n_epochs = 10000

loss_lst = []

for i in range(n_epochs):
    optim.zero_grad()
    perts, resp = gen_batch()
    loss = model(perts, resp)
    loss.backward()
    loss_lst.append(loss.item())
    if i % 50 == 0:
        print(sum(loss_lst) / len(loss_lst))
        loss_lst = []
    optim.step()

6.65017032623291
5.094583158493042
4.221706733703614
3.76480459690094
3.6407118749618532
3.5485582733154297
3.5103693056106566
3.46583740234375
3.4792668771743775
3.436477928161621
3.4386022424697877
3.4280880975723265
3.409325571060181
3.441022834777832
3.4124304246902466
3.382338213920593
3.4216264963150023
3.381232476234436
3.3870992183685305
3.3705470514297486
3.350314302444458
3.373988699913025
3.3576741361618043
3.3615059900283812
3.3876878595352173
3.349146499633789
3.33632559299469
3.3634633827209472
3.358061981201172
3.3396395111083983
3.356384139060974
3.3108683109283445
3.315969452857971
3.3217591857910156
3.3273305559158324
3.327773718833923
3.340449113845825
3.3109439611434937
3.341012692451477
3.3267406368255616
3.318300557136536
3.345765314102173
3.314155926704407
3.309562129974365
3.3262577152252195
3.292046389579773
3.298604898452759
3.2796576976776124
3.310287389755249
3.2892987298965455
3.2725835990905763
3.274227681159973
3.2691804933547974
3.2929777479171753
3.2688

In [116]:
# save weights
torch.save(model.state_dict(), "berna.pt")

In [126]:
with torch.no_grad():
    model.eval()
    # get embeddings
    token = gene_encoder("OR4F5")
    print(token)
    vec = model.embed_table(torch.tensor(token, device="mps").unsqueeze(0))
    print(vec.shape)

[3, 501]
torch.Size([1, 2, 128])


In [134]:
gene_encoder("RP11-34P13.7")

[4, 501]

In [128]:
gene_lst

['RP11-34P13.3',
 'FAM138A',
 'OR4F5',
 'RP11-34P13.7',
 'RP11-34P13.8',
 'RP11-34P13.14',
 'RP11-34P13.9',
 'FO538757.3',
 'FO538757.2',
 'AP006222.2',
 'RP5-857K21.15',
 'RP4-669L17.2',
 'RP4-669L17.10',
 'OR4F29',
 'RP5-857K21.4',
 'RP5-857K21.2',
 'OR4F16',
 'RP11-206L10.4',
 'RP11-206L10.9',
 'FAM87B',
 'LINC00115',
 'FAM41C',
 'RP11-54O7.16',
 'RP11-54O7.1',
 'RP11-54O7.2',
 'RP11-54O7.3',
 'SAMD11',
 'NOC2L',
 'KLHL17',
 'PLEKHN1',
 'PERM1',
 'RP11-54O7.17',
 'HES4',
 'ISG15',
 'RP11-54O7.11',
 'AGRN',
 'RP11-54O7.18',
 'RNF223',
 'C1orf159',
 'LINC01342',
 'RP11-465B22.8',
 'TTLL10-AS1',
 'TTLL10',
 'TNFRSF18',
 'TNFRSF4',
 'SDF4',
 'B3GALT6',
 'FAM132A',
 'RP5-902P8.12',
 'UBE2J2',
 'RP5-902P8.10',
 'SCNN1D',
 'ACAP3',
 'PUSL1',
 'CPSF3L',
 'CPTP',
 'TAS1R3',
 'DVL1',
 'MXRA8',
 'AURKAIP1',
 'CCNL2',
 'MRPL20',
 'RP4-758J18.13',
 'ANKRD65',
 'RP4-758J18.7',
 'TMEM88B',
 'RP4-758J18.10',
 'VWA1',
 'ATAD3C',
 'ATAD3B',
 'ATAD3A',
 'TMEM240',
 'SSU72',
 'RP5-832C2.5',
 'AL645728.

In [131]:
gene_lst.index("control")

500