<a href="https://colab.research.google.com/github/archiebenn/BIOLM0050_kaggle/blob/master/protein_embedding_protT5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q torch transformers sentencepiece h5py

In [2]:
import h5py
import numpy as np
import pandas as pd
import os
import torch

from transformers import T5EncoderModel, T5Tokenizer

## workflow aim:
protein sequence → truncate to shorten but keep N and C terminals for localisation info → ProtT5 embedding → single vector → classifier (other script)


## Load ProtT5

In [4]:
# set tokeniser
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokeniser = T5Tokenizer.from_pretrained(
    "Rostlab/prot_t5_xl_half_uniref50-enc",
    do_lower_case=False
)

model = T5EncoderModel.from_pretrained(
    "Rostlab/prot_t5_xl_half_uniref50-enc"
).to(device)

model.eval()

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/238k [00:00<?, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/656 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.42G [00:00<?, ?B/s]

Loading weights:   0%|          | 0/196 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/2.42G [00:00<?, ?B/s]



T5EncoderModel(
  (shared): Embedding(128, 1024)
  (encoder): T5Stack(
    (embed_tokens): Embedding(128, 1024)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=1024, out_features=4096, bias=False)
              (k): Linear(in_features=1024, out_features=4096, bias=False)
              (v): Linear(in_features=1024, out_features=4096, bias=False)
              (o): Linear(in_features=4096, out_features=1024, bias=False)
              (relative_attention_bias): Embedding(32, 32)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=1024, out_features=16384, bias=False)
              (wo): Linear(in_features=16384, out_features=1024, bias=False)
              (dropout): Dropo

In [5]:
# import training csv
from google.colab import files
uploaded = files.upload()


Saving test_trimmed.csv to test_trimmed.csv


In [6]:
# data setup
df_train = pd.read_csv("test_trimmed.csv")

print(f'Train DF length: {len(df_train)}')
df_train.head()


Train DF length: 4377


Unnamed: 0.1,Unnamed: 0,Id,acc,partition,sequence,length,molecular_weight,pI,gravy,aromaticity,...,aa_frac_M,aa_frac_N,aa_frac_P,aa_frac_Q,aa_frac_R,aa_frac_S,aa_frac_T,aa_frac_V,aa_frac_W,aa_frac_Y
0,0,5,P08516,4,MSVSALSSTRFTGSISGFLQVASVLGLLLLLVKAVQFYLQRQWLLK...,509,58214.494,9.064,-0.189,0.114,...,0.026,0.037,0.059,0.053,0.051,0.077,0.039,0.059,0.02,0.033
1,1,9,Q23544,4,MSVRRRTHSDDFSYLLEKTRRPSKLNVVQEDPKSAPPQGYSLTTVI...,508,56978.625,6.777,-0.309,0.059,...,0.033,0.043,0.059,0.03,0.071,0.077,0.057,0.075,0.004,0.033
2,2,14,Q96HR9,4,MDGLRQRVEHFLEQRNLVTEVLGALEAKTGVEKRYLAAGAVTLLSL...,211,23418.019,8.744,0.244,0.118,...,0.019,0.024,0.057,0.028,0.062,0.052,0.043,0.081,0.019,0.047
3,3,15,Q9JM62,4,MDGLRQRFERFLEQKNVATEALGALEARTGVEKRYLAAGALALLGL...,201,22203.619,6.824,0.375,0.134,...,0.015,0.015,0.05,0.02,0.055,0.05,0.04,0.06,0.02,0.05
4,4,17,Q8WUH2,4,MMSIKAFTLVSAVERELLMGDKERVNIECVECCGRDLYVGTNDCFV...,860,97156.698,6.103,-0.116,0.086,...,0.019,0.033,0.043,0.056,0.052,0.056,0.044,0.074,0.008,0.036


In [7]:
# keep just the Id and sequence in the data frame
df_seq_id = df_train[["Id", "sequence"]]
df_seq_id.head()

Unnamed: 0,Id,sequence
0,5,MSVSALSSTRFTGSISGFLQVASVLGLLLLLVKAVQFYLQRQWLLK...
1,9,MSVRRRTHSDDFSYLLEKTRRPSKLNVVQEDPKSAPPQGYSLTTVI...
2,14,MDGLRQRVEHFLEQRNLVTEVLGALEAKTGVEKRYLAAGAVTLLSL...
3,15,MDGLRQRFERFLEQKNVATEALGALEARTGVEKRYLAAGALALLGL...
4,17,MMSIKAFTLVSAVERELLMGDKERVNIECVECCGRDLYVGTNDCFV...


In [8]:
# testing for longest sequence in train.csv:
# Compute string lengths
df_train["sequence_lengths"] = df_train["sequence"].astype(str).str.len()

# sort in new df by seq length
df_sorted = df_train.sort_values("sequence_lengths", ascending=False)

df_sorted["sequence_lengths"].head(25)

Unnamed: 0,sequence_lengths
3008,6620
1207,5654
3094,5560
1103,5154
952,4749
2555,4743
1301,4684
951,4678
1979,4563
738,4377


## embedding protein function

In [9]:
def embed_protein(seq):

  # reformat sequences to be space separated AAs:
  seq = ' '.join(list(seq))

  inputs = tokeniser(seq, return_tensors="pt")
  inputs = {k: v.to(device) for k, v in inputs.items()}

  with torch.no_grad():
    outputs = model(**inputs)

  emb = outputs.last_hidden_state.squeeze(0).cpu().numpy()

  return emb.mean(axis=0)


### Chunking and pooling sequences
As some of the sequences are fa too large (up to 34k AAs) for ProtT5, will attempt to chunk and pool the sequences into ProtT5 instead:

## embedding chunked protein function


In [10]:
from IPython.terminal.embed import embed
def embed_chunky_protein(seq, chunk_size=1024, overlap=50):
    """
    embed a protein sequence in chunks and pool into a single fixed-length vector.
    Returns tensor of shape [hidden_dim].
    """
    seq_str = str(seq)

    # Skip empty sequences
    if len(seq_str) == 0:
        return None

    # Chunk sequence
    chunks = []
    start = 0
    while start < len(seq_str):
        end = min(start + chunk_size, len(seq_str))
        chunks.append(seq_str[start:end])
        start += chunk_size - overlap

    # embed each chunk and mean pool
    chunk_vector = []
    for chunk in chunks:
      try:
            emb = embed_protein(chunk)
            if emb is not None:

                # convert emb to a tensor
                chunk_vector.append(torch.tensor(emb))

      except Exception as e:
            print(f"Skipping a chunk due to error: {e}")
      finally:
            # free GPU memory if needed
            if 'emb' in locals():
                del emb
            torch.cuda.empty_cache()

    if len(chunk_vector) == 0:
        return None

    # mean pool over chunks to get one vector per protein:
    protein_vector = torch.mean(torch.stack(chunk_vector), dim=0)
    return protein_vector.cpu().numpy()

## Truncating sequences
Chunking and pooling taking too long and could contain unhelpful information contained within the middle of the sequences. Instead will try to keep the 1024AAs at start and end (longest seq = 2048AAs) as this will likely contain the most useful 'postcode' information.

In [11]:
# truncate function to 2048aa either end
def truncate_prot(seq, chunk_len = 2048):
  seq = str(seq)
  if len(seq) <= chunk_len*2:

    # short seq, return whole
    return seq

  else:
    # take first 2048 and last 2048 amino acids of sequence
    return seq[:chunk_len] + seq[-chunk_len:]


In [12]:
# just adding to shorten dataset for testing
#df_seq_id = df_seq_id.head(100)

In [13]:
# write out the .h5 file:

with h5py.File("test_protT5_half_2048aa.h5", "a") as f:
  for _, row in df_seq_id.iterrows():

    # 1. set prot_id as a string for h5py to work (for h5 file column)
    prot_id = str(row["Id"])

    # check if dataset already exists to handle duplicate IDs in input data
    if prot_id in f:
        print(f"Skipping {prot_id}, dataset already exists in the file (duplicate ID in input data).")
        continue

    # 2. run embed protein function on each sequence:
    seq = row["sequence"]

    # truncate protein to max. 256AAs
    seq = truncate_prot(seq)

    try:
            emb = embed_protein(seq)

            if emb is None:
                print(f"Skipping {prot_id}, empty sequence")
                continue

            # convert to fixed numeric type for HDF5
            emb = np.array(emb, dtype=np.float32)

            # write to HDF5
            f.create_dataset(prot_id, data=emb)

    except RuntimeError as e:
            print(f"Skipping {prot_id} due to runtime error: {e}")

    finally:
            # print to keep track of loop:
            print(f"Completed Id {prot_id}")
            # free GPU memory
            if 'emb' in locals():
                del emb
            torch.cuda.empty_cache()

Completed Id 5
Completed Id 9
Completed Id 14
Completed Id 15
Completed Id 17
Completed Id 18
Completed Id 25
Completed Id 28
Completed Id 31
Completed Id 39
Completed Id 53
Completed Id 54
Completed Id 56
Completed Id 57
Completed Id 61
Completed Id 62
Completed Id 68
Completed Id 71
Completed Id 84
Completed Id 92
Completed Id 95
Completed Id 96
Completed Id 115
Completed Id 129
Completed Id 133
Completed Id 134
Completed Id 135
Completed Id 140
Completed Id 142
Completed Id 143
Completed Id 150
Completed Id 155
Completed Id 156
Completed Id 163
Completed Id 164
Completed Id 167
Completed Id 170
Completed Id 174
Completed Id 176
Completed Id 177
Completed Id 180
Completed Id 189
Completed Id 191
Completed Id 192
Completed Id 200
Completed Id 202
Completed Id 203
Completed Id 204
Completed Id 225
Completed Id 227
Completed Id 236
Completed Id 267
Completed Id 280
Completed Id 281
Completed Id 302
Completed Id 303
Completed Id 310
Completed Id 311
Completed Id 312
Completed Id 322
Comp