# Reload Data

### defining functions

In [None]:
from google.colab import drive
drive.mount('/content/drive')
import os

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
os.chdir('/content/drive/MyDrive/Programmable Biology Group/Srikar/Code/proteins/flamingo-ppi-gen/data_dump/per-residue-dataset/')

In [None]:
from transformers import T5Tokenizer, T5EncoderModel
import torch
import re
!pip install sentencepiece
import sentencepiece
import torch
from torch import nn
from transformers import T5ForConditionalGeneration, T5Tokenizer
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel, AutoTokenizer



In [None]:
def one_hot_encode_energy_scores(scores):
    # Assuming 'scores' is a list of energy score values
    return [1 if score <= -1 else 0 for score in scores]

In [None]:
from torch.utils.data import Dataset
import pickle

In [None]:
# Load the protT5_tokens dictionary
with open('protT5_tokens.pkl', 'rb') as file:
    protT5_tokens = pickle.load(file)

In [None]:
len(protT5_tokens)

9665

In [None]:
len(protT5_tokens)

9665

In [None]:
# Load tokenizer and model
tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)
model = model.half() if device.type == 'cuda' else model.full()

from tqdm import tqdm

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [None]:
import torch
import re
import pickle
from torch.utils.data import Dataset
from torch.nn.functional import pad

class ProteinInteractionDataset(Dataset):
    def __init__(self, dataframe, protT5_embeddings, protT5_tokens):
        self.dataframe = dataframe
        self.protT5_embeddings = protT5_embeddings
        self.protT5_tokens = protT5_tokens

        # Determine the maximum lengths
        self.max_length_embeddings = max(max(len(self.protT5_embeddings[seq1]), len(self.protT5_embeddings[seq2]))
                                         for seq1, seq2 in zip(dataframe['peptide_derived_sequence'], dataframe['protein_derived_sequence']))
        self.max_length_tokenized = max(len(self.protT5_tokens[seq]) for seq in dataframe['peptide_derived_sequence'])
        self.max_length_scores = max(len(re.findall(r'-?\d+\.?\d*(?:e[-+]?\d+)?', scores)) for scores in dataframe['energy_scores'])

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        peptide_seq = self.dataframe.iloc[idx]['peptide_derived_sequence']
        protein_seq = self.dataframe.iloc[idx]['protein_derived_sequence']
        energy_scores = self.dataframe.iloc[idx]['energy_scores']

        # Process the energy_scores
        energy_scores = re.findall(r'-?\d+\.?\d*(?:e[-+]?\d+)?', energy_scores)
        energy_scores = [float(score) for score in energy_scores]
        energy_scores = one_hot_encode_energy_scores(energy_scores)

        # Pad the energy scores -- max length of all should be equal
        energy_scores_padded = pad(torch.tensor(energy_scores), (0, self.max_length_tokenized - len(energy_scores)), "constant", 0)

        peptide_embedding = self.protT5_embeddings[peptide_seq]
        protein_embedding = self.protT5_embeddings[protein_seq]
        tokenized_peptide_seq = self.protT5_tokens[peptide_seq]

        # print('max_length_embeddings:', self.max_length_embeddings)
        # print('max_length_tokenized:', self.max_length_tokenized)
        # print('max_length_scores:', self.max_length_scores)

        # Pad the sequences
        peptide_embedding_padded = pad(peptide_embedding, (0, 0, 0, self.max_length_embeddings - len(peptide_embedding)), "constant", 0)
        protein_embedding_padded = pad(protein_embedding, (0, 0, 0, self.max_length_embeddings - len(protein_embedding)), "constant", 0)
        tokenized_peptide_seq_padded = pad(torch.tensor(tokenized_peptide_seq, dtype=torch.float), (0, self.max_length_tokenized - len(tokenized_peptide_seq)), "constant", 0)


        # print(peptide_embedding_padded.shape,protein_embedding_padded.shape,energy_scores_padded.shape,tokenized_peptide_seq_padded.shape)
        return peptide_embedding_padded, protein_embedding_padded, energy_scores_padded, tokenized_peptide_seq_padded

# # Usage
# protein_interaction_dataset = ProteinInteractionDataset(dataframe, protT5_embeddings, protT5_tokens)
# protein_interaction_dataloader = DataLoader(protein_interaction_dataset, batch_size=your_batch_size)


In [None]:
with open('protT5_embeddings.pkl', 'rb') as file:
    protT5_embeddings = pickle.load(file)

In [None]:
len(protT5_embeddings)

9665

## preprocessing SnP PPI data

In [None]:
os.chdir('/content/drive/MyDrive/Programmable Biology Group/Srikar/Code/proteins/flamingo-ppi-gen/data_dump/per-residue-dataset/')

In [None]:
!ls

protT5_embeddings.pkl  testing_dataset.csv   validation_dataset.csv
protT5_tokens.pkl      training_dataset.csv


In [None]:
import pandas as pd

In [None]:
test_snp = pd.read_csv('testing_dataset.csv')
train_snp = pd.read_csv('training_dataset.csv')
val_snp = pd.read_csv('validation_dataset.csv')

In [None]:
import pandas as pd
import re

def preprocess_snp_data(file_path):
    # Read the dataset
    snp_df = pd.read_csv(file_path)

    # Function to transform energy scores
    def transform_energy_scores(energy_scores):
        transformed_scores = []
        for score in energy_scores:
            # Replace sequences of spaces/newlines with a comma
            score = re.sub(r'[\s\n]+', ',', score)
            # Remove a comma after an opening square bracket
            score = re.sub(r'\[\s*,', '[', score)
            # Remove leading commas/whitespace
            score = re.sub(r'^[\s,]+', '', score)
            transformed_scores.append(score)
        return transformed_scores

    # Apply transformations
    snp_df['energy_scores'] = transform_energy_scores(snp_df['energy_scores'])
    snp_df['energy_scores_lengths'] = snp_df['energy_scores'].apply(
        lambda x: x.count(',') + 1 - (1 if x.startswith(',') else 0)
    )

    # Calculate lengths for other columns
    snp_df['peptide_source_RCSB_lengths'] = snp_df['peptide_source_RCSB'].apply(len)
    snp_df['protein_RCSB_lengths'] = snp_df['protein_RCSB'].apply(len)
    snp_df['protein_derived_seq_length'] = snp_df['protein_derived_sequence'].apply(len)
    snp_df['peptide_derived_seq_length'] = snp_df['peptide_derived_sequence'].apply(len)

    # Calculate matching lengths count (optional, depending on your needs)
    snp_df['matching_lengths_count'] = (snp_df['energy_scores_lengths'] == snp_df['peptide_derived_seq_length']).sum()

    return snp_df

# Applying the preprocessing pipeline to each dataset
test_snp = preprocess_snp_data('testing_dataset.csv')
train_snp = preprocess_snp_data('training_dataset.csv')
val_snp = preprocess_snp_data('validation_dataset.csv')


In [None]:
train_snp['protein_derived_sequence'][0]

'VWLANPERYGQMQYRYCGKSGLRLPALSLGLWHNFGHVNALESQRAILRKAFDLGITHFDLANNYGPPPGSAEENFGRLLREDFAAYRDELIISTKAGYDMWPGPYGSGGSRKYLLASLDQSLKRMGLEYVDIFYSHRVDENTPMEETASALAHAVQSGKALYVGISSYSPERTQKMVELLREWKIPLLIHQPSYNLLNRWVDKSGLLDTLQNNGVGCIAFTPLAQGLLTGKYLTEANLNSLRLLNEMAQQRGQSMAQMALSWLLKDDRVTSVLIGASRAEQLEENVQALNNLTFSTKELAQIDQHIADGELN'

## create the *datasets*

In [None]:
len(protT5_embeddings)

9665

In [None]:
# Create datasets with tokenizer
train_dataset = ProteinInteractionDataset(train_snp, protT5_embeddings,protT5_tokens)
test_dataset = ProteinInteractionDataset(test_snp, protT5_embeddings,protT5_tokens)
val_dataset = ProteinInteractionDataset(val_snp, protT5_embeddings,protT5_tokens)


In [None]:
os.chdir('/content/drive/MyDrive/Programmable Biology Group/Srikar/Code/proteins/flamingo-ppi-gen/data_dump/flamingo-26-data/')

In [None]:
from torch.utils.data import DataLoader

train_batch_size = 2
test_batch_size = 2
val_batch_size = 2

# Create the DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=test_batch_size)
val_dataloader = DataLoader(val_dataset, batch_size=val_batch_size)


In [None]:
len(train_dataloader),len(test_dataloader),len(val_dataloader)

(2500, 1000, 1000)

# Motif-guided ProtFlamingo

## Helper Functions + Gated Cross Attn + Perceiver Resampler

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import torch.nn.functional as F
# from transformers import RobertaModel  # Assuming use of Hugging Face's transformer models

# Helper Functions
def exists(val):
    return val is not None

def set_module_requires_grad_(module, requires_grad):
    for param in module.parameters():
        param.requires_grad = requires_grad

def freeze_model_and_make_eval_(model):
    model.eval()
    set_module_requires_grad_(model, False)

# LayerNorm class
class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super().__init__()
        self.gain = nn.Parameter(torch.ones(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.gain * (x - mean) / (std + self.eps)

# Residual class
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

# SwiGLU activation function
class SwiGLU(nn.Module):
    def forward(self, x):
        return F.silu(x[..., :x.shape[-1] // 2]) * x[..., x.shape[-1] // 2:]

# Transformer Block class
class TransformerBlock(nn.Module):
    def __init__(self, dim, heads, mlp_dim):
        super().__init__()
        self.ln1 = LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, heads)
        self.ln2 = LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            SwiGLU(),
            nn.Linear(mlp_dim // 2, dim)
        )
        self.residual = Residual(self.ln1)
        self.feedforward = Residual(self.ln2)
        self.expand_dim = nn.Linear(dim, 2 * dim)  # Project to a higher dimension

    def forward(self, x):
        if x.dim() < 3: ### do the 1,1,1024 transformation
            # Apply the expansion transformation if x has less than 3 dimensions
            x_expanded = self.expand_dim(x)  # Now [2, 2*desired_dim]
            x_expanded = x_expanded.view(1, 1, 1024)  # Reshape to [1, 1, 1024]
            # x_expanded = nn.LayerNorm(x)
            print('x transformed shape in gated cross attn:', x_expanded.shape)
            x = self.residual(self.attn(x_expanded, x_expanded, x_expanded)[0])
        else:
            x = self.residual(self.attn(x, x, x)[0])
        print("Shape after attention and residual:", x.shape)  # Debug print
        x = self.feedforward(self.mlp(x))
        print("Shape after feedforward:", x.shape)  # Debug print
        return x


In [None]:
!pip install transformers



In [None]:
!pip install einops-exts



In [None]:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:4096'

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops_exts import rearrange_many, repeat_many

def exists(val):
    return val is not None

def FeedForward(dim, mult = 4):
    inner_dim = int(dim * mult)
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, inner_dim, bias = False),
        nn.GELU(),
        nn.Linear(inner_dim, dim, bias = False)
    )

class PerceiverAttention(nn.Module):
    def __init__(self, *, dim, concatenated_dim, dim_head=64, heads=8):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm_media = nn.LayerNorm(dim)
        self.norm_latents = nn.LayerNorm(dim)
        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)


    def forward(self, x, latents):
        x = self.norm_media(x)
        latents = self.norm_latents(latents)

        # print('x shape perciever attn:', x.shape)
        # print('latents shape perceiver attn', latents.shape)

        q = self.to_q(latents)
        # print('q shape:',q.shape)

        q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads)
        q = q * self.scale


        kv_input = torch.cat((x, latents), dim=1)
        k, v = self.to_kv(kv_input).chunk(2, dim=-1)

        # print('k shape:',k.shape)
        # print('v shape:',v.shape)
        k = rearrange(k, 'b n (h d) -> b h n d', h=self.heads)
        v = rearrange(v, 'b n (h d) -> b h n d', h=self.heads)

        # print('rearrangement in perceiver cross attn complete...')
        # print('q shape:',q.shape)
        # print('k shape:',k.shape)
        # print('v shape:',v.shape)

        sim = einsum('... i d, ... j d -> ... i j', q, k)
        attn = sim.softmax(dim=-1)
        out = einsum('... i j, ... j d -> ... i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')

        return self.to_out(out)

class PerceiverResampler(nn.Module):
    def __init__(self, *, dim, depth, dim_head=64, heads=8, num_latents=64, concatenated_dim=2048):
        super().__init__()
        self.latents = nn.Parameter(torch.randn(num_latents, dim))
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PerceiverAttention(dim=dim, concatenated_dim=concatenated_dim, dim_head=dim_head, heads=heads),
                FeedForward(dim=dim)
            ]))

    def forward(self, x):
        latents = repeat(self.latents, 'n d -> b n d', b=x.shape[0])

        for attn, ff in self.layers:
            latents = attn(x, latents) + latents
            latents = ff(latents) + latents

        return latents

class MaskedCrossAttention(nn.Module):
    def __init__(self, *, dim, concatenated_dim=2048, dim_head=64, heads=8, only_attend_immediate_media=True):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm = nn.LayerNorm(dim)
        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)
        self.only_attend_immediate_media = only_attend_immediate_media

    def forward(self, x, media, media_locations=None):
        b, t, _ = x.shape
        _, m, _ = media.shape
        h = self.heads

        x = self.norm(x)
        q = self.to_q(x)
        q = rearrange(q, 'b n (h d) -> b h n d', h=h)

        # No need to reshape media as it's already 3D
        k, v = self.to_kv(media).chunk(2, dim=-1)
        k = rearrange(k, 'b n (h d) -> b h n d', h=h)
        v = rearrange(v, 'b n (h d) -> b h n d', h=h)

        q = q * self.scale
        sim = einsum('... i d, ... j d -> ... i j', q, k)

        if media_locations is not None:
            mask = media_locations.unsqueeze(1).unsqueeze(2)
            mask = rearrange(mask, 'b n -> b 1 n 1')
            sim = sim.masked_fill(mask == 0, float('-inf'))

        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)
        out = einsum('... i j, ... j d -> ... i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)', h=self.heads)

        return self.to_out(out)


class GatedCrossAttentionBlock(nn.Module):
    def __init__(self, *, dim, dim_head=64, heads=8, ff_mult=4, only_attend_immediate_media=True):
        super().__init__()
        self.attn = MaskedCrossAttention(dim=dim, concatenated_dim=2048, dim_head=dim_head, heads=heads, only_attend_immediate_media=only_attend_immediate_media)
        self.attn_gate = nn.Parameter(torch.tensor([0.]))
        self.ff = FeedForward(dim, mult=ff_mult)
        self.ff_gate = nn.Parameter(torch.tensor([0.]))

    def forward(self, x, media, media_locations=None):
        gate = self.attn_gate.tanh()
        x = self.attn(x, media, media_locations=media_locations) * gate + x
        x = self.ff(x) * self.ff_gate.tanh() + x
        return x



## ProtFlamingo

In [None]:

class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        return self.norm(x)

class SwiGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return F.silu(gate) * x

class ParallelTransformerBlock(nn.Module):
    def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
        super().__init__()
        self.norm = LayerNorm(dim)

        attn_inner_dim = dim_head * heads
        ff_inner_dim = dim * ff_mult

        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads)
        self.ff = nn.Sequential(
            nn.Linear(dim, 2* ff_mult * dim),
            SwiGLU(),
            nn.Linear(ff_mult * dim, dim)
        )

    def forward(self, x):
        print("Input to ParallelTransformerBlock:", x.shape)

        x = self.norm(x)
        # print("After LayerNorm:", x.shape)

        x = x.permute(1, 0, 2)  # Rearrange for nn.MultiheadAttention
        # print("After permute for MultiheadAttention:", x.shape)

        attn_output, _ = self.attn(x, x, x)
        # print("After MultiheadAttention:", attn_output.shape)

        x = attn_output + x
        # print("After adding attn_output:", x.shape)

        x = x.permute(1, 0, 2)  # Rearrange back
        # print("After permute back:", x.shape)

        # ff_output = self.ff(x)
        # print("After FeedForward:", ff_output.shape)
        ff_output = x
        for layer in self.ff:
            if isinstance(layer, nn.Linear):
                # print("Input to Linear Layer:", ff_output.shape)
                ff_output = layer(ff_output)
                # print("Output from Linear Layer:", ff_output.shape)
            else:
                # Assuming SwiGLU or other non-linear layers don't change shape
                ff_output = layer(ff_output)

        output = ff_output + x
        print("Output from ParallelTransformerBlock:", output.shape)

        return output



In [None]:
from transformers import T5Tokenizer, T5EncoderModel
import torch
import re

In [None]:
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from transformers import T5ForConditionalGeneration, T5Tokenizer
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer


In [None]:
# T5ForConditionalGeneration.from_pretrained("Rostlab/prot_t5_xl_bfd").config.d_model

In [None]:
# # Load ProtT5 model
# protT5_model = T5ForConditionalGeneration.from_pretrained("Rostlab/prot_t5_xl_bfd")
# protT5_tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_bfd")

In [None]:
import torch
from torch import nn
from transformers import T5ForConditionalGeneration, T5Tokenizer

class ProtFlamingo(nn.Module):
    def __init__(self, num_tokens, depth, dim_head=64, heads=8, ff_mult=4, cross_attn_every=3, perceiver_num_latents=64, perceiver_depth=2, motif_mode=False):
        super().__init__()
        self.motif_embedding_projection = nn.Embedding(2, 1024) # Assuming binary one-hot encoding, projecting to 1024 dimensions

        self.dim = 1024  # Assuming the embedding dimension
        self.to_logits = nn.Linear(self.dim, num_tokens)
        self.protT5_model = T5ForConditionalGeneration.from_pretrained("Rostlab/prot_t5_xl_bfd")
        self.protT5_tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_bfd")

        self.perceiver_resampler = PerceiverResampler(dim=self.dim, depth=perceiver_depth, dim_head=dim_head, heads=heads, num_latents=perceiver_num_latents)
        self.expand_seq_len = nn.Linear(dim_head, 983)

        self.layers = nn.ModuleList([])
        for i in range(depth):
            # Note that we're no longer passing 'mlp_dim'
            self.layers.append(ParallelTransformerBlock(dim=self.dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult))
            if i % cross_attn_every == 0:
                self.layers.append(GatedCrossAttentionBlock(dim=self.dim, dim_head=dim_head, heads=heads))


    def forward(self, protein_embeddings, motif_encodings):

        # Interleave protein and motif embeddings
        combined_embeddings,motif_embeddings = self.interleave_embeddings(protein_embeddings, motif_encodings)
        print('combined embeddings...')

        # Process combined embeddings through the perceiver resampler
        processed_protein_embeddings = self.perceiver_resampler(combined_embeddings)
        print('processed embeddings...')
        print('perceiver resampler output embeddings shape:', processed_protein_embeddings.shape)

        for index, layer in enumerate(self.layers):
            print('index:', index)
            if isinstance(layer, GatedCrossAttentionBlock):
                # Pass motif_encodings as media
                print('output perceiver resampler shape:',processed_protein_embeddings.shape)
                print("input projected motif embedding shape:",motif_embeddings.shape)
                target_sequence = layer(binder_embeddings, processed_protein_embeddings) # cross attn between binder seq (text) and motif+target (image)
                print(f'layer {index} done w gated attn...')
                #print('target sequence:',target_sequence)
                #print('target sequence shape:',target_sequence.shape)
            else:
                target_sequence = layer()
                print(f'layer {index} done w/out gated attn...')
                #print('target sequence:',target_sequence)
                #print('target sequence shape:',target_sequence.shape)


        # Reshape to merge batch and embedding dimensions
        batch_size, seq_length, dim = target_sequence.shape
        target_sequence_reshaped = target_sequence.view(batch_size * dim, seq_length)
        # Apply linear transformation to expand sequence length
        expanded_sequence = self.expand_seq_len(target_sequence_reshaped)
        # Reshape back to separate batch and embedding dimensions
        expanded_sequence = expanded_sequence.view(batch_size, dim, 983).transpose(1, 2)
        print('after last linear transform layer shape:',expanded_sequence.shape)

        # Get the logits from the decoder output
        logits = self.protT5_model.lm_head(expanded_sequence)
        print('logits before argmax:',logits.shape)
        predicted_token_ids = logits.argmax(-1)  # Convert logits to token IDs, resulting in shape [2, 64]
        print('lm head decoding done...')
        print('predidcted token ids:',predicted_token_ids)
        print('predicted token ids shape:',predicted_token_ids.shape)

        return predicted_token_ids

    def interleave_embeddings(self, protein_embeddings, motif_one_hot):
        # Map one-hot encoding to embedding space
        motif_embeddings = self.motif_embedding_projection(motif_one_hot.long()) # Ensure it's long type for indexing
        print('motif_embedding shape projection:',motif_embeddings.shape)
        # Interleave embeddings
        combined_embeddings = torch.zeros(protein_embeddings.size(0), protein_embeddings.size(1) * 2, protein_embeddings.size(2), device=protein_embeddings.device)
        combined_embeddings[:, ::2] = protein_embeddings
        combined_embeddings[:, 1::2] = motif_embeddings
        print('combined_embeddings shape interleaved:',combined_embeddings.shape)
        return combined_embeddings,motif_embeddings



## Train Model

In [None]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    seq1_embeddings, seq2_embeddings, one_hot_scores, tokenized_seqs = zip(*batch)

    seq1_embeddings = pad_sequence(seq1_embeddings, batch_first=True)
    seq2_embeddings = pad_sequence(seq2_embeddings, batch_first=True)
    one_hot_scores = pad_sequence(one_hot_scores, batch_first=True)
    tokenized_seqs = pad_sequence(tokenized_seqs, batch_first=True, padding_value=tokenizer.pad_token_id)

    return seq1_embeddings, seq2_embeddings, one_hot_scores, tokenized_seqs


In [None]:
import torch
import torch.nn as nn
from torch.nn.functional import kl_div, log_softmax


In [None]:
import torch

# Assuming 'model', 'train_dataloader', 'val_dataloader', 'test_dataloader', and 'criterion' are already defined
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model = model.half() if device.type == 'cuda' else model.full()


optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)



In [None]:
train_dataloader.dataset

<__main__.ProteinInteractionDataset at 0x79df26776c80>

In [None]:
# Instantiate model, optimizer, and other training components
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Example parameters
num_tokens = 28 # protT5 vocab size
depth = 3  # Adjust based on model complexity and computational resources

In [None]:
model = ProtFlamingo(
    num_tokens=num_tokens,
    depth=depth,
    dim_head=64,
    heads=8,
    ff_mult=4,
    cross_attn_every=2,
    perceiver_num_latents=64,
    perceiver_depth=2
).to(device)


In [None]:
import torch.nn as nn

def train_epoch_ce(model, data_loader, optimizer, device):
    model.train()  # Ensure the model is in training mode
    total_loss = 0
    loss_function = nn.CrossEntropyLoss()

    for seq1_embeddings, seq2_embeddings, one_hot_scores, tokenized_seqs in data_loader:
        seq1_embeddings = seq1_embeddings.to(device).float()
        seq2_embeddings = seq2_embeddings.to(device).float()
        one_hot_scores = one_hot_scores.to(device).float()
        tokenized_seqs = tokenized_seqs.to(device).long()  # Convert tokenized_seqs to long

        optimizer.zero_grad()

        model_output = model(seq1_embeddings, one_hot_scores).to(device)  # Ensure model_output is on the correct device
        print('Model output shape:', model_output.shape)
        print('Tokenized seqs (target) shape:', tokenized_seqs.shape)

        # Calculate loss using CrossEntropyLoss
        loss = loss_function(model_output.view(-1, model_output.size(-1)), tokenized_seqs.view(-1))
        loss.backward()  # Compute gradients
        optimizer.step()  # Update parameters

        total_loss += loss.item()
        print('Current loss:', loss.item())

    return total_loss / len(data_loader)

# Train for one epoch
train_loss = train_epoch_ce(model, train_dataloader, optimizer, device)
print(f"Training Epoch: Loss = {train_loss}")


motif_embedding shape projection: torch.Size([2, 983, 1024])
combined_embeddings shape interleaved: torch.Size([2, 1966, 1024])
combined embeddings...
processed embeddings...
perceiver resampler output embeddings shape: torch.Size([2, 64, 1024])
index: 0
Input to ParallelTransformerBlock: torch.Size([2, 64, 1024])
Output from ParallelTransformerBlock: torch.Size([2, 64, 1024])
layer 0 done w/out gated attn...
index: 1
output perceiver resampler shape: torch.Size([2, 64, 1024])
input projected motif embedding shape: torch.Size([2, 983, 1024])
layer 1 done w gated attn...
index: 2
Input to ParallelTransformerBlock: torch.Size([2, 64, 1024])
Output from ParallelTransformerBlock: torch.Size([2, 64, 1024])
layer 2 done w/out gated attn...
index: 3
Input to ParallelTransformerBlock: torch.Size([2, 64, 1024])
Output from ParallelTransformerBlock: torch.Size([2, 64, 1024])
layer 3 done w/out gated attn...
index: 4
output perceiver resampler shape: torch.Size([2, 64, 1024])
input projected moti

RuntimeError: ignored

In [None]:
import torch.nn as nn

def train_epoch_ce(model, data_loader, optimizer, device):
    model.train()
    total_loss = 0
    loss_function = nn.CrossEntropyLoss()

    for seq1_embeddings, seq2_embeddings, one_hot_scores, tokenized_seqs in data_loader:
        seq1_embeddings = seq1_embeddings.float().to(device)
        seq2_embeddings = seq2_embeddings.float().to(device)
        one_hot_scores = one_hot_scores.float().to(device)
        tokenized_seqs = tokenized_seqs.float().to(device)  # Ensure tokenized_seqs are LongTensors

        optimizer.zero_grad()

        model_output = model(seq1_embeddings, one_hot_scores).float().to(device)
        print('tokenized seqs:', tokenized_seqs)
        print('tokenized seqs (target) shape:', tokenized_seqs.shape)
        print('model output shape:', model_output.shape)

        #CrossEntropyLoss
        loss = loss_function(model_output, tokenized_seqs)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        print('Current loss:', loss.item())

    return total_loss / len(data_loader)

# Train for one epoch
train_loss = train_epoch_ce(model, train_dataloader, optimizer, device)
print(f"Training Epoch: Loss = {train_loss}")


motif_embedding shape projection: torch.Size([2, 983, 1024])
combined_embeddings shape interleaved: torch.Size([2, 1966, 1024])
combined embeddings...
processed embeddings...
perceiver resampler output embeddings shape: torch.Size([2, 64, 1024])
index: 0
Input to ParallelTransformerBlock: torch.Size([2, 64, 1024])
Output from ParallelTransformerBlock: torch.Size([2, 64, 1024])
layer 0 done w/out gated attn...
index: 1
output perceiver resampler shape: torch.Size([2, 64, 1024])
input projected motif embedding shape: torch.Size([2, 983, 1024])
layer 1 done w gated attn...
index: 2
Input to ParallelTransformerBlock: torch.Size([2, 64, 1024])
Output from ParallelTransformerBlock: torch.Size([2, 64, 1024])
layer 2 done w/out gated attn...
index: 3
Input to ParallelTransformerBlock: torch.Size([2, 64, 1024])
Output from ParallelTransformerBlock: torch.Size([2, 64, 1024])
layer 3 done w/out gated attn...
index: 4
output perceiver resampler shape: torch.Size([2, 64, 1024])
input projected moti

RuntimeError: ignored

In [None]:
# Train for one epoch
train_loss = train_epoch_ce(model, train_dataloader, optimizer, device)
print(f"Training Epoch: Loss = {train_loss}")

In [None]:
protT5_tokenizer