In [19]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForMaskedLM
import torch

class_labels = {
    0: "LLWNGPMAV",
    1: "RPRGEVRFL",
    2: "ATDALMTGY",
    3: "HSKKKCDEL",
    4: "KAFSPEVIPMF",
    5: "KRWIILGLNK",
    6: "KRWIIMGLNK",
    7: "TPQDLNTML",
    8: "EIYKRWII",
    9: "ISPRTL-W",
    10: "FLKEKGGL",
    11: "HPKVSSEVHI",
    12: "IIKDYGKQM",
    13: "LPPIVAKEI",
    14: "RFPLTFGWCF",
    15: "RYPLTFGWCF",
    16: "TPGPGVRYPL",
    17: "TQGYFPDWQNY",
    18: "FPRPWLHGL",
    19: "RYPLTFGW",
    20: "ELRRKMMYM",
    21: "QIKVRVDMV",
    22: "QIKVRVKMV",
    23: "VLEETSVML",
    24: "FPTKDVAL",
    25: "NLVPMVATV",
    26: "RPHERNGFTVL",
    27: "TPRVTGGGAM",
    28: "VTEHDTLLY",
    29: "YLAMPFATPMEAELARRSLA",
    30: "GLCTLVAML",
    31: "YVLDHLIVV",
    32: "EPLPQGQLTAY",
    33: "RAKFKQLL",
    34: "HPVGEADYFEY",
    35: "FLRGRAYGL",
    36: "AVFDRKSDAK",
    37: "IVTDFSVIK",
    38: "NFIRMVISNPAAT",
    39: "KRGIVEQSSTSISSL",
    40: "ENPVVHFFKNIVTPR",
    41: "GILGFVFTL",
    42: "PQPELPYPQPE",
    43: "FWIDLFETIG",
    44: "other"
  }

tokenizer = AutoTokenizer.from_pretrained("wukevin/tcr-bert")
model = AutoModelForSequenceClassification.from_pretrained("wukevin/tcr-bert")

input_text = "I like you. I love you".upper()
inputs = tokenizer(input_text, return_tensors="pt")

# Predict
with torch.no_grad():
    outputs = model(**inputs)
    
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)


result = list(zip(list(class_labels.values()), probabilities.tolist()[0]))
result.sort(key = lambda x: -x[1])
for word, probability in result:
    word = '\"' + word + '\"'
    print(f"P(Y = {word:25} | x = x) = {probability:.5f}")
    # print(f"Class: {word:30}| Probability: {probability:.4f}")

P(Y = "TPGPGVRYPL"              | x = x) = 0.44940
P(Y = "GILGFVFTL"               | x = x) = 0.12713
P(Y = "FPRPWLHGL"               | x = x) = 0.07401
P(Y = "LLWNGPMAV"               | x = x) = 0.03501
P(Y = "TPQDLNTML"               | x = x) = 0.02701
P(Y = "EIYKRWII"                | x = x) = 0.01873
P(Y = "HSKKKCDEL"               | x = x) = 0.01844
P(Y = "TQGYFPDWQNY"             | x = x) = 0.01813
P(Y = "NLVPMVATV"               | x = x) = 0.01717
P(Y = "PQPELPYPQPE"             | x = x) = 0.01703
P(Y = "RPRGEVRFL"               | x = x) = 0.01699
P(Y = "HPVGEADYFEY"             | x = x) = 0.01511
P(Y = "RFPLTFGWCF"              | x = x) = 0.01396
P(Y = "other"                   | x = x) = 0.01272
P(Y = "RYPLTFGWCF"              | x = x) = 0.01135
P(Y = "FPTKDVAL"                | x = x) = 0.01040
P(Y = "EPLPQGQLTAY"             | x = x) = 0.01030
P(Y = "HPKVSSEVHI"              | x = x) = 0.00955
P(Y = "NFIRMVISNPAAT"           | x = x) = 0.00898
P(Y = "VLEETSVML"              

In [21]:
class_labels = {0: "R", 1: "H", 2: "K", 3: "D", 4: "E", 5: "S", 6: "T", 7: "N", 8: "Q", 9: "C", 10: "U", 11: "G", 12: "P", 13: "A", 14: "V", 15: "I", 16: "L", 17: "M", 18: "F", 19: "Y", 20: "W", 21: "$", 22: ".", 23: "?", 24: "|", 25: "*"}

tokenizer = AutoTokenizer.from_pretrained("./model/mlm-only/tokenizer")
model = AutoModelForMaskedLM.from_pretrained("./model/mlm-only/model")

input_text = " ".join(list("CALSDVEGAQKL.F")) # V
inputs = tokenizer(input_text, return_tensors="pt")

# Predict
with torch.no_grad():
    outputs = model(**inputs)
    
predictions = outputs.logits
mask_token_index = torch.where(inputs.input_ids == tokenizer.mask_token_id)[1]
mask_token_logits = predictions[0, mask_token_index, :]

# Apply softmax to convert logits to probabilities
token_probabilities = torch.softmax(mask_token_logits, dim=1)

# Get the top 5 predicted tokens and their probabilities
top_5_tokens = torch.topk(token_probabilities, 5, dim=1)
top_5_token_ids = top_5_tokens.indices[0].tolist()
top_5_probabilities = top_5_tokens.values[0].tolist()

# Convert predicted token IDs to words and print them with probabilities
for token_id, probability in zip(top_5_token_ids, top_5_probabilities):
    word = tokenizer.decode([token_id])
    print(f"Predicted word: {word}, Probability: {probability:.4f}")

Predicted word: S, Probability: 0.9094
Predicted word: G, Probability: 0.0593
Predicted word: R, Probability: 0.0181
Predicted word: W, Probability: 0.0024
Predicted word: I, Probability: 0.0021


In [3]:
model.bert

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(26, 768, padding_idx=21)
    (position_embeddings): Embedding(64, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
     

In [4]:
model.bert(**inputs).last_hidden_state.shape

torch.Size([1, 16, 768])

In [6]:
model(**inputs).logits.shape

torch.Size([1, 16, 26])