In [1]:
from transformers import T5Tokenizer, T5EncoderModel
import torch
import pickle
import pandas as pd
import re
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [13]:
BATCH_SIZE = 8
DATAFRAME_PATH = "../custom_fragments2_bkp/datasets/random_equal_distribution2/real_fragments30.csv"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

cuda


In [18]:
df = pd.read_csv(DATAFRAME_PATH)
df

Unnamed: 0,source_accession,sequence,fragment_length,is_fragment
0,B3EWP6,HLLQFGDLINKIARRNGILYYSFYGCYCGLGGRGRPQDATDRCCFV...,60,True
1,C0HJQ3,MSGRGKTGGKARAKAKTRSSRAGLQFPVGRVHRLLRKGNYAQRVGA...,76,True
2,C0HLV7,FQLNVIANNNNHTMLTQTTIHCHGFFQGTNSADGHAFVNCCPIASG...,237,True
3,C7E9W0,ASVTFWTLDNVDRTLVFTGNPGSAAIETITVGPAENTTVEFPGSWV...,144,True
4,F1NBL0,CSTWGGGHFSTFDKYQYDFTGTCNYIFATVCDESSPDFNIQFRRGL...,1185,True
...,...,...,...,...
4898,Q9R4P1,VQIFVRDNNVDQALKALK,18,True
4899,Q9R4P4,TKIADLRSQTVDQLSDXLXKL,21,True
4900,Q9R4P6,VQIFVXDNNVDQALK,15,True
4901,Q9R5V8,MKATELREKSAQQLNXQLL,19,True


In [8]:
tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', dtype=torch.float16)
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc", dtype=torch.float16)
model.to(DEVICE)
model.eval()

T5EncoderModel(
  (shared): Embedding(128, 1024)
  (encoder): T5Stack(
    (embed_tokens): Embedding(128, 1024)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=1024, out_features=4096, bias=False)
              (k): Linear(in_features=1024, out_features=4096, bias=False)
              (v): Linear(in_features=1024, out_features=4096, bias=False)
              (o): Linear(in_features=4096, out_features=1024, bias=False)
              (relative_attention_bias): Embedding(32, 32)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=1024, out_features=16384, bias=False)
              (wo): Linear(in_features=16384, out_features=1024, bias=False)
              (dropout): Dropo

In [19]:
length_cut_df = df[df["fragment_length"] <= 1024].copy()
length_cut_df.head()

Unnamed: 0,source_accession,sequence,fragment_length,is_fragment
0,B3EWP6,HLLQFGDLINKIARRNGILYYSFYGCYCGLGGRGRPQDATDRCCFV...,60,True
1,C0HJQ3,MSGRGKTGGKARAKAKTRSSRAGLQFPVGRVHRLLRKGNYAQRVGA...,76,True
2,C0HLV7,FQLNVIANNNNHTMLTQTTIHCHGFFQGTNSADGHAFVNCCPIASG...,237,True
3,C7E9W0,ASVTFWTLDNVDRTLVFTGNPGSAAIETITVGPAENTTVEFPGSWV...,144,True
7,O02827,FRLVEKKTGKVWAGKFFKAYSAKEKENIRQEISIMNCLHHPKLVQC...,438,True


In [20]:
length_cut_df["sequence_split"] = length_cut_df["sequence"].apply(
    lambda seq: " ".join(list(re.sub(r"[UZOB]", "X", seq)))
)
length_cut_df.head()

Unnamed: 0,source_accession,sequence,fragment_length,is_fragment,sequence_split
0,B3EWP6,HLLQFGDLINKIARRNGILYYSFYGCYCGLGGRGRPQDATDRCCFV...,60,True,H L L Q F G D L I N K I A R R N G I L Y Y S F ...
1,C0HJQ3,MSGRGKTGGKARAKAKTRSSRAGLQFPVGRVHRLLRKGNYAQRVGA...,76,True,M S G R G K T G G K A R A K A K T R S S R A G ...
2,C0HLV7,FQLNVIANNNNHTMLTQTTIHCHGFFQGTNSADGHAFVNCCPIASG...,237,True,F Q L N V I A N N N N H T M L T Q T T I H C H ...
3,C7E9W0,ASVTFWTLDNVDRTLVFTGNPGSAAIETITVGPAENTTVEFPGSWV...,144,True,A S V T F W T L D N V D R T L V F T G N P G S ...
7,O02827,FRLVEKKTGKVWAGKFFKAYSAKEKENIRQEISIMNCLHHPKLVQC...,438,True,F R L V E K K T G K V W A G K F F K A Y S A K ...


In [21]:
tokenized_df = pd.DataFrame(data=length_cut_df[["source_accession", "is_fragment", "sequence_split"]])
tokenized_df.head()

Unnamed: 0,source_accession,is_fragment,sequence_split
0,B3EWP6,True,H L L Q F G D L I N K I A R R N G I L Y Y S F ...
1,C0HJQ3,True,M S G R G K T G G K A R A K A K T R S S R A G ...
2,C0HLV7,True,F Q L N V I A N N N N H T M L T Q T T I H C H ...
3,C7E9W0,True,A S V T F W T L D N V D R T L V F T G N P G S ...
7,O02827,True,F R L V E K K T G K V W A G K F F K A Y S A K ...


In [22]:
tokenized_input = []

In [24]:
for i in tqdm(range(0, len(tokenized_df), BATCH_SIZE)):
    batch = tokenized_df.loc[i:i+BATCH_SIZE - 1]

    encoded = tokenizer.batch_encode_plus(
        batch["sequence_split"],
        add_special_tokens=True,
        padding="longest",
        return_tensors="pt"
    )

    output = []

    for acc_id, is_fragment, input_ids, attention_mask in zip(batch["source_accession"], batch["is_fragment"], encoded["input_ids"], encoded["attention_mask"]):
        output.append((acc_id, is_fragment, input_ids, attention_mask))

    tokenized_input.append(output)

100%|██████████| 607/607 [00:00<00:00, 839.94it/s] 


In [26]:
len(tokenized_input)

607

In [27]:
all_embeddings = []

In [28]:
import os
os.makedirs("./embeddings", exist_ok=True)

In [30]:
with torch.no_grad():
    for (batch_num, entries) in tqdm(enumerate(tokenized_input), total=len(tokenized_input)):
        for acc_id, is_fragment, input_ids, attention_mask in entries:
            embedding = model(
                input_ids=input_ids.unsqueeze(0).to(DEVICE),
                attention_mask=attention_mask.unsqueeze(0).to(DEVICE)
            ).last_hidden_state
            print(f"Accession ID: {acc_id}")
            print(f"Fragment: {is_fragment}")
            print(f"Embedding shape: {embedding.shape}")  # Should be [1, seq_len, 1024]
            print(f"Attention mask shape: {attention_mask.shape}")  # Should be [seq_len]

            break
        break

  0%|          | 0/607 [00:00<?, ?it/s]

Accession ID: B3EWP6
Fragment: True
Embedding shape: torch.Size([1, 439, 1024])
Attention mask shape: torch.Size([439])





In [31]:
with torch.no_grad():
    for (batch_num, entries) in tqdm(enumerate(tokenized_input), total=len(tokenized_input)):
        # Collect batch data
        acc_ids = []
        fragment_status = []
        #fragment_types = []
        batch_input_ids = []
        batch_attention_masks = []

        for (acc_id, is_fragment, input_ids, attention_mask) in entries:
            acc_ids.append(acc_id)
            fragment_status.append(is_fragment)
            #fragment_types.append(fragment_type)
            batch_input_ids.append(input_ids)
            batch_attention_masks.append(attention_mask)

        # Stack into proper batches [batch_size, seq_len]
        batch_input_ids = torch.stack(batch_input_ids).to(DEVICE)
        batch_attention_masks = torch.stack(batch_attention_masks).to(DEVICE)

        # Get embeddings [batch_size, seq_len, embedding_dim]
        embeddings = model(
            input_ids=batch_input_ids,
            attention_mask=batch_attention_masks
        ).last_hidden_state

        # Mean pooling: average over sequence length, weighted by attention mask
        # Expand attention mask to match embedding dimensions
        mask = batch_attention_masks.unsqueeze(-1).expand(embeddings.size()).float()

        # Sum embeddings where attention mask is 1
        summed = torch.sum(embeddings * mask, dim=1)
        # Count how many positions were summed (sequence lengths)
        lengths = torch.sum(mask, dim=1)
        # Average: [batch_size, embedding_dim]
        prot_embeds = summed / lengths

        # Store embeddings with their accession IDs
        for i, (acc_id, is_fragment) in enumerate(zip(acc_ids, fragment_status)):
            all_embeddings.append((acc_id, is_fragment, prot_embeds[i].cpu()))

        # Progress logging
        if batch_num % 100 == 0:
            print(f"Processed {batch_num * BATCH_SIZE} sequences...")

        # Periodic saving
        if batch_num % 1000 == 0 and batch_num > 0:
            with open(f"./embeddings/embeddings_{batch_num}.pkl", "wb") as p:
                pickle.dump(all_embeddings, p)
            print(f"Saved checkpoint at batch {batch_num}")

# Save final embeddings
with open(f"./embeddings/final_embeddings.pkl", "wb") as p:
    pickle.dump(all_embeddings, p)

print(f"\nComplete! Generated {len(all_embeddings)} embeddings.")

  0%|          | 1/607 [00:00<02:06,  4.78it/s]

Processed 0 sequences...


 17%|█▋        | 103/607 [00:29<00:39, 12.67it/s]

Processed 800 sequences...


 33%|███▎      | 201/607 [00:48<02:26,  2.77it/s]

Processed 1600 sequences...


 50%|████▉     | 301/607 [01:09<00:36,  8.37it/s]

Processed 2400 sequences...


 66%|██████▌   | 402/607 [01:23<00:08, 24.37it/s]

Processed 3200 sequences...


 83%|████████▎ | 501/607 [01:36<00:29,  3.64it/s]

Processed 4000 sequences...


100%|█████████▉| 604/607 [01:44<00:00, 37.84it/s]

Processed 4800 sequences...


100%|██████████| 607/607 [01:44<00:00,  5.81it/s]


Complete! Generated 4806 embeddings.



