### **Make a new dataset**
- target seq
- binder seq
- motif seq
- cluster

In [None]:
from transformers import T5Tokenizer, T5EncoderModel
import torch
import re
!pip install sentencepiece
import sentencepiece
import torch
from torch import nn
from transformers import T5ForConditionalGeneration, T5Tokenizer
from torch.utils.data import DataLoader, Dataset

Collecting sentencepiece
  Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: sentencepiece
Successfully installed sentencepiece-0.1.99


In [None]:
import os
os.chdir('/content/drive/MyDrive/Programmable Biology Group/Srikar/Code/proteins/flamingo-ppi-gen/data_dump/per-residue-dataset')

In [None]:
import pandas as pd
import re

def preprocess_snp_data(file_path):
    # Read the dataset
    snp_df = pd.read_csv(file_path)

    # Function to transform energy scores
    def transform_energy_scores(energy_scores):
        transformed_scores = []
        for score in energy_scores:
            # Replace sequences of spaces/newlines with a comma
            score = re.sub(r'[\s\n]+', ',', score)
            # Remove a comma after an opening square bracket
            score = re.sub(r'\[\s*,', '[', score)
            # Remove leading commas/whitespace
            score = re.sub(r'^[\s,]+', '', score)
            transformed_scores.append(score)
        return transformed_scores

    # Apply transformations
    snp_df['energy_scores'] = transform_energy_scores(snp_df['energy_scores'])
    snp_df['energy_scores_lengths'] = snp_df['energy_scores'].apply(
        lambda x: x.count(',') + 1 - (1 if x.startswith(',') else 0)
    )

    # Calculate lengths for other columns
    snp_df['peptide_source_RCSB_lengths'] = snp_df['peptide_source_RCSB'].apply(len)
    snp_df['protein_RCSB_lengths'] = snp_df['protein_RCSB'].apply(len)
    snp_df['protein_derived_seq_length'] = snp_df['protein_derived_sequence'].apply(len)
    snp_df['peptide_derived_seq_length'] = snp_df['peptide_derived_sequence'].apply(len)

    # Calculate matching lengths count (optional, depending on your needs)
    snp_df['matching_lengths_count'] = (snp_df['energy_scores_lengths'] == snp_df['peptide_derived_seq_length']).sum()

    return snp_df

# Applying the preprocessing pipeline to each dataset
test_snp = preprocess_snp_data('testing_dataset.csv')
train_snp = preprocess_snp_data('training_dataset.csv')
val_snp = preprocess_snp_data('validation_dataset.csv')


In [None]:
unique_seqs = pd.concat([train_snp['peptide_derived_sequence'], train_snp['protein_derived_sequence'],
                         test_snp['peptide_derived_sequence'], test_snp['protein_derived_sequence'],
                         val_snp['peptide_derived_sequence'], val_snp['protein_derived_sequence']]).unique()

In [None]:
max_length = max(len(seq) for seq in unique_seqs)
print(max_length)

984


In [None]:
import torch
import re
import pickle
from torch.utils.data import Dataset
from torch.nn.functional import pad

class ProteinInteractionDataset(Dataset):
    def __init__(self, dataframe):
        # self.protT5_model = T5ForConditionalGeneration.from_pretrained("Rostlab/prot_t5_xl_bfd")
        # self.protT5_tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_bfd")
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        peptide_seq = self.dataframe.iloc[idx]['peptide_derived_sequence']
        protein_seq = self.dataframe.iloc[idx]['protein_derived_sequence']
        energy_scores = self.dataframe.iloc[idx]['energy_scores']

        max_length = 984

        # Process the energy_scores
        energy_scores = re.findall(r'-?\d+\.?\d*(?:e[-+]?\d+)?', energy_scores)
        energy_scores = [float(score) for score in energy_scores]
        energy_scores = one_hot_encode_energy_scores(energy_scores)
        # energy_scores_padded = pad(torch.tensor(energy_scores), (0, max_length - len(energy_scores)), "constant", 0)

        return torch.tensor(energy_scores), peptide_seq, protein_seq

def one_hot_encode_energy_scores(scores):
        # Assuming 'scores' is a list of energy score values
        return [1 if score <= -1 else 0 for score in scores]

In [None]:
# Create datasets with tokenizer
train_dataset = ProteinInteractionDataset(train_snp)
test_dataset = ProteinInteractionDataset(test_snp)
val_dataset = ProteinInteractionDataset(val_snp)


In [None]:
from torch.utils.data import DataLoader

train_batch_size = 1
test_batch_size = 1
val_batch_size = 1

# Create the DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=test_batch_size)
val_dataloader = DataLoader(val_dataset, batch_size=val_batch_size)


### **Background Functions**

General

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import torch.nn.functional as F
# from transformers import RobertaModel  # Assuming use of Hugging Face's transformer models

# Helper Functions
def exists(val):
    return val is not None

def set_module_requires_grad_(module, requires_grad):
    for param in module.parameters():
        param.requires_grad = requires_grad

def freeze_model_and_make_eval_(model):
    model.eval()
    set_module_requires_grad_(model, False)

# LayerNorm class
class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super().__init__()
        self.gain = nn.Parameter(torch.ones(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.gain * (x - mean) / (std + self.eps)

# Residual class
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

# SwiGLU activation function
class SwiGLU(nn.Module):
    def forward(self, x):
        return F.silu(x[..., :x.shape[-1] // 2]) * x[..., x.shape[-1] // 2:]


class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        return self.norm(x)

class SwiGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return F.silu(gate) * x



Main Perciever+Cross Attn

In [None]:
!pip install einops
!pip install einops-exts

Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m798.2 kB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.7.0
Collecting einops-exts
  Downloading einops_exts-0.0.4-py3-none-any.whl (3.9 kB)
Installing collected packages: einops-exts
Successfully installed einops-exts-0.0.4


In [None]:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:4096'

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops_exts import rearrange_many, repeat_many

def exists(val):
    return val is not None

def FeedForward(dim, mult = 4):
    inner_dim = int(dim * mult)
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, inner_dim, bias = False),
        nn.GELU(),
        nn.Linear(inner_dim, dim, bias = False)
    )

class PerceiverAttention(nn.Module):
    def __init__(self, *, dim, concatenated_dim, dim_head=64, heads=8):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm_media = nn.LayerNorm(dim)
        self.norm_latents = nn.LayerNorm(dim)
        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)


    def forward(self, x, latents):
        x = self.norm_media(x)
        latents = self.norm_latents(latents)

        # print('x shape perciever attn:', x.shape)
        # print('latents shape perceiver attn', latents.shape)

        q = self.to_q(latents)
        # print('q shape:',q.shape)

        q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads)
        q = q * self.scale


        kv_input = torch.cat((x, latents), dim=1)
        k, v = self.to_kv(kv_input).chunk(2, dim=-1)

        # print('k shape:',k.shape)
        # print('v shape:',v.shape)
        k = rearrange(k, 'b n (h d) -> b h n d', h=self.heads)
        v = rearrange(v, 'b n (h d) -> b h n d', h=self.heads)

        # print('rearrangement in perceiver cross attn complete...')
        # print('q shape:',q.shape)
        # print('k shape:',k.shape)
        # print('v shape:',v.shape)

        sim = einsum('... i d, ... j d -> ... i j', q, k)
        attn = sim.softmax(dim=-1)
        out = einsum('... i j, ... j d -> ... i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')

        return self.to_out(out)

class PerceiverResampler(nn.Module):
    def __init__(self, *, dim, depth, dim_head=64, heads=8, num_latents=64, concatenated_dim=1536):
        super().__init__()
        self.latents = nn.Parameter(torch.randn(num_latents, dim))
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PerceiverAttention(dim=dim, concatenated_dim=concatenated_dim, dim_head=dim_head, heads=heads),
                FeedForward(dim=dim)
            ]))

    def forward(self, x):
        latents = repeat(self.latents, 'n d -> b n d', b=x.shape[0])

        for attn, ff in self.layers:
            latents = attn(x, latents) + latents
            latents = ff(latents) + latents

        return latents

class MaskedCrossAttention(nn.Module):
    def __init__(self, *, dim, concatenated_dim=1536, dim_head=64, heads=8, only_attend_immediate_media=True):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm = nn.LayerNorm(dim)
        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)
        self.only_attend_immediate_media = only_attend_immediate_media

    def forward(self, x, media, media_locations=None):
        b, t, _ = x.shape
        _, m, _ = media.shape
        h = self.heads

        x = self.norm(x)
        q = self.to_q(x)
        q = rearrange(q, 'b n (h d) -> b h n d', h=h)

        # No need to reshape media as it's already 3D
        k, v = self.to_kv(media).chunk(2, dim=-1)
        k = rearrange(k, 'b n (h d) -> b h n d', h=h)
        v = rearrange(v, 'b n (h d) -> b h n d', h=h)

        q = q * self.scale
        sim = einsum('... i d, ... j d -> ... i j', q, k)

        if media_locations is not None:
            mask = media_locations.unsqueeze(1).unsqueeze(2)
            mask = rearrange(mask, 'b n -> b 1 n 1')
            sim = sim.masked_fill(mask == 0, float('-inf'))

        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)
        out = einsum('... i j, ... j d -> ... i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)', h=self.heads)

        return self.to_out(out)


class GatedCrossAttentionBlock(nn.Module):
    def __init__(self, *, dim, dim_head=64, heads=8, ff_mult=4, only_attend_immediate_media=True):
        super().__init__()
        self.attn = MaskedCrossAttention(dim=dim, concatenated_dim=1536, dim_head=dim_head, heads=heads, only_attend_immediate_media=only_attend_immediate_media)
        self.attn_gate = nn.Parameter(torch.tensor([0.]))
        self.ff = FeedForward(dim, mult=ff_mult)
        self.ff_gate = nn.Parameter(torch.tensor([0.]))

    def forward(self, x, media, media_locations=None):
        gate = self.attn_gate.tanh()
        x = self.attn(x, media, media_locations=media_locations) * gate + x
        x = self.ff(x) * self.ff_gate.tanh() + x
        return x



### **ProtFlamingo**

In [None]:
from transformers import GPT2Tokenizer

# Load the tokenizer for the ProtGPT2 model from Hugging Face
tokenizer = GPT2Tokenizer.from_pretrained('nferruz/ProtGPT2')

# Get the number of tokens in the tokenizer
num_tokens = len(tokenizer)

print(f"The ProtGPT2 model has {num_tokens} tokens in its vocabulary.")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


vocab.json:   0%|          | 0.00/655k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/314k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/357 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.07M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/850 [00:00<?, ?B/s]

The ProtGPT2 model has 50257 tokens in its vocabulary.


In [None]:
import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer

class ProtFlamingo(nn.Module):
    def __init__(self, model_path, cross_attn_every=1, dim_head=64, heads=8, motif_embedding_dim=1280, perceiver_depth=2, perceiver_num_latents=64):
        super().__init__()

        # Load ProtGPT2 model
        self.protGPT2_model = GPT2LMHeadModel.from_pretrained(model_path)
        self.protGPT2_tokenizer = GPT2Tokenizer.from_pretrained(model_path)

        # Create an embedding layer for one-hot motifs to learn positional embeddings
        self.motif_embedding = nn.Embedding(num_embeddings=motif_embedding_dim, embedding_dim=self.protGPT2_model.config.n_embd)

        # Define Perceiver Resampler
        self.perceiver_resampler = PerceiverResampler(dim=self.protGPT2_model.config.n_embd, depth=perceiver_depth, dim_head=dim_head, heads=heads, num_latents=perceiver_num_latents)

        # Access the decoder blocks from ProtGPT2 model
        self.decoder_blocks = self.protGPT2_model.transformer.h

        # Intersperse GatedCrossAttentionBlocks within the GPT2 decoder blocks
        self.layers = nn.ModuleList([])
        for i, block in enumerate(self.decoder_blocks):
            self.layers.append(block)
            if i % cross_attn_every == 0 and i != 0:
                self.layers.append(GatedCrossAttentionBlock(dim=self.protGPT2_model.config.n_embd, dim_head=dim_head, heads=heads))

    def forward(self, target_seqs, one_hot_motifs):
        # Tokenize target sequences
        inputs = self.protGPT2_tokenizer(target_seqs, return_tensors="pt")
        input_ids = inputs['input_ids'].to(device)
        attention_mask = inputs['attention_mask'].to(device)

        # Obtain embeddings from GPT2
        embeddings = self.protGPT2_model.transformer.wte(input_ids)
        # print('embeddings shape')
        # print(embeddings.shape)

        # Convert one-hot motifs to learned positional embeddings
        one_hot_motif_embeddings = self.motif_embedding(one_hot_motifs)
        # print('one hot motif shape')
        # print(one_hot_motif_embeddings.shape)

        # Process through Perceiver Resampler
        processed_motif_embeddings = self.perceiver_resampler(one_hot_motif_embeddings)
        # print(processed_motif_embeddings.shape)

        # Pass through layers (GPT2 Blocks and GatedCrossAttentionBlocks)
        for layer in self.layers:
            # print(layer)
            if isinstance(layer, GatedCrossAttentionBlock):
                embeddings = layer(embeddings, processed_motif_embeddings)
            else:
                layer_outputs = layer(embeddings, attention_mask=attention_mask)
                embeddings = layer_outputs[0]

        # Final logits

        logits = self.protGPT2_model.lm_head(embeddings)
        # print(logits)
        # print(logits.shape)
        #print(self.protGPT2_model)
        return logits



### **Initialize Model**

In [None]:
# Parameters for initializing the model
cross_attn_every = 3  # Intersperse a GatedCrossAttentionBlock after every 3 GPT2 blocks
dim_head = 64         # Dimension of each head in multi-head attention
heads = 8             # Number of heads in multi-head attention
perceiver_num_latents = 64  # Number of latents in PerceiverResampler
perceiver_depth = 2        # Depth of the PerceiverResampler

# Initialize the ProtFlamingo model
model = ProtFlamingo(model_path='nferruz/ProtGPT2',
                             cross_attn_every=cross_attn_every,
                             dim_head=dim_head,
                             heads=heads,
                             perceiver_num_latents=perceiver_num_latents,
                             perceiver_depth=perceiver_depth)

# If using a GPU, move the model to GPU
if torch.cuda.is_available():
    model = model.cuda()

import torch

def init_weights(m):
    if type(m) == torch.nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0.01)

model.apply(init_weights)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

pytorch_model.bin:   0%|          | 0.00/3.13G [00:00<?, ?B/s]

In [None]:
tokenizer

GPT2Tokenizer(name_or_path='nferruz/ProtGPT2', vocab_size=50257, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}

### **Train Loop**

In [None]:
import torch
import torch.nn as nn

def train_epoch(model, data_loader, optimizer, device, clip_value=1.0):
    model.train()
    total_loss = 0
    total_batches = 0

    for batch in data_loader:
        one_hot_motifs, target_seqs_tuple, binder_seqs_tuple = batch
        target_seqs = target_seqs_tuple[0]
        binder_seqs = binder_seqs_tuple[0]

        # print(target_seqs)
        # print(len(target_seqs))
        # print(binder_seqs)
        # print(len(binder_seqs))

        # Custom character-level tokenization
        binder_inputs = tokenizer(target_seqs, return_tensors="pt")
        binder_input_ids = binder_inputs['input_ids'].to(device)
        binder_attention_mask = binder_inputs['attention_mask'].to(device)

        # Move data to the device
        binder_input_ids = binder_input_ids.to(device)
        # print(binder_input_ids)
        # print(len(binder_input_ids))

        one_hot_motifs = one_hot_motifs.to(device)
        # print(one_hot_motifs)
        # print(len(one_hot_motifs))

        optimizer.zero_grad()

        # Forward pass
        outputs = model(target_seqs, one_hot_motifs)

        # Loss calculation
        # print("_______________________")
        # print('INPUTS FOR LOSS CALCULATIONS')
        # print(outputs.shape)
        # print(binder_input_ids.shape)


        # print('ce loss calc begins...')
        # print('outputs shape:',outputs.view(-1, outputs.size(-1)).shape) #984,128
        # print('binder input ids shape:',binder_input_ids.view(-1).shape) #984
        # loss = nn.CrossEntropyLoss()(logits,targets)
        loss = nn.CrossEntropyLoss()(outputs.view(-1, outputs.size(-1)), binder_input_ids.view(-1))
        total_loss += loss.item()
        print(loss)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
        optimizer.step()

        total_batches += 1

    average_loss = total_loss / total_batches
    print(average_loss)
    return average_loss

# Example usage
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

num_epochs = 5
for epoch in range(num_epochs):
    train_loss = train_epoch(model, train_dataloader, optimizer, device)
    print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {train_loss:.4f}")


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
tensor(19.5016, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(5.7615, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(14.0471, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0., device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(6.6839, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0., device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0., device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(8.4813, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.5723, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0., device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0., device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.9544, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(8.0467, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(13.6778, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0., device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0., device='cuda:0', grad_fn=<NllLossBackward0>)


In [None]:
epoch

4

In [None]:
# save model
torch.save(model.state_dict(), '/content/drive/MyDrive/Programmable Biology Group/Srikar/Code/proteins/flamingo-ppi-gen/flamingo-gpt2-v1.pth')

In [None]:
train_loss

6.051890324297871