In [13]:
import torch
import torch.nn as nn
from transformers import BertModel, BertConfig
from torch.utils.data import DataLoader, Dataset, dataloader
import torch.optim as optim
from Bio import SeqIO
from transformers import AutoTokenizer

In [14]:
def tokenize_kmers(sequence, k=3):
    return [sequence[i:i + k] for i in range(len(sequence) - k + 1)]

In [15]:
def esm_tokenizer(sequences, max_length=512):
    tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
    tokenized = tokenizer(
        sequences,
        return_tensors="pt",
        padding='max_length',
        truncation=True,
        max_length=max_length
    )
    return tokenized['input_ids'], tokenized['attention_mask']

In [20]:
class ProteinBERT(nn.Module):
    def __init__(self, vocab_size, hidden_size=768, num_hidden_layers=12, num_attention_heads=12):
        super().__init__()
        config = BertConfig(
            vocab_size=vocab_size,
            hidden_size=hidden_size,
            num_hidden_layers=num_hidden_layers,
            num_attention_heads=num_attention_heads,
            intermediate_size=4 * hidden_size,
            max_position_embeddings=1024,
            type_vocab_size=1,
            pad_token_id=0,
        )
        self.bert = BertModel(config)
        self.mlm_head = nn.Linear(hidden_size, vocab_size)

    def forward(self, input_ids, attention_mask):
        input_ids = input_ids.squeeze(1)
        attention_mask = attention_mask.squeeze(1)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state  # This should be a tensor
        prediction_scores = model.mlm_head(last_hidden_state)

        return prediction_scores, pooled_output

In [21]:
class ProteinDataset(Dataset):
    def __init__(self, sequences, tokenizer, max_length=512):
        self.sequences = sequences
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        input_ids, attention_mask = self.tokenizer(sequence)

        return input_ids, attention_mask

In [None]:
with open("uniref50_sequences.txt") as f:
    sequences = [line.strip() for line in f]

# Tokenize sequences
input_ids, attention_mask = esm_tokenizer(sequences)

# Print the first token ID using the tokenizer's decode method
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
print(f"First token ID: {tokenizer.decode(input_ids[0].tolist())}")

# Prepare dataset and dataloader
dataset = ProteinDataset(sequences, esm_tokenizer)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# Initialize model, optimizer, and loss function
device = torch.device("cpu")
model = ProteinBERT(vocab_size=8000).to(device)
optimizer = optim.Adam(model.parameters(), lr=3e-5)
criterion = nn.CrossEntropyLoss()

# Training loop
for epoch in range(3):
    for batch in dataloader:
        input_ids, attention_mask = batch
        input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)

        optimizer.zero_grad()

        print(f"Input IDs shape: {input_ids.shape}")
        print(f"Attention mask shape: {attention_mask.shape}")

        last_hidden_state, pooled_output = model(input_ids, attention_mask)

        mask = torch.rand(input_ids.size()).to(device) < 0.15
        mask = mask & (input_ids != 0)
        labels = torch.where(mask, input_ids, torch.tensor(-100).to(device))

        prediction_scores = model.mlm_head(last_hidden_state)

        loss = criterion(prediction_scores.view(-1, 8000), labels.view(-1))
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1} Loss: {loss.item()}")


First token ID: <cls> M G R I R V W V G T S I P N P V N A H Q L V Y L K G M A K T K K L I L L L F V A A Q P N F K E W S L D V D A S T L V L T F E A N S V L S V K P D C S K V T I H S T A N G V K N V T L T N S G N G T L D A A N D Q A S C T I D A K D L D N I K L E T T L G T N T T N T F L E V K A G F G T K N G T T E F T Q G S P Y T A A A L V T P D V T A P E I S A T V G F S E F D L N S G R V T I A F T E A V D V S T L K F T K L A F R D A K L T G K T S T T G Y C N V T K D G K C D A A F C K N G A T V V L E V D N V D L N C I K S K R G L C T K D S D C I I T L E E D D F I Q D M A G N K L G K Y E S G T T A N A A E T L L H K F V P D I T S P T L D N F D L D L N A N T L T L E F S E T V D A K T L K A D G L T I Q G N G N T A D V S L Q V K L T S E S T T E S S D S A T I I V D I A P A D G A K L K M S T N I A T K T G D S Y I A V A T S A M N D M S G N A V K P I S S T A A K Q V R R F T N D T S A A V L S K F S L D L N T N Q L T L T F D E P V K V D S L N F T L F T L Q S T A A G G T E V K L T G S T T M T T G T 