In [None]:
from transformers import (
    BertForMaskedLM,
    T5EncoderModel,
    AutoTokenizer,
    T5Tokenizer,
)

from antibody_commonness.pseudo_likelihood import calculate_pll
from antibody_commonness.data import AntibodyPLLDataset

In [None]:
import ipywidgets as widgets

# Create a dropdown widget
dropdown = widgets.Dropdown(
    options=['IgBert', 'IgT5'],
    value='IgBert',
    description='Select an antibody MLM:'
)
# Display the dropdown widget
display(dropdown)

# Create a text input widget for batch size
batch_size_input = widgets.IntText(
    value=128,
    description='Batch Size:'
)

# Display the batch size input widget
display(batch_size_input)

In [None]:
import torch

model_str = dropdown.value

if model_str == "IgBert":
    model = BertForMaskedLM.from_pretrained("Exscientia/IgBert_unpaired")
    tokenizer = AutoTokenizer.from_pretrained("Exscientia/IgBert_unpaired")
if model_str == "IgT5":
    model = T5EncoderModel.from_pretrained("Exscientia/IgT5_unpaired")
    tokenizer = T5Tokenizer.from_pretrained("Exscientia/IgT5_unpaired")

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
batch_size = batch_size_input.value

print(f"Will run model [{model_str}] on device [{device}] with batch size [{batch_size}]")

In [None]:
import pandas as pd

sab_dab_link = "https://opig.stats.ox.ac.uk/webapps/sabdab-sabpred/static/downloads/TheraSAbDab_SeqStruc_OnlineDownload.csv"
Thera_SAbDab = pd.read_csv(sab_dab_link)

sequences = Thera_SAbDab["HeavySequence"].to_list()

dataset = AntibodyPLLDataset(sequences)

In [None]:
score_storage = calculate_pll(model, tokenizer, dataset=dataset, batch_size=batch_size, device=device)

In [None]:
# Access scores
score_storage.get_pll(), score_storage.get_length_normalised_pll()