In [3]:
import pandas as pd
from transformers import BertForMaskedLM, BertTokenizer, pipeline
import math
import torch

## Prediction on masked tokens in the test dataset

In [4]:
# Function to tokenize sequences, preserving `[MASK]` as a single token
def add_spaces_preserve_mask(sequence):
    tokens = []
    i = 0
    while i < len(sequence):
        if sequence[i : i + 6] == "[MASK]":  # Detect `[MASK]`
            tokens.append("[MASK]")
            i += 6  
        else:
            tokens.append(sequence[i])
            i += 1
    return " ".join(tokens)

# Function to predict the masked tokens from the tokenized sequence
def recover_sequence(sequence, model_unmasker):
    tokens = sequence.split()  # Tokenize the sequence
    result = model_unmasker(" ".join(tokens)) # Predictions
    
    mask_index = 0  
    for i, token in enumerate(tokens):
        if token == "[MASK]":
            top_prediction = result[mask_index][0]["token_str"]
            tokens[i] = top_prediction
            mask_index += 1  
    
    return " ".join(tokens)


# Function to calculate overall accuracy and recover the masked tokens
def calculate_accuracy_and_recover(test_data, model_unmasker):
    correct_predictions = 0
    total_masked_tokens = 0
    recovered_sequences = []

    for _, row in test_data.iterrows():
        original_seq = row["OriginalSequence"]
        masked_seq = row["MaskedSequence"]

        original_seq_spaced = add_spaces_preserve_mask(original_seq)
        masked_seq_spaced = add_spaces_preserve_mask(masked_seq)

        recovered_seq = recover_sequence(masked_seq_spaced, model_unmasker)
        recovered_sequences.append(recovered_seq)

        # Compare only the masked tokens
        original_seq_tokens = original_seq_spaced.split()
        masked_seq_tokens = masked_seq_spaced.split()
        recovered_seq_tokens = recovered_seq.split()
        for orig, masked, recovered in zip(original_seq_tokens, masked_seq_tokens, recovered_seq_tokens):
            if masked == "[MASK]":  
                total_masked_tokens += 1
                if orig == recovered:
                    correct_predictions += 1

    # Calculate accuracy and error rate
    accuracy = correct_predictions / total_masked_tokens if total_masked_tokens > 0 else 0
    error_rate = 1 - accuracy

    return recovered_sequences, accuracy, error_rate

# Functions to calculate region-level accuracy and error rate
def add_region_accuracy(region_accuracies, start, end, region_name, original, masked, recovered):
    """Helper function to calculate accuracy for a specific region."""
    for orig, mask, rec in zip(original[start:end], masked[start:end], recovered[start:end]):
        if mask == "[MASK]":
            region_accuracies[region_name]["total"] += 1
            if orig == rec:
                region_accuracies[region_name]["correct"] += 1


def calculate_accuracy_by_region(test_data, model_unmasker):
    region_accuracies = {
        "FWH1": {"correct": 0, "total": 0},
        "CDRH1": {"correct": 0, "total": 0},
        "FWH2": {"correct": 0, "total": 0},
        "CDRH2": {"correct": 0, "total": 0},
        "FWH3": {"correct": 0, "total": 0},
        "CDRH3": {"correct": 0, "total": 0},
        "FWH4": {"correct": 0, "total": 0},
    }

    for _, row in test_data.iterrows():
        original_seq = row["OriginalSequence"]
        masked_seq = row["MaskedSequence"]

        original_seq_spaced = add_spaces_preserve_mask(original_seq)
        masked_seq_spaced = add_spaces_preserve_mask(masked_seq)
        recovered_seq = recover_sequence(masked_seq_spaced, model_unmasker)

        original_tokens = original_seq_spaced.split()
        masked_tokens = masked_seq_spaced.split()
        recovered_tokens = recovered_seq.split()

        # Parse CDRH1-3 positions
        cdrh1_start, cdrh1_end = eval(row["CDRH1_pos"])
        cdrh1_end += 1
        cdrh2_start, cdrh2_end = eval(row["CDRH2_pos"])
        cdrh2_end += 1
        cdrh3_start, cdrh3_end = eval(row["CDRH3_pos"])
        cdrh3_end += 1

        # Define FWH1-4 regions
        fwh1_start, fwh1_end = 0, cdrh1_start
        fwh2_start, fwh2_end = cdrh1_end, cdrh2_start
        fwh3_start, fwh3_end = cdrh2_end, cdrh3_start
        fwh4_start, fwh4_end = cdrh3_end, len(original_seq)

        # Compute accuracy for each region
        add_region_accuracy(region_accuracies, fwh1_start, fwh1_end, "FWH1", original_tokens, masked_tokens, recovered_tokens)
        add_region_accuracy(region_accuracies, cdrh1_start, cdrh1_end, "CDRH1", original_tokens, masked_tokens, recovered_tokens)
        add_region_accuracy(region_accuracies, fwh2_start, fwh2_end, "FWH2", original_tokens, masked_tokens, recovered_tokens)
        add_region_accuracy(region_accuracies, cdrh2_start, cdrh2_end, "CDRH2", original_tokens, masked_tokens, recovered_tokens)
        add_region_accuracy(region_accuracies, fwh3_start, fwh3_end, "FWH3", original_tokens, masked_tokens, recovered_tokens)
        add_region_accuracy(region_accuracies, cdrh3_start, cdrh3_end, "CDRH3", original_tokens, masked_tokens, recovered_tokens)
        add_region_accuracy(region_accuracies, fwh4_start, fwh4_end, "FWH4", original_tokens, masked_tokens, recovered_tokens)

    # Calculate final accuracy for each region
    for region, counts in region_accuracies.items():
        total = counts["total"]
        correct = counts["correct"]
        region_accuracies[region]["accuracy"] = correct / total if total > 0 else 0

    return region_accuracies

In [5]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Load test dataset
file_path = "./data/annotation/updated_test_masked_dataset.csv" 
test_data = pd.read_csv(file_path)

### Model 1: Pretrained ProtBERT

In [50]:
# Load ProtBERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
model = BertForMaskedLM.from_pretrained("Rostlab/prot_bert")
pretrained_unmasker = pipeline('fill-mask', model=model, tokenizer=tokenizer, device=device)

# Run recovery and calculate overall accuracy
pretrained_recovered_sequences, pretrained_overall_accuracy, pretrained_overall_error_rate = calculate_accuracy_and_recover(test_data, pretrained_unmasker)

test_data["RecoveredSequence"] = pretrained_recovered_sequences
test_data.to_csv("recovered_results_vanilla.csv", index=False)

print(f"Pretrained ProtBERT overall accuracy on Masked Tokens: {pretrained_overall_accuracy:.2%}")
print(f"Pretrained ProtBERT overall error Rate on Masked Tokens: {pretrained_overall_error_rate:.2%}")

# Calculate region-level accuracy
pretrained_region_accuracies = calculate_accuracy_by_region(test_data.iloc[0:428], pretrained_unmasker)
for region, counts in pretrained_region_accuracies.items():
    print(f"Region {region} Accuracy: {counts['accuracy']:.2%}")

Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Pretrained ProtBERT overall accuracy on Masked Tokens: 70.31%
Pretrained ProtBERT overall error Rate on Masked Tokens: 29.69%
Region FWH1 Accuracy: 86.56%
Region CDRH1 Accuracy: 41.75%
Region FWH2 Accuracy: 73.47%
Region CDRH2 Accuracy: 41.13%
Region FWH3 Accuracy: 86.60%
Region CDRH3 Accuracy: 18.17%
Region FWH4 Accuracy: 88.49%


### Model 2: unfreeze all parameters

In [51]:
# Load fine tuned model 2
save_directory = "./fine_tuned_ProtBERT/model2"
tokenizer = BertTokenizer.from_pretrained(save_directory, do_lower_case=False)
model = BertForMaskedLM.from_pretrained(save_directory)
model2_unmasker = pipeline('fill-mask', model=model, tokenizer=tokenizer, device=device)

# Run recovery and calculate overall accuracy
model2_recovered_sequences, model2_overall_accuracy, model2_overall_error_rate = calculate_accuracy_and_recover(test_data, model2_unmasker)

test_data["RecoveredSequence"] = model2_recovered_sequences
test_data.to_csv("recovered_results_model2.csv", index=False)

print(f"Model2 overall accuracy on Masked Tokens: {model2_overall_accuracy:.2%}")
print(f"Model2 overall error Rate on Masked Tokens: {model2_overall_error_rate:.2%}")

# Calculate region-level accuracy
model2_region_accuracies = calculate_accuracy_by_region(test_data.iloc[0:428], model2_unmasker)
for region, counts in model2_region_accuracies.items():
    print(f"Region {region} Accuracy: {counts['accuracy']:.2%}")

Model2 overall accuracy on Masked Tokens: 90.25%
Model2 overall error Rate on Masked Tokens: 9.75%
Region FWH1 Accuracy: 95.12%
Region CDRH1 Accuracy: 82.50%
Region FWH2 Accuracy: 94.31%
Region CDRH2 Accuracy: 81.99%
Region FWH3 Accuracy: 94.86%
Region CDRH3 Accuracy: 71.71%
Region FWH4 Accuracy: 97.17%


### Model 3: unfreeze last two encoder layers

In [52]:
# Load fine tuned model 3
save_directory = "./fine_tuned_ProtBERT/model3"
tokenizer = BertTokenizer.from_pretrained(save_directory, do_lower_case=False)
model = BertForMaskedLM.from_pretrained(save_directory)
model3_unmasker = pipeline('fill-mask', model=model, tokenizer=tokenizer, device=device)

# Run recovery and calculate overall accuracy
model3_recovered_sequences, model3_overall_accuracy, model3_oveall_error_rate = calculate_accuracy_and_recover(test_data, model3_unmasker)

test_data["RecoveredSequence"] = model3_recovered_sequences
test_data.to_csv("recovered_results_model3.csv", index=False)

print(f"Model3 accuracy on Masked Tokens: {model3_overall_accuracy:.2%}")
print(f"Model3 error Rate on Masked Tokens: {model3_oveall_error_rate:.2%}")

# Calculate region-level accuracy
model3_region_accuracies = calculate_accuracy_by_region(test_data.iloc[0:428], model3_unmasker)
for region, counts in model3_region_accuracies.items():
    print(f"Region {region} Accuracy: {counts['accuracy']:.2%}")

Model3 accuracy on Masked Tokens: 83.31%
Model3 error Rate on Masked Tokens: 16.69%
Region FWH1 Accuracy: 92.88%
Region CDRH1 Accuracy: 70.75%
Region FWH2 Accuracy: 89.63%
Region CDRH2 Accuracy: 62.63%
Region FWH3 Accuracy: 91.96%
Region CDRH3 Accuracy: 49.02%
Region FWH4 Accuracy: 96.79%


## Pseudo-Perplexity (PPPL)

In [None]:
## Function to calculate PPPL for each BERT-based model
def compute_pppl(sequence, model, tokenizer):
    input_ids = tokenizer(sequence, return_tensors="pt")["input_ids"].to(device)
    num_tokens = input_ids.size(1) - 2  # Exclude [CLS] and [SEP]
    total_log_prob = 0.0

    for i in range(1, num_tokens + 1):  # Iterate through each token (excluding special tokens)
        masked_input = input_ids.clone()
        masked_input[0, i] = tokenizer.mask_token_id  # Mask the i-th token

        with torch.no_grad():
            outputs = model(masked_input)
            logits = outputs.logits

        softmax_probs = torch.nn.functional.softmax(logits[0, i], dim=-1)
        original_token_id = input_ids[0, i]
        token_prob = softmax_probs[original_token_id].item()

        if token_prob > 0:
            total_log_prob += math.log(token_prob)

    pppl = math.exp(-total_log_prob / num_tokens)
    return pppl


def compute_pppl_for_dataset(test_data, model, tokenizer):
    pppl_scores = []
    model = model.to(device)
    model.eval()

    for sequence in test_data["OriginalSequence"]:
        sequence = add_spaces_preserve_mask(sequence)
        pppl = compute_pppl(sequence, model, tokenizer)
        pppl_scores.append(pppl)

    avg_pppl = sum(pppl_scores) / len(pppl_scores)

    return avg_pppl


### Model 1: Pretrained ProtBERT

In [7]:
pretrained_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
pretained_model = BertForMaskedLM.from_pretrained("Rostlab/prot_bert")

pretrained_avg_pppl = compute_pppl_for_dataset(test_data, pretained_model, pretrained_tokenizer)

print(f"Average PPPL for Pretrained ProtBERT: {pretrained_avg_pppl}")

BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another archite

Average PPPL for Pretrained ProtBERT: 2.980614575856263


### Model 2: unfreeze all parameters

In [8]:
save_directory = "./fine_tuned_ProtBERT/model2"
model2_tokenizer = BertTokenizer.from_pretrained(save_directory, do_lower_case=False)
model2 = BertForMaskedLM.from_pretrained(save_directory)

model2_avg_pppl = compute_pppl_for_dataset(test_data, model2, model2_tokenizer)

print(f"Average PPPL for Model2: {model2_avg_pppl}")

Average PPPL for Model2: 1.4402317786021814


### Model 3: unfreeze last two encoder layers

In [9]:
save_directory = "./fine_tuned_ProtBERT/model3"
model3_tokenizer = BertTokenizer.from_pretrained(save_directory, do_lower_case=False)
model3 = BertForMaskedLM.from_pretrained(save_directory)

model3_avg_pppl = compute_pppl_for_dataset(test_data, model3, model3_tokenizer)

print(f"Average PPPL for Model3: {model3_avg_pppl}")

Average PPPL for Model3: 1.8949705332376572
