### **Make a new dataset**
- target seq
- binder seq
- motif seq
- cluster

In [1]:
from transformers import T5Tokenizer, T5EncoderModel
import torch
import re
!pip install sentencepiece
import sentencepiece
import torch
from torch import nn
from transformers import T5ForConditionalGeneration, T5Tokenizer
from torch.utils.data import DataLoader, Dataset



In [2]:
import os
os.chdir('/content/drive/MyDrive/Programmable Biology Group/Srikar/Code/proteins/flamingo-ppi-gen/data_dump/per-residue-dataset/')

In [3]:
# Load tokenizer and model
tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)
model = model.half() if device.type == 'cuda' else model.full()

from tqdm import tqdm

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [4]:

from tqdm import tqdm

def generate_protT5_tokens(sequences):
    tokens_dict = {}
    # Wrap the sequence iteration with tqdm for a progress bar
    for sequence in tqdm(sequences, desc="Generating..."):
        # Process sequence
        #print(len(sequence))
        processed_seq = " ".join(list(re.sub(r"[UZOB]", "X", sequence)))
        #print(len(processed_seq))
        # Tokenize and encode
        ids = tokenizer(processed_seq, add_special_tokens=True, return_tensors="pt",padding='longest')
        input_ids = ids['input_ids'].to(device)
        tokens_dict[sequence] = input_ids.squeeze().tolist()
    return tokens_dict


In [5]:
import pandas as pd
import re

def preprocess_snp_data(file_path):
    # Read the dataset
    snp_df = pd.read_csv(file_path)

    # Function to transform energy scores
    def transform_energy_scores(energy_scores):
        transformed_scores = []
        for score in energy_scores:
            # Replace sequences of spaces/newlines with a comma
            score = re.sub(r'[\s\n]+', ',', score)
            # Remove a comma after an opening square bracket
            score = re.sub(r'\[\s*,', '[', score)
            # Remove leading commas/whitespace
            score = re.sub(r'^[\s,]+', '', score)
            transformed_scores.append(score)
        return transformed_scores

    # Apply transformations
    snp_df['energy_scores'] = transform_energy_scores(snp_df['energy_scores'])
    snp_df['energy_scores_lengths'] = snp_df['energy_scores'].apply(
        lambda x: x.count(',') + 1 - (1 if x.startswith(',') else 0)
    )

    # Calculate lengths for other columns
    snp_df['peptide_source_RCSB_lengths'] = snp_df['peptide_source_RCSB'].apply(len)
    snp_df['protein_RCSB_lengths'] = snp_df['protein_RCSB'].apply(len)
    snp_df['protein_derived_seq_length'] = snp_df['protein_derived_sequence'].apply(len)
    snp_df['peptide_derived_seq_length'] = snp_df['peptide_derived_sequence'].apply(len)

    # Calculate matching lengths count (optional, depending on your needs)
    snp_df['matching_lengths_count'] = (snp_df['energy_scores_lengths'] == snp_df['peptide_derived_seq_length']).sum()

    return snp_df

# Applying the preprocessing pipeline to each dataset
test_snp = preprocess_snp_data('testing_dataset.csv')
train_snp = preprocess_snp_data('training_dataset.csv')
val_snp = preprocess_snp_data('validation_dataset.csv')


In [6]:

# Load tokenizer and model
tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)
# model = model.half() if device.type == 'cuda' else model.full()


T5EncoderModel(
  (shared): Embedding(128, 1024)
  (encoder): T5Stack(
    (embed_tokens): Embedding(128, 1024)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=1024, out_features=4096, bias=False)
              (k): Linear(in_features=1024, out_features=4096, bias=False)
              (v): Linear(in_features=1024, out_features=4096, bias=False)
              (o): Linear(in_features=4096, out_features=1024, bias=False)
              (relative_attention_bias): Embedding(32, 32)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=1024, out_features=16384, bias=False)
              (wo): Linear(in_features=16384, out_features=1024, bias=False)
              (dropout): Dropo

In [7]:
unique_seqs = pd.concat([train_snp['peptide_derived_sequence'], train_snp['protein_derived_sequence'],
                         test_snp['peptide_derived_sequence'], test_snp['protein_derived_sequence'],
                         val_snp['peptide_derived_sequence'], val_snp['protein_derived_sequence']]).unique()

protT5_tokens = generate_protT5_tokens(unique_seqs)

Generating...: 100%|██████████| 9665/9665 [00:10<00:00, 915.49it/s] 


In [8]:
import torch
import re
import pickle
from torch.utils.data import Dataset
from torch.nn.functional import pad

class ProteinInteractionDataset(Dataset):
    def __init__(self, dataframe, protT5_tokens):
        self.protT5_model = T5ForConditionalGeneration.from_pretrained("Rostlab/prot_t5_xl_bfd")
        self.protT5_tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_bfd")
        self.dataframe = dataframe
        self.protT5_tokens = protT5_tokens
        # Determine the maximum length
        self.max_length_tokenized = max(len(self.protT5_tokens[seq]) for seq in dataframe['peptide_derived_sequence'])

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        peptide_seq = self.dataframe.iloc[idx]['peptide_derived_sequence']
        protein_seq = self.dataframe.iloc[idx]['protein_derived_sequence']
        energy_scores = self.dataframe.iloc[idx]['energy_scores']
        clusters = self.dataframe.iloc[idx]['cluster']

        tokenized_peptide_seq = self.protT5_tokens[peptide_seq]
        tokenized_protein_seq = self.protT5_tokens[protein_seq]

        # Process the energy_scores
        energy_scores = re.findall(r'-?\d+\.?\d*(?:e[-+]?\d+)?', energy_scores)
        energy_scores = [float(score) for score in energy_scores]
        energy_scores = one_hot_encode_energy_scores(energy_scores)
        # Pad the energy scores -- max length of all should be equal
        energy_scores_padded = pad(torch.tensor(energy_scores), (0, self.max_length_tokenized - len(energy_scores)), "constant", 0)
        # Pad the sequences
        tokenized_peptide_seq_padded = pad(torch.tensor(tokenized_peptide_seq, dtype=torch.long), (0, self.max_length_tokenized - len(tokenized_peptide_seq)), "constant", 0)
        tokenized_protein_seq_padded = pad(torch.tensor(tokenized_protein_seq, dtype=torch.long), (0, self.max_length_tokenized - len(tokenized_protein_seq)), "constant", 0)
        return energy_scores_padded, tokenized_peptide_seq_padded, tokenized_protein_seq_padded, clusters

def one_hot_encode_energy_scores(scores):
        # Assuming 'scores' is a list of energy score values
        return [1 if score <= -1 else 0 for score in scores]

In [9]:
# Create datasets with tokenizer
train_dataset = ProteinInteractionDataset(train_snp, protT5_tokens)
test_dataset = ProteinInteractionDataset(test_snp, protT5_tokens)
val_dataset = ProteinInteractionDataset(val_snp, protT5_tokens)


In [10]:
from torch.utils.data import DataLoader

train_batch_size = 2
test_batch_size = 2
val_batch_size = 2

# Create the DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=test_batch_size)
val_dataloader = DataLoader(val_dataset, batch_size=val_batch_size)


### **Background Functions**

General

In [11]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import torch.nn.functional as F
# from transformers import RobertaModel  # Assuming use of Hugging Face's transformer models

# Helper Functions
def exists(val):
    return val is not None

def set_module_requires_grad_(module, requires_grad):
    for param in module.parameters():
        param.requires_grad = requires_grad

def freeze_model_and_make_eval_(model):
    model.eval()
    set_module_requires_grad_(model, False)

# LayerNorm class
class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super().__init__()
        self.gain = nn.Parameter(torch.ones(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.gain * (x - mean) / (std + self.eps)

# Residual class
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

# SwiGLU activation function
class SwiGLU(nn.Module):
    def forward(self, x):
        return F.silu(x[..., :x.shape[-1] // 2]) * x[..., x.shape[-1] // 2:]


class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        return self.norm(x)

class SwiGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return F.silu(gate) * x



Main Perciever+Cross Attn

In [12]:
!pip install einops
!pip install einops-exts



In [13]:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:4096'

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops_exts import rearrange_many, repeat_many

def exists(val):
    return val is not None

def FeedForward(dim, mult = 4):
    inner_dim = int(dim * mult)
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, inner_dim, bias = False),
        nn.GELU(),
        nn.Linear(inner_dim, dim, bias = False)
    )

class PerceiverAttention(nn.Module):
    def __init__(self, *, dim, concatenated_dim, dim_head=64, heads=8):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm_media = nn.LayerNorm(dim)
        self.norm_latents = nn.LayerNorm(dim)
        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)


    def forward(self, x, latents):
        x = self.norm_media(x)
        latents = self.norm_latents(latents)

        # print('x shape perciever attn:', x.shape)
        # print('latents shape perceiver attn', latents.shape)

        q = self.to_q(latents)
        # print('q shape:',q.shape)

        q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads)
        q = q * self.scale


        kv_input = torch.cat((x, latents), dim=1)
        k, v = self.to_kv(kv_input).chunk(2, dim=-1)

        # print('k shape:',k.shape)
        # print('v shape:',v.shape)
        k = rearrange(k, 'b n (h d) -> b h n d', h=self.heads)
        v = rearrange(v, 'b n (h d) -> b h n d', h=self.heads)

        # print('rearrangement in perceiver cross attn complete...')
        # print('q shape:',q.shape)
        # print('k shape:',k.shape)
        # print('v shape:',v.shape)

        sim = einsum('... i d, ... j d -> ... i j', q, k)
        attn = sim.softmax(dim=-1)
        out = einsum('... i j, ... j d -> ... i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')

        return self.to_out(out)

class PerceiverResampler(nn.Module):
    def __init__(self, *, dim, depth, dim_head=64, heads=8, num_latents=64, concatenated_dim=2048):
        super().__init__()
        self.latents = nn.Parameter(torch.randn(num_latents, dim))
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PerceiverAttention(dim=dim, concatenated_dim=concatenated_dim, dim_head=dim_head, heads=heads),
                FeedForward(dim=dim)
            ]))

    def forward(self, x):
        latents = repeat(self.latents, 'n d -> b n d', b=x.shape[0])

        for attn, ff in self.layers:
            latents = attn(x, latents) + latents
            latents = ff(latents) + latents

        return latents

class MaskedCrossAttention(nn.Module):
    def __init__(self, *, dim, concatenated_dim=2048, dim_head=64, heads=8, only_attend_immediate_media=True):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm = nn.LayerNorm(dim)
        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)
        self.only_attend_immediate_media = only_attend_immediate_media

    def forward(self, x, media, media_locations=None):
        b, t, _ = x.shape
        _, m, _ = media.shape
        h = self.heads

        x = self.norm(x)
        q = self.to_q(x)
        q = rearrange(q, 'b n (h d) -> b h n d', h=h)

        # No need to reshape media as it's already 3D
        k, v = self.to_kv(media).chunk(2, dim=-1)
        k = rearrange(k, 'b n (h d) -> b h n d', h=h)
        v = rearrange(v, 'b n (h d) -> b h n d', h=h)

        q = q * self.scale
        sim = einsum('... i d, ... j d -> ... i j', q, k)

        if media_locations is not None:
            mask = media_locations.unsqueeze(1).unsqueeze(2)
            mask = rearrange(mask, 'b n -> b 1 n 1')
            sim = sim.masked_fill(mask == 0, float('-inf'))

        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)
        out = einsum('... i j, ... j d -> ... i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)', h=self.heads)

        return self.to_out(out)


class GatedCrossAttentionBlock(nn.Module):
    def __init__(self, *, dim, dim_head=64, heads=8, ff_mult=4, only_attend_immediate_media=True):
        super().__init__()
        self.attn = MaskedCrossAttention(dim=dim, concatenated_dim=2048, dim_head=dim_head, heads=heads, only_attend_immediate_media=only_attend_immediate_media)
        self.attn_gate = nn.Parameter(torch.tensor([0.]))
        self.ff = FeedForward(dim, mult=ff_mult)
        self.ff_gate = nn.Parameter(torch.tensor([0.]))

    def forward(self, x, media, media_locations=None):
        gate = self.attn_gate.tanh()
        x = self.attn(x, media, media_locations=media_locations) * gate + x
        x = self.ff(x) * self.ff_gate.tanh() + x
        return x



### **ProtFlamingo**
- input = tokenized target,binder & motif encoding
- protT5 embed tokenized AA seqs (text), motif emb (image)
- goal: complete binder seq (text completion)

In [14]:
import torch
import torch.nn as nn
from transformers import T5ForConditionalGeneration, T5Tokenizer

class ProtFlamingoLearnedEmbedding(nn.Module):
    def __init__(self, num_tokens, depth, dim_head=64, heads=8, ff_mult=4, cross_attn_every=3, perceiver_num_latents=64, perceiver_depth=2):
        super().__init__()
        self.motif_embedding_projection = nn.Embedding(2, 1024)  # Binary one-hot encoding to 1024 dimensions

        # Load ProtT5 model and tokenizer
        self.protT5_model = T5ForConditionalGeneration.from_pretrained("Rostlab/prot_t5_xl_bfd")
        self.protT5_tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_bfd")

        # Access the decoder blocks from ProtT5 model
        self.decoder_blocks = self.protT5_model.decoder.block

        # Intersperse GatedCrossAttentionBlocks within the T5 decoder blocks
        self.layers = nn.ModuleList([])
        for i, block in enumerate(self.decoder_blocks):
            self.layers.append(block)
            if i % cross_attn_every == 0 and i != 0:
                self.layers.append(GatedCrossAttentionBlock(dim=self.protT5_model.config.d_model, dim_head=dim_head, heads=heads))

        self.perceiver_resampler = PerceiverResampler(dim=self.protT5_model.config.d_model, depth=perceiver_depth, dim_head=dim_head, heads=heads, num_latents=perceiver_num_latents)
        self.expand_seq_len = nn.Linear(dim_head, 983)

    def forward(self, tokenized_target_seq, tokenized_binder_seq, motif_one_hot):
        motif_embeddings = self.motif_embedding_projection(motif_one_hot.long())
        target_embeddings = self.generate_protT5_embeddings(tokenized_target_seq)
        binder_embeddings = self.generate_protT5_embeddings(tokenized_binder_seq)
        processed_motif_embeddings = self.perceiver_resampler(motif_embeddings)

        # Pass through layers (T5Blocks and GatedCrossAttentionBlocks)
        for layer in self.layers:
            if isinstance(layer, T5Block):
                target_embeddings = layer(target_embeddings)
            elif isinstance(layer, GatedCrossAttentionBlock):
                target_embeddings = layer(target_embeddings, processed_motif_embeddings)

        expanded_sequence = self.expand_seq_len(target_embeddings)
        logits = self.protT5_model.lm_head(expanded_sequence)
        predicted_token_ids = logits.argmax(-1)

        return predicted_token_ids

    def generate_protT5_embeddings(self, tokenized_seqs):
        with torch.no_grad():
            outputs = self.protT5_model(input_ids=tokenized_seqs)
        embeddings = outputs.last_hidden_state[0, :len(tokenized_seqs)]
        print("Generated embeddings shape:", embeddings.shape)
        return embeddings


### **Initialize Model**

In [15]:
import torch

# Assuming 'model', 'train_dataloader', 'val_dataloader', 'test_dataloader', and 'criterion' are already defined
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model = model.half() if device.type == 'cuda' else model.full()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Example parameters
num_tokens = 28 # protT5 vocab size
depth = 3  # Adjust based on model complexity and computational resources

model = ProtFlamingoLearnedEmbedding(
    num_tokens=num_tokens,
    depth=depth,
    dim_head=64,
    heads=8,
    ff_mult=4,
    cross_attn_every=2,
    perceiver_num_latents=64,
    perceiver_depth=2
).to(device)



### **Train Loop**

In [30]:
import torch
import torch.nn as nn
from tqdm import tqdm

def train_epoch_ce(model, data_loader, optimizer, device):
    model.train()  # Set the model to training mode
    total_loss = 0
    for one_hot_motifs, tokenized_target_seq, tokenized_binder_seq, _ in tqdm(data_loader, desc="Training"):
        tokenized_target_seq = tokenized_target_seq.long().to(device)
        tokenized_binder_seq = tokenized_binder_seq.long().to(device)
        one_hot_motifs = one_hot_motifs.float().to(device)

        optimizer.zero_grad()
        model_output = model(tokenized_target_seq, tokenized_binder_seq, one_hot_motifs)
        loss = nn.CrossEntropyLoss()(model_output.view(-1, model_output.size(-1)), tokenized_binder_seq.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    average_loss = total_loss / len(data_loader)
    print(f"Training Loss: {average_loss:.4f}")
    return average_loss

def validate_epoch_ce(model, data_loader, device):
    model.eval()  # Set the model to evaluation mode
    total_loss = 0
    with torch.no_grad():
        for one_hot_motifs, tokenized_target_seq, tokenized_binder_seq, _ in tqdm(data_loader, desc="Validation"):
            tokenized_target_seq = tokenized_target_seq.to(device).long()
            tokenized_binder_seq = tokenized_binder_seq.to(device).long()
            one_hot_motifs = one_hot_motifs.to(device).float()

            model_output = model(tokenized_target_seq, tokenized_binder_seq, one_hot_motifs)
            loss = nn.CrossEntropyLoss()(model_output.view(-1, model_output.size(-1)), tokenized_binder_seq.view(-1))
            total_loss += loss.item()

    average_loss = total_loss / len(data_loader)
    print(f"Validation Loss: {average_loss:.4f}")
    return average_loss

def test_epoch_ce(model, data_loader, device):
    model.eval()  # Set the model to evaluation mode
    total_loss = 0
    with torch.no_grad():
        for one_hot_motifs, tokenized_target_seq, tokenized_binder_seq, _ in tqdm(data_loader, desc="Testing"):
            tokenized_target_seq = tokenized_target_seq.to(device).long()
            tokenized_binder_seq = tokenized_binder_seq.to(device).long()
            one_hot_motifs = one_hot_motifs.to(device).float()

            model_output = model(tokenized_target_seq, tokenized_binder_seq, one_hot_motifs)
            loss = nn.CrossEntropyLoss()(model_output.view(-1, model_output.size(-1)), tokenized_binder_seq.view(-1))
            total_loss += loss.item()

    average_loss = total_loss / len(data_loader)
    print(f"Test Loss: {average_loss:.4f}")
    return average_loss


In [31]:
import matplotlib.pyplot as plt

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Lists to store losses
train_losses = []
val_losses = []
test_losses = []

num_epochs = 1

# Training Loop
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")

    train_loss = train_epoch_ce(model, train_dataloader, optimizer, device)
    train_losses.append(train_loss)

    val_loss = validate_epoch_ce(model, val_dataloader, device)
    val_losses.append(val_loss)

    test_loss = test_epoch_ce(model, test_dataloader, device)
    test_losses.append(test_loss)

# Plotting
plt.figure(figsize=(10, 6))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.plot(test_losses, label='Test Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training, Validation, and Test Losses Over Epochs')
plt.legend()
plt.show()


Epoch 1/1


Training:   0%|          | 0/2500 [00:00<?, ?it/s]


RuntimeError: ignored