In [1]:
import pandas as pd

In [12]:
import ast

In [2]:
!ls ../../data/propedia/parsing

binding_regions.csv binding_regions.py


In [18]:
import torch

In [37]:
from sklearn.model_selection import train_test_split

In [16]:
indices = ast.literal_eval(regions.iloc[0]['binding_region_indices'])

In [3]:
regions = pd.read_csv("../../data/propedia/parsing/binding_regions.csv")

In [4]:
lengths = regions.peptide_seq.str.len() 

In [32]:
len(lengths[lengths < 50]) / len(lengths)

0.9968171145316984

In [21]:
err_row = regions[regions['propedia_id'] == '6ghl_E_A']

In [23]:
err_row['partner_seq'].str.len()

8433    338
Name: partner_seq, dtype: int64

In [24]:
err_row['peptide_seq'].str.len()

8433    21
Name: peptide_seq, dtype: int64

In [22]:
err_row['partner_seq'].str.len() + err_row['peptide_seq'].str.len()


8433    359
dtype: int64

tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False, False, False,  True,  True,  True],
        [False, False, False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False, False, False]])

In [33]:
lengths = regions['partner_seq'].str.len() + 50
len(lengths[lengths < 1020]) / len(lengths)

0.9997391077484998

In [1]:
import numpy as np

In [35]:
import torch

In [314]:
import esm

In [389]:
esm_dim = 768

In [391]:
test_esm = torch.tensor(np.load("DLX5_embedding.npy"))[:, :, :768]

In [248]:
class PositionalEmbedding():
    def __init__(self, dim=esm_dim, max_seq_length=1024):        
        # positional embeddings as defined in Attention is All You Need
        positions = torch.arange(max_peptide, dtype=torch.float)
        freqs = 1 / torch.pow(10000, torch.arange(0, dim, 2) / dim)

        self.pos_embeddings = torch.zeros(max_peptide, dim)
        trig_arguments = torch.matmul(positions.unsqueeze(-1), torch.t(freqs.unsqueeze(-1)))
        self.pos_embeddings[:, torch.arange(0, dim, 2)] = torch.sin(trig_arguments)
        self.pos_embeddings[:, torch.arange(1, dim, 2)] = torch.cos(trig_arguments)
    
    def __call__(self, size):
        return self.pos_embeddings[:size, :]
    
    

In [438]:
def get_alphabet():
    alphabet = esm.Alphabet.from_architecture("msa_transformer")
    alphabet.prepend_bos = False
    alphabet.append_eos = True
    alphabet.use_msa = False
    return alphabet
    

In [439]:
alphabet = get_alphabet()

In [440]:
# TODO: make pad, unk all -100 in dataloader

In [441]:
config = {
    "esm_dim": esm_dim,
    "alphabet": get_alphabet(),
    "lr": 1e-3,
    "hidden_dim": 40, # dim used in transformer decoder
    "num_heads": 4, 
    "num_layers": 2,
    "dim_feedforward": 80, # d_model * 2?
    "layer_norm_eps": 1e-5,
    "dropout": 0.1,
    "batch_first": True,
    "norm_first": True,
    "use_esm_token_embedding": True,
    "esm_token_embedding_path": "esm_token_embedding.pt",
    "positional_embed_peptide_only": True
}

In [684]:
test_input['peptide_labels'].shape

torch.Size([1, 8])

In [395]:
test_input = {
    'peptide_input': torch.tensor([2, 4, 4, 5, 3, 0, 0, 0]).reshape(1, -1),
    'partner_embedding': test_esm,
    'peptide_labels': torch.tensor([2, 4, -100, -100, 3, 0, 0, 0]).reshape(1, -1),
}

In [396]:
pe = PeptideEmbedding()

In [682]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence
import pytorch_lightning as pl
import pickle
from math import floor

In [690]:

class PepGenGPT(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        self.sep_idx = self.config["alphabet"].get_idx("<sep>")
        self.eos_idx = self.config["alphabet"].get_idx("<eos>")
        self.pad_idx = self.config["alphabet"].get_idx("<pad>")
        
        if self.config['use_esm_token_embedding']:
            self.peptide_embedder = nn.Embedding(
                len(config['alphabet']),
                self.config['esm_dim'],
                padding_idx=self.pad_idx
            )
            self.peptide_embedder.load_state_dict(torch.load(self.config["esm_token_embedding_path"]))
            
            # load esm embedder
            self.peptide_embedding_map = nn.Linear(
                self.config['esm_dim'], 
                self.config['hidden_dim']
            )
        else:
            self.peptide_embedder = nn.Embedding(
                len(config['alphabet']),
                self.config['hidden_dim'],
                padding_idx=self.pad_idx
            )
        
        # reduce dimensionality of input embedding
        self.esm_embedding_map = nn.Linear(self.config['esm_dim'], self.config['hidden_dim'])
        
        self.positional_embedding = PositionalEmbedding(dim=self.config['hidden_dim'])
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.config['hidden_dim'],
            nhead=self.config['num_heads'],
            dim_feedforward=self.config['dim_feedforward'],
            dropout=self.config['dropout'],
            layer_norm_eps=self.config['layer_norm_eps'],
            batch_first=self.config['batch_first'],
            norm_first=self.config["norm_first"]
        )
        output_norm = nn.LayerNorm(self.config['hidden_dim'], eps=self.config['layer_norm_eps'])
        self.encoder = nn.TransformerEncoder(encoder_layer, self.config['num_layers'], output_norm)
        self.logit_map = nn.Linear(self.config['hidden_dim'], len(config['alphabet']))
    
    def embed_peptide(self, peptide_input):
        embedding = self.peptide_embedder(peptide_input)
        if self.config['use_esm_token_embedding']:
            embedding = self.peptide_embedding_map(embedding)
        
        return embedding
    
    def get_causal_mask(self, partner_length, peptide_length):
        total_length = partner_length + peptide_length
        mask = torch.ones(total_length, total_length, dtype=torch.bool)
        mask[:, :partner_length] = False
        mask[-peptide_length:, -peptide_length:] = ~torch.tril(torch.ones(peptide_length, peptide_length, dtype=torch.bool))
    
    def forward(self, peptide_input, raw_partner_embedding, padding_mask):
        peptide_length = peptide_input.shape[1]
        
        peptide_embedding = self.embed_peptide(peptide_input)
        partner_embedding = self.esm_embedding_map(raw_partner_embedding)
        
        x = torch.cat([partner_embedding, peptide_embedding], dim=1)

        partner_length = x.shape[1] - peptide_length
        if self.config['positional_embed_peptide_only']:
            x[:, -peptide_length:, :] += self.positional_embedding(peptide_length).unsqueeze(0)
        else:
            x += self.positional_embedding(x.shape[1]).unsqueeze(0)
        
        causal_mask = self.get_causal_mask(partner_length, peptide_length)
        
        x = self.encoder(
            src=x,
            mask=causal_mask,
            src_key_padding_mask=padding_mask,
        )
        
        x = self.logit_map(x)
        
        # we only care about the peptide predictions
        x = x[:, -peptide_length:, :]
        return x
            
    def predict(self, partner_embedding, max_length=50):        
        partner_embedding = self.esm_embedding_map(partner_embedding)
        peptide_seq = torch.tensor([self.sep_idx]).reshape(1, -1)
        
        for i in range(max_length):
            total_length = partner_embedding.shape[1] + peptide_seq.shape[1]
            padding_mask = torch.zeros(1, total_length, dtype=torch.bool)
            logits = self(peptide_seq, partner_embedding, padding_mask)
            next_token = logits.topk(1, dim=2).indices[0, -1, 0]
            peptide_seq = torch.cat([peptide_seq, next_token.reshape(1, -1)], dim=1)
            if next_token.item() == self.eos_idx:
                break

        return peptide_seq
    
    def training_step(self, batch, batch_idx):
        logits = self(batch['peptide_input'], batch['partner_embedding'], batch['padding_mask'])
                
        # B x L x C -> B x C x L
        # permute to agree with cross-entropy dims
        logits = logits.permute(0, 2, 1)
        
        loss = F.cross_entropy(logits, batch['peptide_labels'])
        
        return loss

    def validation_step(self, batch, batch_idx: int) -> None:
        logits = self(batch['peptide_input'], batch['partner_embedding'], batch['padding_mask'])

        # B x L x C -> B x C x L
        # permute to agree with cross-entropy dims
        logits = logits.permute(0, 2, 1)
        pred = logits.argmax(dim=1, keepdim=True)

        loss = F.cross_entropy(logits, batch['peptide_labels'])

        accuracy = pred.eq(target.view_as(pred)).float().mean()
        self.log("val_acc", accuracy)
        self.log("hp_metric", accuracy, on_step=False, on_epoch=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.config['lr'])
        return optimizer

In [543]:
regions = pd.read_csv("../../parsing/propedia/binding_regions.csv")


In [None]:
propedia_id

In [549]:
class TestEmbeddingDict():
    def __init__(self):
        pass
    def __getitem__(self, index):
        return test_esm

In [674]:
class PartnerPeptideDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe, partner_embedding_dict):
        self.dataframe = dataframe[dataframe["propedia_id"].isin(partner_embedding_dict.keys())]
        self.dataframe = self.dataframe.reset_index()
        self.partner_embedding_dict = partner_embedding_dict

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        # Select row
        row = self.dataframe.iloc[index]
        propedia_id = row["propedia_id"]
        
        return {
            "propedia_id": propedia_id,
            "peptide_seq": row["peptide_seq"],
            "partner_embedding": self.partner_embedding_dict[propedia_id]
        } 

In [675]:
test_dict = {"3p3n_B_A": test_esm, "3si4_I_H": test_esm}

In [676]:
dataset = PartnerPeptideDataset(regions, test_dict)

In [677]:
class PartnerPeptideCollator:
    def __init__(self, alphabet):
        self.alphabet = alphabet
        self.batch_converter = alphabet.get_batch_converter()
        self.sep_idx = self.alphabet.get_idx("<sep>")
        self.eos_idx = self.alphabet.get_idx("<eos>")
        self.pad_idx = self.alphabet.get_idx("<pad>")
    
    def __call__(self, raw_batch):
        n_seq = len(raw_batch)
        batch = {}
        _, _, peptide_tokens = self.batch_converter([(d['propedia_id'], d['peptide_seq']) for d in raw_batch])
        peptide_tokens = torch.cat([torch.full((n_seq, 1), self.sep_idx), peptide_tokens], dim=1)
        batch['peptide_input'] = peptide_tokens[:, :-1]
        batch['peptide_labels'] = peptide_tokens[:, 1:]
        
        batch['partner_embedding'] = pad_sequence(
            [d['partner_embedding'][0, :, :] for d in raw_batch],
            batch_first=True,
            padding_value=0
        )
        
        partner_raw_lengths = [d['partner_embedding'].shape[1] for d in raw_batch]
         
        batch['padding_mask'] = self.get_padding_mask(
            batch['partner_embedding'].shape[1],
            partner_raw_lengths,
            batch['peptide_input']
        )
        return batch
        
    def get_padding_mask(self, partner_batch_length, partner_raw_lengths, peptide_input):
        batch_size = peptide_input.shape[0]
        peptide_length = peptide_input.shape[1]
        total_length = partner_length + peptide_length
        
        print(peptide_input.shape)
        
        mask = torch.ones(batch_size, total_length, dtype=torch.bool)
        
        mask[:, -peptide_length:] = peptide_input == self.pad_idx
        for i, length in enumerate(partner_raw_lengths):
            mask[i, :length] = False
        

In [683]:
class PartnerPeptideDataModule(pl.LightningDataModule):
    def __init__(self,
                 regions_csv_path,
                 partner_embeddings_path,
                 train_frac=0.8,
                 test_frac=0.1,
                 val_frac=0.1,
                 batch_size=4,
                 random_seed=42):
        super().__init__()
        
        self.regions_csv_path = regions_csv_path
        self.partner_embeddings_path = partner_embeddings_path
        self.train_frac = train_frac
        self.test_frac = test_frac
        self.val_frac = val_frac
        self.batch_size = batch_size
        self.random_seed = random_seed

    def setup(self, stage):
        regions = pd.read_csv(self.regions_csv_path)
        embeddings = pickle.load(self.partner_embeddings_path)
        
        raw_dataset = PartnerPeptideDataset(regions, embeddings)
        
        
        train_length = floor(train_frac * len(raw_dataset))
        test_length = floor(test_frac * len(raw_dataset))
        val_length = len(raw_dataset) - train_length - test_length
        
        splits = random_split(raw_dataset, 
                              [train_length, test_length, val_length], 
                              generator=torch.Generator().manual_seed(self.random_seed))
        
        self.train_dataset, self.test_dataset, self.val_dataset = splits
        
        self.alphabet = self.get_alphabet()
        self.collator = PartnerPeptideCollator(alphabet)
        
    def get_alphabet(self):
        alphabet = esm.Alphabet.from_architecture("msa_transformer")
        alphabet.prepend_bos = False
        alphabet.append_eos = True
        alphabet.use_msa = False
        return alphabet
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, collate_fn=self.collator)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self.collator)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, collate_fn=self.collator)

        

In [691]:
collator = PartnerPeptideCollator(alphabet)

In [692]:
gpt = PepGenGPT(config)

In [693]:
gpt.training_step(collator(test_raw_batch), 0)

torch.Size([2, 12])


tensor(3.7980, grad_fn=<NllLoss2DBackward0>)

In [694]:
batch = collator(test_raw_batch)

torch.Size([2, 12])


In [695]:
logits = gpt(batch['peptide_input'], batch['partner_embedding'], batch['padding_mask'])

In [697]:
logits.shape

torch.Size([2, 12, 33])

In [698]:
logits = logits.permute(0, 2, 1)


In [704]:
pred = logits.argmax(dim=1)


In [705]:
pred.shape

torch.Size([2, 12])

torch.Size([2, 12])

In [708]:
pred

tensor([[28, 28, 28, 28, 21, 21, 28, 28, 28, 28, 23, 21],
        [28, 12, 28, 28, 23, 28, 28, 28, 28, 28, 21, 28]])

In [709]:
batch['peptide_labels']

tensor([[10,  4,  4,  9,  5,  8,  5, 13,  5, 17,  2,  1],
        [13, 18,  9,  9, 12, 14,  9,  9, 24,  4, 16,  2]])

In [707]:
pred.eq(batch['peptide_labels']).float().mean()

tensor(0.)