# SnP EDA + Data Preparation

## preprocessing SnP PPI data

In [None]:
from google.colab import drive
drive.mount('/content/drive')
import os

Mounted at /content/drive


In [None]:
os.chdir('/content/drive/MyDrive/Programmable Biology Group/Srikar/Code/proteins/flamingo-ppi-gen/data_dump/old_dat/')

In [None]:
!ls

'all_species_PPIs_famdiversity_06_11_2023 (1).csv'   ppi-snp.csv
 all_species_PPIs_famdiversity_06_11_2023.csv	     receptor_seqs_dict.pkl
 binder_seqs_dict.pkl				     testing_dataset.csv
 complex.csv					     testing_dataset.gsheet
 human-ppi-uniprot.csv				     training_dataset.csv
 peptide_complex_test_dataset.csv		     validation_dataset.csv
 peptide_complex_train_dataset.csv


In [None]:
import pandas as pd

In [None]:
test_snp = pd.read_csv('testing_dataset.csv')
train_snp = pd.read_csv('training_dataset.csv')
val_snp = pd.read_csv('validation_dataset.csv')

In [None]:
import pandas as pd
import re

def preprocess_snp_data(file_path):
    # Read the dataset
    snp_df = pd.read_csv(file_path)

    # Function to transform energy scores
    def transform_energy_scores(energy_scores):
        transformed_scores = []
        for score in energy_scores:
            # Replace sequences of spaces/newlines with a comma
            score = re.sub(r'[\s\n]+', ',', score)
            # Remove a comma after an opening square bracket
            score = re.sub(r'\[\s*,', '[', score)
            # Remove leading commas/whitespace
            score = re.sub(r'^[\s,]+', '', score)
            transformed_scores.append(score)
        return transformed_scores

    # Apply transformations
    snp_df['energy_scores'] = transform_energy_scores(snp_df['energy_scores'])
    snp_df['energy_scores_lengths'] = snp_df['energy_scores'].apply(
        lambda x: x.count(',') + 1 - (1 if x.startswith(',') else 0)
    )

    # Calculate lengths for other columns
    snp_df['peptide_source_RCSB_lengths'] = snp_df['peptide_source_RCSB'].apply(len)
    snp_df['protein_RCSB_lengths'] = snp_df['protein_RCSB'].apply(len)
    snp_df['protein_derived_seq_length'] = snp_df['protein_derived_sequence'].apply(len)
    snp_df['peptide_derived_seq_length'] = snp_df['peptide_derived_sequence'].apply(len)

    # Calculate matching lengths count (optional, depending on your needs)
    snp_df['matching_lengths_count'] = (snp_df['energy_scores_lengths'] == snp_df['peptide_derived_seq_length']).sum()

    return snp_df

# Applying the preprocessing pipeline to each dataset
test_snp = preprocess_snp_data('testing_dataset.csv')
train_snp = preprocess_snp_data('training_dataset.csv')
val_snp = preprocess_snp_data('validation_dataset.csv')


## dataset + dataloaders (includes 1-hot encoding of energy scores)

In [None]:
def one_hot_encode_energy_scores(scores):
    # Assuming 'scores' is a list of energy score values
    return [1 if score <= -1 else 0 for score in scores]

In [None]:
from torch.utils.data import DataLoader, Dataset

In [None]:
import re

class ProteinInteractionDataset(Dataset):
    def __init__(self, dataframe, protT5_embeddings, tokenizer):
        self.dataframe = dataframe
        self.protT5_embeddings = protT5_embeddings
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        peptide_seq = self.dataframe.iloc[idx]['peptide_derived_sequence']
        protein_seq = self.dataframe.iloc[idx]['protein_derived_sequence']
        energy_scores = self.dataframe.iloc[idx]['energy_scores']

        # Tokenize the protein_seq
        tokenized_protein_seq = self.tokenizer.encode(protein_seq, add_special_tokens=True, return_tensors="pt").squeeze()

        # Use regular expression to split the energy_scores string
        energy_scores = re.findall(r'-?\d+\.?\d*(?:e[-+]?\d+)?', energy_scores)

        # Convert the split strings to floats
        energy_scores = [float(score) for score in energy_scores]

        # One-hot encode the energy scores
        encoded_scores = one_hot_encode_energy_scores(energy_scores)

        peptide_embedding = self.protT5_embeddings[peptide_seq]
        protein_embedding = self.protT5_embeddings[protein_seq]

        return peptide_embedding, protein_embedding, torch.tensor(encoded_scores), tokenized_protein_seq


In [None]:
from transformers import T5Tokenizer, T5EncoderModel
import torch
import re
!pip install sentencepiece
import sentencepiece
import torch
from torch import nn
from transformers import T5ForConditionalGeneration, T5Tokenizer
from torch.utils.data import DataLoader, Dataset

Collecting sentencepiece
  Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: sentencepiece
Successfully installed sentencepiece-0.1.99


In [None]:

# Load tokenizer and model
tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)
# model = model.half() if device.type == 'cuda' else model.full()

from tqdm import tqdm

def generate_protT5_embeddings(sequences):
    embeddings_dict = {}
    # Wrap the sequence iteration with tqdm for a progress bar
    for sequence in tqdm(sequences, desc="Generating embeddings"):
        # Process sequence
        processed_seq = " ".join(list(re.sub(r"[UZOB]", "X", sequence)))
        # Tokenize and encode
        ids = tokenizer(processed_seq, add_special_tokens=True, return_tensors="pt")
        input_ids = ids['input_ids'].to(device)
        attention_mask = ids['attention_mask'].to(device)
        # Generate embeddings
        with torch.no_grad():
            embedding_repr = model(input_ids=input_ids, attention_mask=attention_mask)
        embeddings_dict[sequence] = embedding_repr.last_hidden_state.squeeze().mean(dim=0)
    return embeddings_dict


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/238k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/1.79k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/656 [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


pytorch_model.bin:   0%|          | 0.00/2.42G [00:00<?, ?B/s]

In [None]:
from torch.utils.data import DataLoader

# Generate ProtT5 embeddings
# Combine unique sequences from peptide and protein derived sequences
# unique_seqs = pd.concat([train_snp['peptide_derived_sequence'], train_snp['protein_derived_sequence'],
#                          test_snp['peptide_derived_sequence'], test_snp['protein_derived_sequence'],
#                          val_snp['peptide_derived_sequence'], val_snp['protein_derived_sequence']]).unique()
# protT5_embeddings = generate_protT5_embeddings(unique_seqs)


In [None]:
os.chdir('/content/drive/MyDrive/Programmable Biology Group/Srikar/Code/proteins/flamingo-ppi-gen/data_dump/')

In [None]:
import pickle

In [None]:
file_path = 'protT5_embeddings.pkl'
with open(file_path, 'rb') as file:
    protT5_embeddings = pickle.load(file)

In [None]:

# Create datasets with tokenizer
train_dataset = ProteinInteractionDataset(train_snp, protT5_embeddings, tokenizer)
test_dataset = ProteinInteractionDataset(test_snp, protT5_embeddings, tokenizer)
val_dataset = ProteinInteractionDataset(val_snp, protT5_embeddings, tokenizer)

# Create DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=2)
val_dataloader = DataLoader(val_dataset, batch_size=2)




## save as .pkl

In [None]:
# import pickle

# # protT5_embeddings is a dictionary of tensors
# with open('protT5_embeddings.pkl', 'wb') as f:
#     pickle.dump(protT5_embeddings, f)


In [None]:
type(protT5_embeddings)

dict

In [None]:
# Save the datasets
with open('train_dataset.pkl', 'wb') as f:
    pickle.dump(train_dataset, f)

with open('test_dataset.pkl', 'wb') as f:
    pickle.dump(test_dataset, f)

with open('val_dataset.pkl', 'wb') as f:
    pickle.dump(val_dataset, f)

# Reload Data

In [None]:
from google.colab import drive
drive.mount('/content/drive')
import os

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
os.chdir('/content/drive/MyDrive/Programmable Biology Group/Srikar/Code/proteins/flamingo-ppi-gen/data_dump/')

In [None]:
from transformers import T5Tokenizer, T5EncoderModel
import torch
import re
!pip install sentencepiece
import sentencepiece
import torch
from torch import nn
from transformers import T5ForConditionalGeneration, T5Tokenizer
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel, AutoTokenizer



In [None]:
def one_hot_encode_energy_scores(scores):
    # Assuming 'scores' is a list of energy score values
    return [1 if score <= -1 else 0 for score in scores]

In [None]:
from torch.utils.data import Dataset
import pickle

In [None]:
import re

class ProteinInteractionDataset(Dataset):
    def __init__(self, dataframe, protT5_embeddings, tokenizer):
        self.dataframe = dataframe
        self.protT5_embeddings = protT5_embeddings
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        peptide_seq = self.dataframe.iloc[idx]['peptide_derived_sequence']
        protein_seq = self.dataframe.iloc[idx]['protein_derived_sequence']
        energy_scores = self.dataframe.iloc[idx]['energy_scores']

        # Tokenize the protein_seq
        tokenized_protein_seq = self.tokenizer.encode(protein_seq, add_special_tokens=True, return_tensors="pt").squeeze()

        # Use regular expression to split the energy_scores string
        energy_scores = re.findall(r'-?\d+\.?\d*(?:e[-+]?\d+)?', energy_scores)

        # Convert the split strings to floats
        energy_scores = [float(score) for score in energy_scores]

        # One-hot encode the energy scores
        encoded_scores = one_hot_encode_energy_scores(energy_scores)

        peptide_embedding = self.protT5_embeddings[peptide_seq]
        protein_embedding = self.protT5_embeddings[protein_seq]

        return peptide_embedding, protein_embedding, torch.tensor(encoded_scores), tokenized_protein_seq


In [None]:
from torch.nn.functional import pad

def collate_fn(batch):
    max_length_embeddings = max(max(prot1_emb.size(0), prot2_emb.size(0)) for prot1_emb, prot2_emb, _, _ in batch)
    max_length_tokenized = max(tok_seq.size(0) for _, _, _, tok_seq in batch)

    print("Max length for embeddings:", max_length_embeddings)
    print("Max length for tokenized sequences:", max_length_tokenized)

    prot1_embeddings_padded = []
    prot2_embeddings_padded = []
    one_hot_scores_padded_list = []
    tokenized_seqs_padded = []

    for prot1_emb, prot2_emb, one_hot_scores, tok_seq in batch:
        print("Original sizes:", prot1_emb.size(), prot2_emb.size(), one_hot_scores.size(), tok_seq.size())

        prot1_emb_padded = pad(prot1_emb, (0, max_length_embeddings - prot1_emb.size(0)), "constant", 0)
        prot2_emb_padded = pad(prot2_emb, (0, max_length_embeddings - prot2_emb.size(0)), "constant", 0)

        if one_hot_scores.nelement() != 0:
            one_hot_scores_padded = pad(one_hot_scores, (0, max_length_embeddings - one_hot_scores.size(0)), "constant", 0)
        else:
            one_hot_scores_padded = one_hot_scores

        tok_seq_padded = pad(tok_seq, (0, max_length_tokenized - tok_seq.size(0)), "constant", 0)

        print("Padded sizes:", prot1_emb_padded.size(), prot2_emb_padded.size(), one_hot_scores_padded.size(), tok_seq_padded.size())

        prot1_embeddings_padded.append(prot1_emb_padded)
        prot2_embeddings_padded.append(prot2_emb_padded)
        one_hot_scores_padded_list.append(one_hot_scores_padded)
        tokenized_seqs_padded.append(tok_seq_padded)

    return torch.stack(prot1_embeddings_padded), torch.stack(prot2_embeddings_padded), torch.stack(one_hot_scores_padded_list), torch.stack(tokenized_seqs_padded)


In [None]:
# Load tokenizer and model
tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)
model = model.half() if device.type == 'cuda' else model.full()

from tqdm import tqdm

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [None]:
import pickle
# Load the datasets
with open('train_dataset.pkl', 'rb') as f:
    train_dataset = pickle.load(f)

with open('test_dataset.pkl', 'rb') as f:
    test_dataset = pickle.load(f)

with open('val_dataset.pkl', 'rb') as f:
    val_dataset = pickle.load(f)

# Assuming these are the batch sizes you want to use
train_batch_size = 2
test_batch_size = 2
val_batch_size = 2


In [None]:
train_dataset[0]

(tensor([ 0.0517, -0.0203,  0.0323,  ..., -0.0052,  0.0547,  0.0482],
        device='cuda:0', dtype=torch.float16),
 tensor([ 0.0432, -0.0195,  0.0399,  ...,  0.0008,  0.0552,  0.0399],
        device='cuda:0', dtype=torch.float16),
 tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [None]:
from torch.utils.data import DataLoader

# Create the DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=test_batch_size, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=val_batch_size, collate_fn=collate_fn)


# Motif-guided ProtFlamingo

## Helper Functions + Gated Cross Attn + Perceiver Resampler

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import torch.nn.functional as F
# from transformers import RobertaModel  # Assuming use of Hugging Face's transformer models

# Helper Functions
def exists(val):
    return val is not None

def set_module_requires_grad_(module, requires_grad):
    for param in module.parameters():
        param.requires_grad = requires_grad

def freeze_model_and_make_eval_(model):
    model.eval()
    set_module_requires_grad_(model, False)

# LayerNorm class
class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super().__init__()
        self.gain = nn.Parameter(torch.ones(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.gain * (x - mean) / (std + self.eps)

# Residual class
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

# SwiGLU activation function
class SwiGLU(nn.Module):
    def forward(self, x):
        return F.silu(x[..., :x.shape[-1] // 2]) * x[..., x.shape[-1] // 2:]

# Transformer Block class
class TransformerBlock(nn.Module):
    def __init__(self, dim, heads, mlp_dim):
        super().__init__()
        self.ln1 = LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, heads)
        self.ln2 = LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            SwiGLU(),
            nn.Linear(mlp_dim // 2, dim)
        )
        self.residual = Residual(self.ln1)
        self.feedforward = Residual(self.ln2)
        self.expand_dim = nn.Linear(dim, 2 * dim)  # Project to a higher dimension

    def forward(self, x):
        if x.dim() < 3: ### do the 4,1,1024 transformation
            # Apply the expansion transformation if x has less than 3 dimensions
            x_expanded = self.expand_dim(x)  # Now [2, 2*desired_dim]
            x_expanded = x_expanded.view(4, 1, 1024)  # Reshape to [4, 1, 1024]
            # x_expanded = nn.LayerNorm(x)
            print('x transformed shape in gated cross attn:', x_expanded.shape)
            x = self.residual(self.attn(x_expanded, x_expanded, x_expanded)[0])
        else:
            x = self.residual(self.attn(x, x, x)[0])
        print("Shape after attention and residual:", x.shape)  # Debug print
        x = self.feedforward(self.mlp(x))
        print("Shape after feedforward:", x.shape)  # Debug print
        return x


In [None]:
!pip install transformers



In [None]:
!pip install einops-exts



In [None]:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:4096'

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops_exts import rearrange_many, repeat_many

def exists(val):
    return val is not None

def FeedForward(dim, mult = 4):
    inner_dim = int(dim * mult)
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, inner_dim, bias = False),
        nn.GELU(),
        nn.Linear(inner_dim, dim, bias = False)
    )

class PerceiverAttention(nn.Module):
    def __init__(self, *, dim, concatenated_dim, dim_head=64, heads=8):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        # Adjusted to use concatenated_dim for normalization
        self.norm_media = nn.LayerNorm(concatenated_dim)
        self.norm_latents = nn.LayerNorm(dim)

        # Adjusted dimensions for the larger concatenated input
        self.to_q = nn.Linear(concatenated_dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(concatenated_dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)

    def forward(self, x, latents):
        # Normalize x and latents
        x = self.norm_media(x)
        latents = self.norm_latents(latents)

        print("Original Shape of x:", x.shape)
        print("Original Shape of latents:", latents.shape)

        # Expand x to 4D if it's 3D: [B, T, D] -> [B, T, 1, D]
        if x.dim() == 3:
            x = x.unsqueeze(2)
            print("Expanded Shape of x:", x.shape)

        # Check dimensions of latents and expand if necessary
        if latents.dim() == 3:
            b, m, d = latents.shape
            latents = latents.unsqueeze(1)  # Add time dimension
            print("Expanded Shape of latents:", latents.shape)
        elif latents.dim() != 4:
            raise ValueError(f"Unexpected number of dimensions in latents: {latents.dim()}")

        # Ensure x and latents have the same last dimension
        if x.size(-1) != latents.size(-1):
            # Use a linear layer to transform the dimension of latents
            linear_transform = nn.Linear(latents.size(-1), x.size(-1)).to(latents.device)
            latents = linear_transform(latents)
            print("Transformed latents with linear layer:", latents.shape)

        # Concatenate x and latents along the new dimension
        kv_input = torch.cat((x, latents), dim=2)
        print("Shape of concatenated kv_input:", kv_input.shape)


        # Generate queries from latents, and keys and values from the concatenated input
        q = self.to_q(latents)
        q = rearrange(q, 'b t m (h d) -> b h t m d', h=self.heads)
        print("Shape of query tensors:", q.shape)

        k, v = self.to_kv(kv_input).chunk(2, dim=-1)
        k, v = rearrange_many((k, v), 'b t n (h d) -> b h t n d', h=self.heads)

        q = q * self.scale

        # attention

        sim = einsum('... i d, ... j d  -> ... i j', q, k)

        sim = sim - sim.amax(dim = -1, keepdim = True).detach()
        attn = sim.softmax(dim = -1)

        out = einsum('... i j, ... j d -> ... i d', attn, v)
        out = rearrange(out, 'b h t n d -> b t n (h d)', h = self.heads)
        return self.to_out(out)


class PerceiverResampler(nn.Module):
    def __init__(self, *, dim, depth, dim_head=64, heads=8, num_latents=64, num_media_embeds=4, ff_mult=4, concatenated_dim=2048):
        super().__init__()
        self.latents = nn.Parameter(torch.randn(num_latents, dim))
        self.media_pos_emb = nn.Parameter(torch.randn(num_media_embeds, 1, concatenated_dim))

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PerceiverAttention(dim=dim, concatenated_dim=concatenated_dim, dim_head=dim_head, heads=heads),
                FeedForward(dim=dim, mult=ff_mult)
            ]))

        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        if x.ndim == 3:
            x = rearrange(x, 'b n d -> b 1 n d')

        times = x.shape[1]
        x = x + self.media_pos_emb[:times]

        latents = repeat(self.latents, 'n d -> b m n d', b=x.shape[0], m=x.shape[1])

        for attn, ff in self.layers:
            latents = attn(x, latents) + latents
            latents = ff(latents) + latents

        return self.norm(latents)


class MaskedCrossAttention(nn.Module):
    def __init__(self, *, dim, concatenated_dim=2048, dim_head=64, heads=8, only_attend_immediate_media=True):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm = nn.LayerNorm(dim)
        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)
        self.only_attend_immediate_media = only_attend_immediate_media

    def forward(self, x, media, media_locations=None):
        b, t, m = media.shape[:3]
        h = self.heads
        print('x before norm (masked cross attn):',x.shape)
        print('media before reaarange (masked cross attn):',media.shape)
        x = self.norm(x)
        q = self.to_q(x)
        media = rearrange(media, 'b t n d -> b (t n) d')
        k, v = self.to_kv(media).chunk(2, dim=-1)
        print("Shape of q (masked cross attn):", q.shape)
        print("Shape of k (masked cross attn):", k.shape)
        print("Shape of v (masked cross attn):", v.shape)

        q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h=h)

        q = q * self.scale
        sim = einsum('... i d, ... j d -> ... i j', q, k)
        if media_locations is not None:
            # Modify attention mask based on motif presence
            mask = media_locations.unsqueeze(1).unsqueeze(2)
            sim = sim.masked_fill(mask == 0, float('-inf'))

        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)
        out = einsum('... i j, ... j d -> ... i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)


class GatedCrossAttentionBlock(nn.Module):
    def __init__(self, *, dim, dim_head=64, heads=8, ff_mult=4, only_attend_immediate_media=True):
        super().__init__()
        self.attn = MaskedCrossAttention(dim=dim, dim_head=dim_head, heads=heads, only_attend_immediate_media=only_attend_immediate_media)
        self.attn_gate = nn.Parameter(torch.tensor([0.]))
        self.ff = FeedForward(dim, mult=ff_mult)
        self.ff_gate = nn.Parameter(torch.tensor([0.]))

        self.expand_dim = nn.Linear(dim, 2 * dim)  # Project to a higher dimension

    def forward(self, x, media, media_locations=None):
        gate = self.attn_gate.tanh()
        print('defined gate (tanh)')
        print('gate shape:',gate.shape)
        print('x shape initially input to gate:', x.shape)
        # x = x.unsqueeze(1)
        # print('x unsqueezed shape:', x.shape)

        if x.dim() < 3: ### do the 4,1,1024 transformation
            # Apply the expansion transformation if x has less than 3 dimensions
            x_expanded = self.expand_dim(x)  # Now [2, 2*desired_dim]
            x_expanded = x_expanded.view(4, 1, 1024)  # Reshape to [4, 1, 1024]
            # x_expanded = nn.LayerNorm(x)
            print('x transformed shape in gated cross attn:', x_expanded.shape)
            x = self.attn(x_expanded, media, media_locations=media_locations) * gate + x_expanded
        else:
            # If x has 3 dimensions, use it as is
            x = self.attn(x, media, media_locations=media_locations) * gate + x
        print('self attn of x *gate +x :',x.shape)
        x = self.ff(x) * self.ff_gate.tanh() + x
        print('ff*gate +x of x:', x.shape)
        return x


## ProtFlamingo

In [None]:

class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        return self.norm(x)

class SwiGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return F.silu(gate) * x

class ParallelTransformerBlock(nn.Module):
    def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
        super().__init__()
        self.norm = LayerNorm(dim)

        attn_inner_dim = dim_head * heads
        ff_inner_dim = dim * ff_mult

        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads)
        self.ff = nn.Sequential(
            nn.Linear(dim, 2* ff_mult * dim),
            SwiGLU(),
            nn.Linear(ff_mult * dim, dim)
        )

    def forward(self, x):
        print("Input to ParallelTransformerBlock:", x.shape)

        x = self.norm(x)
        print("After LayerNorm:", x.shape)

        x = x.permute(1, 0, 2)  # Rearrange for nn.MultiheadAttention
        print("After permute for MultiheadAttention:", x.shape)

        attn_output, _ = self.attn(x, x, x)
        print("After MultiheadAttention:", attn_output.shape)

        x = attn_output + x
        print("After adding attn_output:", x.shape)

        x = x.permute(1, 0, 2)  # Rearrange back
        print("After permute back:", x.shape)

        # ff_output = self.ff(x)
        # print("After FeedForward:", ff_output.shape)
        ff_output = x
        for layer in self.ff:
            if isinstance(layer, nn.Linear):
                print("Input to Linear Layer:", ff_output.shape)
                ff_output = layer(ff_output)
                print("Output from Linear Layer:", ff_output.shape)
            else:
                # Assuming SwiGLU or other non-linear layers don't change shape
                ff_output = layer(ff_output)

        output = ff_output + x
        print("Output from ParallelTransformerBlock:", output.shape)

        return output



In [None]:
from transformers import T5Tokenizer, T5EncoderModel
import torch
import re

In [None]:
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from transformers import T5ForConditionalGeneration, T5Tokenizer
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer


In [None]:
T5ForConditionalGeneration.from_pretrained("Rostlab/prot_t5_xl_bfd").config.d_model

1024

In [None]:
import torch
from torch import nn
from transformers import T5ForConditionalGeneration, T5Tokenizer

class ProtFlamingo(nn.Module):
    def __init__(self, num_tokens, depth, dim_head=64, heads=8, ff_mult=4, cross_attn_every=3, perceiver_num_latents=64, perceiver_depth=2, only_attend_immediate_media=True, protein_mode=False, motif_mode=False):
        super().__init__()

        self.protT5 = T5ForConditionalGeneration.from_pretrained("Rostlab/prot_t5_xl_bfd")
        self.dim = 1024  # Assuming the embedding dimension
        self.to_logits = nn.Linear(self.dim, num_tokens)

        self.perceiver_resampler = PerceiverResampler(dim=self.dim, depth=perceiver_depth, dim_head=dim_head, heads=heads, num_latents=perceiver_num_latents,concatenated_dim=2048)  # new parameter for the concatenated size

        self.layers = nn.ModuleList([])
        for i in range(depth):
            self.layers.append(TransformerBlock(dim=self.dim, heads=heads, mlp_dim=self.dim * ff_mult))
            if i % cross_attn_every == 0:
                self.layers.append(GatedCrossAttentionBlock(dim=self.dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult))

    def forward(self, protein_embeddings, motif_encodings, target_sequence):
        # Print the shape of inputs
        print("Protein Embeddings Shape:", protein_embeddings.shape)
        print("Motif Encodings Shape:", motif_encodings.shape)
        print("Target Sequence Shape:", target_sequence.shape)

        # Concatenate protein_embeddings and motif_encodings
        combined_input = torch.cat((protein_embeddings, motif_encodings), dim=-1)
        print("Combined Input Shape:", combined_input.shape)

        # Process the combined input through the perceiver resampler
        processed_input = self.perceiver_resampler(combined_input)
        print("Processed Input Shape:", processed_input.shape)

        # Iterate over the layers
        for index, layer in enumerate(self.layers):
            print(f"Layer {index} type: {type(layer).__name__}")

            # Process the target_sequence through the layer
            if isinstance(layer, GatedCrossAttentionBlock):
                target_sequence = layer(target_sequence, processed_input)
            else:
                target_sequence = layer(target_sequence)

            # Print the shape after each layer
            print(f"Post Layer {index} Shape:", target_sequence.shape)

        # Generate logits
        logits = self.to_logits(target_sequence)
        print("Logits Shape:", logits.shape)

        return logits


## Train Model

In [None]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    seq1_embeddings, seq2_embeddings, one_hot_scores, tokenized_seqs = zip(*batch)

    seq1_embeddings = pad_sequence(seq1_embeddings, batch_first=True)
    seq2_embeddings = pad_sequence(seq2_embeddings, batch_first=True)
    one_hot_scores = pad_sequence(one_hot_scores, batch_first=True)
    tokenized_seqs = pad_sequence(tokenized_seqs, batch_first=True, padding_value=tokenizer.pad_token_id)

    return seq1_embeddings, seq2_embeddings, one_hot_scores, tokenized_seqs


In [None]:
import torch
import torch.nn as nn
from torch.nn.functional import kl_div, log_softmax


In [None]:
import torch

# Assuming 'model', 'train_dataloader', 'val_dataloader', 'test_dataloader', and 'criterion' are already defined
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model = model.half() if device.type == 'cuda' else model.full()


optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)



In [None]:
train_dataloader.dataset

<__main__.ProteinInteractionDataset at 0x7a3ee2422f80>

In [None]:
# Load ProtT5 model
protT5_model = T5ForConditionalGeneration.from_pretrained("Rostlab/prot_t5_xl_bfd")
protT5_tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_bfd")

In [None]:
# Instantiate model, optimizer, and other training components
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Example parameters
num_tokens = protT5_tokenizer.vocab_size
depth = 12  # Adjust based on model complexity and computational resources

In [None]:
model = ProtFlamingo(
    num_tokens=num_tokens,
    depth=depth,
    dim_head=64,
    heads=8,
    ff_mult=4,
    cross_attn_every=2,
    perceiver_num_latents=64,
    perceiver_depth=2,
    only_attend_immediate_media=True
).to(device)


In [None]:
def train_epoch_kl(model, data_loader, optimizer, device):
    model.train()
    total_loss = 0

    for seq1_embeddings, seq2_embeddings, one_hot_scores, tokenized_seqs in data_loader:
        # seq1_embeddings, seq2_embeddings, one_hot_scores, tokenized_seqs = \
        #     # Convert all tensors to FloatTensor
        seq1_embeddings = seq1_embeddings.float().to(device)
        seq2_embeddings = seq2_embeddings.float().to(device)
        one_hot_scores = one_hot_scores.float().to(device)
        tokenized_seqs = tokenized_seqs.float().to(device)  # Convert tokenized_seqs to float

        optimizer.zero_grad()

        model_output = model(seq1_embeddings, seq2_embeddings, one_hot_scores)
        print('model output shape:', model_output.shape)
        log_probs = F.log_softmax(model_output, dim=-1)
        print('log probs shape:', log_probs.shape)
        print('tokenized_seqs shape:', tokenized_seqs.shape)
        print(tokenized_seqs)
        print(log_probs)

        loss = F.kl_div(log_probs, tokenized_seqs, reduction='batchmean')
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(data_loader)

# Train for one epoch
train_loss = train_epoch_kl(model, train_dataloader, optimizer, device)
print(f"Training Epoch: Loss = {train_loss}")

def validate_epoch_kl(model, data_loader, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for seq1_embeddings, seq2_embeddings, one_hot_scores, tokenized_seqs in data_loader:
            # seq1_embeddings, seq2_embeddings, one_hot_scores, tokenized_seqs = \
            #     # Convert all tensors to FloatTensor
            seq1_embeddings = seq1_embeddings.float().to(device)
            seq2_embeddings = seq2_embeddings.float().to(device)
            one_hot_scores = one_hot_scores.float().to(device)
            tokenized_seqs = tokenized_seqs.float().to(device)  # Convert tokenized_seqs to float

            model_output = model(seq1_embeddings, seq2_embeddings, one_hot_scores)
            log_probs = F.log_softmax(model_output, dim=-1)

            loss = F.kl_div(log_probs, tokenized_seqs, reduction='batchmean')
            total_loss += loss.item()

    return total_loss / len(data_loader)

# Validate for one epoch
val_loss = validate_epoch_kl(model, val_dataloader, device)
print(f"Validation Epoch: Loss = {val_loss}")


def test_epoch_kl(model, data_loader, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for seq1_embeddings, seq2_embeddings, one_hot_scores, tokenized_seqs in data_loader:
            # seq1_embeddings, seq2_embeddings, one_hot_scores, tokenized_seqs = \
            #     # Convert all tensors to FloatTensor
            seq1_embeddings = seq1_embeddings.float().to(device)
            seq2_embeddings = seq2_embeddings.float().to(device)
            one_hot_scores = one_hot_scores.float().to(device)
            tokenized_seqs = tokenized_seqs.float().to(device)  # Convert tokenized_seqs to float

            model_output = model(seq1_embeddings, seq2_embeddings, one_hot_scores)
            log_probs = F.log_softmax(model_output, dim=-1)

            loss = F.kl_div(log_probs, tokenized_seqs, reduction='batchmean')
            total_loss += loss.item()

    return total_loss / len(data_loader)

# Test for one epoch
test_loss = test_epoch_kl(model, test_dataloader, device)
print(f"Test Epoch: Loss = {test_loss}")


Max length for embeddings: 1024
Max length for tokenized sequences: 2
Original sizes: torch.Size([1024]) torch.Size([1024]) torch.Size([173]) torch.Size([2])
Padded sizes: torch.Size([1024]) torch.Size([1024]) torch.Size([1024]) torch.Size([2])
Original sizes: torch.Size([1024]) torch.Size([1024]) torch.Size([301]) torch.Size([2])
Padded sizes: torch.Size([1024]) torch.Size([1024]) torch.Size([1024]) torch.Size([2])
Protein Embeddings Shape: torch.Size([2, 1024])
Motif Encodings Shape: torch.Size([2, 1024])
Target Sequence Shape: torch.Size([2, 1024])
Combined Input Shape: torch.Size([2, 2048])
Original Shape of x: torch.Size([4, 2, 2048])
Original Shape of latents: torch.Size([4, 2, 64, 1024])
Expanded Shape of x: torch.Size([4, 2, 1, 2048])
Transformed latents with linear layer: torch.Size([4, 2, 64, 2048])
Shape of concatenated kv_input: torch.Size([4, 2, 65, 2048])
Shape of query tensors: torch.Size([4, 8, 2, 64, 64])
Original Shape of x: torch.Size([4, 2, 2048])
Original Shape of 

RuntimeError: ignored

In [None]:
protT5_tokenizer