In [75]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForMaskedLM
import torch

class_labels = {
    0: "LLWNGPMAV",
    1: "RPRGEVRFL",
    2: "ATDALMTGY",
    3: "HSKKKCDEL",
    4: "KAFSPEVIPMF",
    5: "KRWIILGLNK",
    6: "KRWIIMGLNK",
    7: "TPQDLNTML",
    8: "EIYKRWII",
    9: "ISPRTL-W",
    10: "FLKEKGGL",
    11: "HPKVSSEVHI",
    12: "IIKDYGKQM",
    13: "LPPIVAKEI",
    14: "RFPLTFGWCF",
    15: "RYPLTFGWCF",
    16: "TPGPGVRYPL",
    17: "TQGYFPDWQNY",
    18: "FPRPWLHGL",
    19: "RYPLTFGW",
    20: "ELRRKMMYM",
    21: "QIKVRVDMV",
    22: "QIKVRVKMV",
    23: "VLEETSVML",
    24: "FPTKDVAL",
    25: "NLVPMVATV",
    26: "RPHERNGFTVL",
    27: "TPRVTGGGAM",
    28: "VTEHDTLLY",
    29: "YLAMPFATPMEAELARRSLA",
    30: "GLCTLVAML",
    31: "YVLDHLIVV",
    32: "EPLPQGQLTAY",
    33: "RAKFKQLL",
    34: "HPVGEADYFEY",
    35: "FLRGRAYGL",
    36: "AVFDRKSDAK",
    37: "IVTDFSVIK",
    38: "NFIRMVISNPAAT",
    39: "KRGIVEQSSTSISSL",
    40: "ENPVVHFFKNIVTPR",
    41: "GILGFVFTL",
    42: "PQPELPYPQPE",
    43: "FWIDLFETIG",
    44: "other"
  }

tokenizer = AutoTokenizer.from_pretrained("wukevin/tcr-bert")
model = AutoModelForSequenceClassification.from_pretrained("wukevin/tcr-bert")

input_text = "I like you. I love you".upper()
inputs = tokenizer(input_text, return_tensors="pt")

# Predict
with torch.no_grad():
    outputs = model(**inputs)
    
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)


result = list(zip(list(class_labels.values()), probabilities.tolist()[0]))
result.sort(key = lambda x: -x[1])
result

[('TPGPGVRYPL', 0.44939708709716797),
 ('GILGFVFTL', 0.1271347850561142),
 ('FPRPWLHGL', 0.07400770485401154),
 ('LLWNGPMAV', 0.035007625818252563),
 ('TPQDLNTML', 0.027008509263396263),
 ('EIYKRWII', 0.018733853474259377),
 ('HSKKKCDEL', 0.018440475687384605),
 ('TQGYFPDWQNY', 0.018129343166947365),
 ('NLVPMVATV', 0.01716885156929493),
 ('PQPELPYPQPE', 0.017030414193868637),
 ('RPRGEVRFL', 0.016987716779112816),
 ('HPVGEADYFEY', 0.015107231214642525),
 ('RFPLTFGWCF', 0.013964857906103134),
 ('other', 0.012723141349852085),
 ('RYPLTFGWCF', 0.011348765343427658),
 ('FPTKDVAL', 0.010398641228675842),
 ('EPLPQGQLTAY', 0.01029637549072504),
 ('HPKVSSEVHI', 0.009545115754008293),
 ('NFIRMVISNPAAT', 0.00897930283099413),
 ('VLEETSVML', 0.007365161087363958),
 ('FLRGRAYGL', 0.007004470098763704),
 ('VTEHDTLLY', 0.006441055331379175),
 ('QIKVRVDMV', 0.006390884984284639),
 ('LPPIVAKEI', 0.006246190518140793),
 ('FWIDLFETIG', 0.006086317822337151),
 ('KRWIILGLNK', 0.005950759630650282),
 ('IIKD

In [99]:
class_labels = {0: "R", 1: "H", 2: "K", 3: "D", 4: "E", 5: "S", 6: "T", 7: "N", 8: "Q", 9: "C", 10: "U", 11: "G", 12: "P", 13: "A", 14: "V", 15: "I", 16: "L", 17: "M", 18: "F", 19: "Y", 20: "W", 21: "$", 22: ".", 23: "?", 24: "|", 25: "*"}

tokenizer = AutoTokenizer.from_pretrained("model/mlm-only/tokenizer")
model = AutoModelForMaskedLM.from_pretrained("model/mlm-only/model")

input_text = "Paris is the [MASK] of France."
inputs = tokenizer(input_text, return_tensors="pt")

# Predict
with torch.no_grad():
    outputs = model(**inputs)
    
predictions = outputs.logits
mask_token_index = torch.where(inputs.input_ids == tokenizer.mask_token_id)[1]
mask_token_logits = predictions[0, mask_token_index, :]

# Apply softmax to convert logits to probabilities
token_probabilities = torch.softmax(mask_token_logits, dim=1)

# Get the top 5 predicted tokens and their probabilities
top_5_tokens = torch.topk(token_probabilities, 5, dim=1)
top_5_token_ids = top_5_tokens.indices[0].tolist()
top_5_probabilities = top_5_tokens.values[0].tolist()

# Convert predicted token IDs to words and print them with probabilities
for token_id, probability in zip(top_5_token_ids, top_5_probabilities):
    word = tokenizer.decode([token_id])
    print(f"Predicted word: {word}, Probability: {probability:.4f}")

Predicted word: F, Probability: 0.9051
Predicted word: T, Probability: 0.0153
Predicted word: L, Probability: 0.0133
Predicted word: Y, Probability: 0.0123
Predicted word: Q, Probability: 0.0059
