# ProtFlamingo with ESM-2 Embeddings (Pure PPI)

## Data + Preprocess

In [1]:
from google.colab import drive
drive.mount('/content/drive')
import os

Mounted at /content/drive


In [2]:
os.chdir('/content/drive/MyDrive/Programmable Biology Group/Srikar/Code/flamingo-pep-gen/ppi-mlo915')

In [3]:
!ls

ppigpt_test_merged_MI0915_LTPHTP_oct3_2023.csv	ppigpt_train_merged_MI0915_LTPHTP_oct3_2023.csv


In [4]:
import pandas as pd
train = pd.read_csv('ppigpt_train_merged_MI0915_LTPHTP_oct3_2023.csv',index_col=0)
test = pd.read_csv('ppigpt_test_merged_MI0915_LTPHTP_oct3_2023.csv',index_col=0)

train = train[['seq_1','seq_2']]
test = test[['seq_1','seq_2']]

In [5]:
test.head(5)

Unnamed: 0,seq_1,seq_2
1,MKKWSSTDLGAAADPLQKDTCPDPLDGDPNSRPPPAKPQLSTAKSR...,MPGARDALCHQALQLLAELCARGALEHDSCQDFIYHLRDRARPRLR...
7,MQAEIKADIIVEAMEVLVNHILYVRGIYPSHIFKMKRMYNSPIYVS...,MGSALENYVNQVRTLSASGSYRELAEELPESLSLLARNWSILDNVL...
31,MTYTTRQIGAKNTLEYKVYIEKDGKPVSAFHDIPLYADKENNIFNM...,MVNQGQPQPNLYDKHINMFPPARARESSHKLGNANSDRHGLPAQNI...
52,MTDETAHPTQSASKQESAALKQTGDDQQESQQQRGYTNYNNGSNYT...,MWNPILLDTSSFSFQKHVSGVFLQVRNATKRAAGSRTSMKDSAGRR...
57,MRSVTNAFGNSGELNDQVDETGYRKFDIHEGILFCIELSETMFKES...,MNENEYDNFDDLDDLLDEDPTKLDEAEPDDVQAKGSVYNDSENKEK...


## Helper Functions + Gated Cross Attn + Perceiver Resampler

In [6]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import torch.nn.functional as F
# from transformers import RobertaModel  # Assuming use of Hugging Face's transformer models



# Helper Functions
def exists(val):
    return val is not None

def set_module_requires_grad_(module, requires_grad):
    for param in module.parameters():
        param.requires_grad = requires_grad

def freeze_model_and_make_eval_(model):
    model.eval()
    set_module_requires_grad_(model, False)

# LayerNorm class
class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super().__init__()
        self.gain = nn.Parameter(torch.ones(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.gain * (x - mean) / (std + self.eps)

# Residual class
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

# SwiGLU activation function
class SwiGLU(nn.Module):
    def forward(self, x):
        return F.silu(x[..., :x.shape[-1] // 2]) * x[..., x.shape[-1] // 2:]

# Transformer Block class
class TransformerBlock(nn.Module):
    def __init__(self, dim, heads, mlp_dim):
        super().__init__()
        self.ln1 = LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, heads)
        self.ln2 = LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            SwiGLU(),
            nn.Linear(mlp_dim, dim)
        )
        self.residual = Residual(self.ln1)
        self.feedforward = Residual(self.ln2)

    def forward(self, x):
        x = self.residual(self.attn(x, x, x)[0])
        x = self.feedforward(self.mlp(x))
        return x

In [7]:
!pip install transformers



In [8]:
!pip install einops-exts

Collecting einops-exts
  Downloading einops_exts-0.0.4-py3-none-any.whl (3.9 kB)
Collecting einops>=0.4 (from einops-exts)
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops, einops-exts
Successfully installed einops-0.7.0 einops-exts-0.0.4


In [9]:
import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops_exts import rearrange_many, repeat_many

def exists(val):
    return val is not None

def FeedForward(dim, mult = 4):
    inner_dim = int(dim * mult)
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, inner_dim, bias = False),
        nn.GELU(),
        nn.Linear(inner_dim, dim, bias = False)
    )

class PerceiverAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 8
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm_media = nn.LayerNorm(dim)
        self.norm_latents = nn.LayerNorm(dim)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x, latents):
        """
        einstein notation
        b - batch
        t - time
        n - sequence
        d - dimension
        """
        x = self.norm_media(x)
        latents = self.norm_latents(latents)

        b, m, h = *x.shape[:2], self.heads

        q = self.to_q(latents)

        # the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to
        kv_input = torch.cat((x, latents), dim = -2)
        k, v = self.to_kv(kv_input).chunk(2, dim = -1)

        q, k, v = rearrange_many((q, k, v), 'b t n (h d) -> b h t n d', h = h)

        q = q * self.scale

        # attention

        sim = einsum('... i d, ... j d  -> ... i j', q, k)

        sim = sim - sim.amax(dim = -1, keepdim = True).detach()
        attn = sim.softmax(dim = -1)

        out = einsum('... i j, ... j d -> ... i d', attn, v)
        out = rearrange(out, 'b h t n d -> b t n (h d)', h = h)
        return self.to_out(out)

class PerceiverResampler(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        num_latents = 64,
        num_media_embeds = 4,
        ff_mult = 4
    ):
        super().__init__()
        self.latents = nn.Parameter(torch.randn(num_latents, dim))
        self.media_pos_emb = nn.Parameter(torch.randn(num_media_embeds, 1, dim))


        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads),
                FeedForward(dim = dim, mult = ff_mult)
            ]))

        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        if x.ndim == 3:
            x = rearrange(x, 'b n d -> b 1 n d')
            print(x.shape)

        times = x.shape[1]
        print(self.media_pos_emb.shape)
        x = x + self.media_pos_emb[:times]
        # print(x.shape)
        # print(self.media_pos_emb[:times].shape)

        latents = repeat(self.latents, 'n d -> b m n d', b = x.shape[0], m = x.shape[1])

        for attn, ff in self.layers:
            latents = attn(x, latents) + latents
            latents = ff(latents) + latents

        return self.norm(latents)

# gated cross attention

class MaskedCrossAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 8,
        only_attend_immediate_media = True
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm = nn.LayerNorm(dim)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

        # whether for text to only attend to immediate preceding image, or all images

        self.only_attend_immediate_media = only_attend_immediate_media

    def forward(
        self,
        x,
        media,
        media_locations = None
    ):
        b, t, m = media.shape[:3]
        h = self.heads

        x = self.norm(x)

        q = self.to_q(x)
        media = rearrange(media, 'b t n d -> b (t n) d')

        k, v = self.to_kv(media).chunk(2, dim = -1)
        q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)

        q = q * self.scale

        sim = einsum('... i d, ... j d -> ... i j', q, k)

        if exists(media_locations):
            text_time = media_locations.cumsum(dim = -1) # at each boolean of True, increment the time counter (relative to media time)
            media_time = torch.arange(t, device = x.device) + 1

            # text time must equal media time if only attending to most immediate image
            # otherwise, as long as text time is greater than media time (if attending to all previous images / media)
            mask_op = torch.eq if self.only_attend_immediate_media else torch.ge

            text_to_media_mask = mask_op(rearrange(text_time, 'b i -> b 1 i 1'), repeat(media_time, 'j -> 1 1 1 (j m)', m = m))
            sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)

        sim = sim - sim.amax(dim = -1, keepdim = True).detach()
        attn = sim.softmax(dim = -1)

        if exists(media_locations) and self.only_attend_immediate_media:
            # any text without a preceding media needs to have attention zeroed out
            text_without_media_mask = text_time == 0
            text_without_media_mask = rearrange(text_without_media_mask, 'b i -> b 1 i 1')
            attn = attn.masked_fill(text_without_media_mask, 0.)

        out = einsum('... i j, ... j d -> ... i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class GatedCrossAttentionBlock(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        only_attend_immediate_media = True
    ):
        super().__init__()
        self.attn = MaskedCrossAttention(dim = dim, dim_head = dim_head, heads = heads, only_attend_immediate_media = only_attend_immediate_media)
        self.attn_gate = nn.Parameter(torch.tensor([0.]))

        self.ff = FeedForward(dim, mult = ff_mult)
        self.ff_gate = nn.Parameter(torch.tensor([0.]))

    def forward(
        self,
        x,
        media,                  # media tensor, encoded by perceiver resample - (batch, time, latents, dim)
        media_locations = None  # boolean tensor indicating positions of media - (batch, sequence)
    ):
        x = self.attn(x, media, media_locations = media_locations) * self.attn_gate.tanh() + x
        x = self.ff(x) * self.ff_gate.tanh()  + x
        return x

## ProtFlamingo and Model Train

In [10]:
!pip install transformers



In [11]:
!pip install fair-esm

Collecting fair-esm
  Downloading fair_esm-2.0.0-py3-none-any.whl (93 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/93.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fair-esm
Successfully installed fair-esm-2.0.0


In [12]:
!pip install sentencepiece

Collecting sentencepiece
  Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.3 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.1/1.3 MB[0m [31m2.6 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.3/1.3 MB[0m [31m21.7 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m18.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: sentencepiece
Successfully installed sentencepiece-0.1.99


In [13]:
import numpy as np

In [14]:
train = train.sample(n=4, random_state=np.random.RandomState())
test = test.sample(n=4, random_state = np.random.RandomState())

In [15]:
import sentencepiece

In [16]:
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from transformers import T5ForConditionalGeneration, T5Tokenizer
import esm
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer


# Load ProtT5 model
protT5_model = T5ForConditionalGeneration.from_pretrained("Rostlab/prot_t5_xl_bfd")
protT5_tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_bfd")

# Load ESM-2 model for embeddings
esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
esm_model.eval()
if torch.cuda.is_available():
    esm_model.cuda()

# Function to generate embeddings for a list of sequences
def generate_esm_embeddings(rbp_seqs):
    rbp_seqs_dict = {}
    for seq in tqdm(rbp_seqs):
        # Convert sequence to model input format
        batch_labels, batch_strs, batch_tokens = batch_converter([("", seq)])
        batch_tokens = batch_tokens.to('cuda' if torch.cuda.is_available() else 'cpu')
        with torch.no_grad():
            results = esm_model(batch_tokens, repr_layers=[33])
        token_representations = results["representations"][33]
        seq_len = (batch_tokens != alphabet.padding_idx).sum(1).item()
        rbp_seqs_dict[seq] = token_representations[0, 1:seq_len-1].cpu()
    return rbp_seqs_dict

# Custom Dataset
class ProteinInteractionDataset(Dataset):
    def __init__(self, dataframe, esm_embeddings):
        self.dataframe = dataframe
        self.esm_embeddings = esm_embeddings

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        prot1_seq, prot2_seq = self.dataframe.iloc[idx]['seq_1'], self.dataframe.iloc[idx]['seq_2']
        prot1_embedding = self.esm_embeddings[prot1_seq]
        prot2_tokenized = protT5_tokenizer.encode(prot2_seq, return_tensors="pt").squeeze()
        return prot1_embedding, prot2_tokenized


config.json:   0%|          | 0.00/457 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/11.3G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/24.0 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/238k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/1.79k [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t33_650M_UR50D-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D-contact-regression.pt


In [17]:
from torch.nn.functional import pad

def collate_fn(batch):
    # Find the longest sequence in the batch
    max_length = max(max(prot1_emb.size(0), prot2_tok.size(0)) for prot1_emb, prot2_tok in batch)

    prot1_embeddings_padded = []
    prot2_tokenized_padded = []

    for prot1_emb, prot2_tok in batch:
        # Pad each sequence to the max_length
        prot1_emb_padded = pad(prot1_emb, (0, 0, 0, max_length - prot1_emb.size(0)))
        prot2_tok_padded = pad(prot2_tok, (0, max_length - prot2_tok.size(0)))

        prot1_embeddings_padded.append(prot1_emb_padded)
        prot2_tokenized_padded.append(prot2_tok_padded)

    return torch.stack(prot1_embeddings_padded), torch.stack(prot2_tokenized_padded)

# Prepare  train and test data
train_df = train # train DataFrame
test_df = test  # test DataFrame

esm_embeddings = generate_esm_embeddings(pd.concat([train_df['seq_1'], test_df['seq_1']]).unique())
train_dataset = ProteinInteractionDataset(train_df, esm_embeddings)
test_dataset = ProteinInteractionDataset(test_df, esm_embeddings)

# Use collate_fn in DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=2, collate_fn=collate_fn)




100%|██████████| 8/8 [00:03<00:00,  2.09it/s]


In [18]:
!pip install einops



In [19]:

class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        return self.norm(x)

class SwiGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return F.silu(gate) * x

class ParallelTransformerBlock(nn.Module):
    def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
        super().__init__()
        self.norm = LayerNorm(dim)

        attn_inner_dim = dim_head * heads
        ff_inner_dim = dim * ff_mult

        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads)
        self.ff = nn.Sequential(
            nn.Linear(dim, 2* ff_mult * dim),
            SwiGLU(),
            nn.Linear(ff_mult * dim, dim)
        )

    def forward(self, x):
        print("Input to ParallelTransformerBlock:", x.shape)

        x = self.norm(x)
        print("After LayerNorm:", x.shape)

        x = x.permute(1, 0, 2)  # Rearrange for nn.MultiheadAttention
        print("After permute for MultiheadAttention:", x.shape)

        attn_output, _ = self.attn(x, x, x)
        print("After MultiheadAttention:", attn_output.shape)

        x = attn_output + x
        print("After adding attn_output:", x.shape)

        x = x.permute(1, 0, 2)  # Rearrange back
        print("After permute back:", x.shape)

        # ff_output = self.ff(x)
        # print("After FeedForward:", ff_output.shape)
        ff_output = x
        for layer in self.ff:
            if isinstance(layer, nn.Linear):
                print("Input to Linear Layer:", ff_output.shape)
                ff_output = layer(ff_output)
                print("Output from Linear Layer:", ff_output.shape)
            else:
                # Assuming SwiGLU or other non-linear layers don't change shape
                ff_output = layer(ff_output)

        output = ff_output + x
        print("Output from ParallelTransformerBlock:", output.shape)

        return output



In [20]:
T5ForConditionalGeneration.from_pretrained("Rostlab/prot_t5_xl_bfd").config.d_model

1024

In [21]:

class ProtFlamingo(nn.Module):
    def __init__(
        self,
        *,
        num_tokens,
        depth,
        dim_head=64,
        heads=8,
        ff_mult=4,
        cross_attn_every=3,
        perceiver_num_latents=64,
        perceiver_depth=2,
        only_attend_immediate_media=True,
        protein_mode=False, motif_mode=False
    ):
        super().__init__()

        # Add flags for protein and motif modes
        self.protein_mode = protein_mode
        self.motif_mode = motif_mode

        # Apply layer freezing based on the mode
        self._apply_layer_freezing()


        # ProtT5 model
        self.protT5 = T5ForConditionalGeneration.from_pretrained("Rostlab/prot_t5_xl_bfd")
        self.prot_dim = self.protT5.config.d_model
        self.dim = train_dataloader.dataset[0][0].shape[1]
        self.tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_bfd")


        # Check if adjustment is needed and add a linear layer for embedding size adjustment
        if self.protT5.config.d_model != 1280:
            self.embedding_adjustment = nn.Linear(self.protT5.config.d_model, 1280)
        else:
            self.embedding_adjustment = None

        # Perceiver Resampler for processing ESM-2 embeddings
        self.perceiver_resampler = PerceiverResampler(
            dim=self.dim,
            depth=perceiver_depth,
            dim_head=dim_head,
            heads=heads,
            num_latents=perceiver_num_latents
        )
        print('perceiver done...')
        print(self.dim)

        # Flamingo-like layers with Gated Cross Attention and ParallelTransformerBlock
        self.layers = nn.ModuleList([])
        for ind in range(depth):
            parallel_transformer_block = ParallelTransformerBlock(dim=self.dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)
            self.layers.append(parallel_transformer_block)

            if ind % cross_attn_every == 0:
                gated_cross_attention_block = GatedCrossAttentionBlock(
                    dim=self.dim,
                    dim_head=dim_head,
                    heads=heads,
                    only_attend_immediate_media=only_attend_immediate_media
                )
                self.layers.append(gated_cross_attention_block)

        print('gated cross attn done..')


        # Output layer
        self.to_logits = nn.Linear(self.dim, num_tokens)

    def _apply_layer_freezing(self):
        # Check if in protein or motif mode
        if self.protein_mode or self.motif_mode:
            # In protein/motif mode, freeze everything but perceiver and gated cross attention
            self._freeze_all_layers()
            self._unfreeze_layers(self.perceiver_resampler)
            for _, layer in self.layers:
                if isinstance(layer, GatedCrossAttentionBlock):
                    self._unfreeze_layers(layer)
        else:
            # Unfreeze all layers in other modes
            self._unfreeze_layers(self)

    def _freeze_all_layers(self):
        for param in self.parameters():
            param.requires_grad = False

    def _unfreeze_layers(self, module):
        for param in module.parameters():
            param.requires_grad = True

    def forward(self, prot1_embeddings, generated_sequence):
        # Process prot1 embeddings through perceiver resampler
        processed_prot1_embeddings = self.perceiver_resampler(prot1_embeddings)
        print('processed_prot1 embeddings')

        # Prepare the generated sequence for input to the model
        #generated_sequence_tensor = torch.tensor([[generated_sequence]], dtype=torch.long).to(prot1_embeddings.device)
        generated_sequence_tensor = generated_sequence.to(prot1_embeddings.device)

        print('prepare gen seq')


        # Embed the generated sequence using the ProtT5 decoder's embedding layer
        sequence_emb = self.protT5.decoder.embed_tokens(generated_sequence_tensor) * (self.dim ** 0.5)
        # Apply the embedding size adjustment if the layer was initialized
        if self.embedding_adjustment is not None:
            sequence_emb = self.embedding_adjustment(sequence_emb)
        print('embed generated seq')

        # Process the sequence through the layers of the model
        for layer in self.layers:
            if isinstance(layer, ParallelTransformerBlock):
                print('parelleltransfomr')
                print(sequence_emb.shape)
                sequence_emb = layer(sequence_emb)
            elif isinstance(layer, GatedCrossAttentionBlock):
                print('gated cross attn')
                print(sequence_emb.shape)
                sequence_emb = layer(sequence_emb, processed_prot1_embeddings)
        print('process gen seq thru model layers')

        # Output the logits for the next token
        next_token_logits = self.to_logits(sequence_emb[:, -1, :])

        return next_token_logits


In [22]:
# def train_epoch(model, data_loader, optimizer, criterion, device):
#     model.train()
#     total_loss = 0

#     for input_protein_embedding, target_sequences in data_loader:
#         input_protein_embedding = input_protein_embedding.to(device).squeeze(1)  # Adjust dimensions if needed

#         batch_loss = 0

#         for target_sequence in target_sequences:  # Iterate over each sequence in the batch
#             optimizer.zero_grad()
#             loss = 0

#             # Start with a start token
#             generated_sequence = 3 ## token ids
#             print(generated_sequence)

#             for i in range(1, target_sequence.size(0)):  # Iterate over each token in the sequence
#                 target_token = target_sequence[i].item()  # Now this should be a single element

#                 # Predict the next token
#                 next_token_logits = model(input_protein_embedding, generated_sequence)

#                 # Compute loss for the current step
#                 loss += criterion(next_token_logits, torch.tensor([target_token], dtype=torch.long).to(device))

#                 # Update the generated sequence
#                 generated_sequence.append(target_token)

#             loss.backward()
#             optimizer.step()

#             batch_loss += loss.item() / target_sequence.size(0)  # Normalize by sequence length

#         total_loss += batch_loss / len(target_sequences)  # Normalize by batch size

#     return total_loss / len(data_loader)


In [23]:
def train_epoch(model, data_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0

    for input_protein_embedding, target_sequences in data_loader:
        # Move input and target sequences to the device
        input_protein_embedding = input_protein_embedding.to(device).squeeze(1)
        target_sequences = target_sequences.to(device).squeeze(1)

        # Initialize generated_sequence for the entire batch with the start token
        start_token_id = 3  # Replace with your actual start token ID
        batch_size = input_protein_embedding.size(0)
        generated_sequence = torch.full((batch_size, 1), start_token_id, dtype=torch.long).to(device)  # Shape [batch_size, 1]

        optimizer.zero_grad()
        loss = 0

        for i in range(1, target_sequences.size(1)):  # Iterate over each token in the sequence
            # Predict the next token
            next_token_logits = model(input_protein_embedding, generated_sequence)

            # Flatten output for loss calculation
            next_token_logits = next_token_logits.view(-1, next_token_logits.size(-1))

            # Compute loss for the current step
            current_token = target_sequences[:, i]  # Shape [batch_size]
            loss += criterion(next_token_logits, current_token)

            # Update the generated sequence for the next iteration
            current_token = current_token.unsqueeze(1)  # Add sequence length dimension
            generated_sequence = torch.cat([generated_sequence, current_token], dim=1)

        loss.backward()
        optimizer.step()

        batch_loss = loss.item() / target_sequences.size(1)  # Normalize by sequence length
        total_loss += batch_loss

    return total_loss / len(data_loader)


In [24]:
# Instantiate model, optimizer, and other training components
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Example parameters
num_tokens = protT5_tokenizer.vocab_size
depth = 2  # Adjust based on model complexity and computational resources

# Instantiate the model
model = ProtFlamingo(
    num_tokens=num_tokens,
    depth=depth
).to(device)


optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

perceiver done...
1280
gated cross attn done..


In [25]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Hyperparameters
learning_rate = 1e-4
batch_size = 2
num_epochs = 10


In [26]:

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

In [27]:
train_dataloader.dataset[0][0].shape[1]

1280

In [28]:


num_epochs = 1

# Training loop
for epoch in range(num_epochs):
    train_loss = train_epoch(model, train_dataloader,optimizer,criterion,device)
    print(f"Epoch {epoch}: Training Loss: {train_loss}")

# Save the model after training
torch.save(model.state_dict(), 'prot_flamingo_model.pth')

# Define inference function for generating protein sequences


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
torch.Size([2, 1, 523, 1280])
torch.Size([4, 1, 1280])
processed_prot1 embeddings
prepare gen seq
embed generated seq
parelleltransfomr
torch.Size([2, 62, 1280])
Input to ParallelTransformerBlock: torch.Size([2, 62, 1280])
After LayerNorm: torch.Size([2, 62, 1280])
After permute for MultiheadAttention: torch.Size([62, 2, 1280])
After MultiheadAttention: torch.Size([62, 2, 1280])
After adding attn_output: torch.Size([62, 2, 1280])
After permute back: torch.Size([2, 62, 1280])
Input to Linear Layer: torch.Size([2, 62, 1280])
Output from Linear Layer: torch.Size([2, 62, 10240])
Input to Linear Layer: torch.Size([2, 62, 5120])
Output from Linear Layer: torch.Size([2, 62, 1280])
Output from ParallelTransformerBlock: torch.Size([2, 62, 1280])
gated cross attn
torch.Size([2, 62, 1280])
parelleltransfomr
torch.Size([2, 62, 1280])
Input to ParallelTransformerBlock: torch.Size([2, 62, 1280])
After LayerNorm: torch.Size([2, 62, 1280

OutOfMemoryError: ignored

In [None]:
protT5_tokenizer