# Fine-tuning a masked language model (PyTorch)

## import functions

In [1]:
from functions import load_model, load_dataset, fine_tune, accuracy

## set variables

In [4]:
# google-bert/bert-base-uncased
# google-bert/bert-large-uncased

# FacebookAI/roberta-base

# albert/albert-base-v2

# distilbert/distilbert-base-cased
# distilbert/distilroberta-base

# model_checkpoint = "google-bert/bert-base-uncased"
model_checkpoint = "distilbert/distilbert-base-uncased"
model_checkpoint = "distilbert/distilbert-base-uncased_one_to_three_digit"


model_name = model_checkpoint.split("/")[-1]

train_name = "one_digit_ads_sub"
twodigit_addsub_name = "two_digit_ads_sub"
onedigit_addition_test_name = "onedigit_addition_test"
twodigit_addition_test_name = "twodigit_addition_test"
onedigit_subtraction_name = "onedigit_subtraction_test"
twodigit_subtraction_test_name = "twodigit_subtraction_test"
ds_name = "one_to_three_digit"

# model_name = ''

model_path = f"../models/{model_name}"
train_dataset_path = f"datasets/{train_name}"
two_digit_ads_sub_dataset_path = f"datasets/{twodigit_addsub_name}"
onedigit_addition_test_dataset_path = f"datasets/{onedigit_addition_test_name}"
twodigit_addition_test_dataset_path = f"datasets/{twodigit_addition_test_name}"
onedigit_subtraction_dataset_path = f"datasets/{onedigit_subtraction_name}"
twodigit_subtraction_test_dataset_path = f"datasets/{twodigit_subtraction_test_name}"
ds = f"../datasets/{ds_name}"

## load the model

In [5]:
model, tokenizer = load_model(model_path)

'>>> ../models/distilbert-base-uncased_one_to_three_digit is loaded.
The number of parameters: 67M'


## load the dataset

In [None]:
train_dataset = load_dataset(train_dataset_path)

## tokenize the dataset

In [None]:
def tokenize_label(example: dict) -> dict:
    """
    Tokenize the label columns
    """

    return tokenizer(example["unmasked"])


tokenized_dataset = train_dataset.map(
    tokenize_label, batched=True, remove_columns=["unmasked", "masked"]
)

## fine-tune the model

In [None]:
# onedigit_addition_accuracy, twodigit_addition_accuracy, one_digit_subtraction_accuracy, twodigit_subtraction_accuracy = [], [], [], []
two_digit_ads_sub_accuracy = []

In [None]:
model = fine_tune(model, tokenizer, tokenized_dataset, lr=2e-5, num_epochs=1, batch_size=64)

## calculate accuracy

In [None]:
two_digit_ads_sub_dataset = load_dataset(two_digit_ads_sub_dataset_path)
# onedigit_addition_test_dataset = load_dataset(onedigit_addition_test_dataset_path)
# twodigit_addition_test_dataset = load_dataset(twodigit_addition_test_dataset_path)
# onedigit_subtraction_test_dataset = load_dataset(onedigit_subtraction_dataset_path)
# twodigit_subtraction_test_dataset = load_dataset(twodigit_subtraction_test_dataset_path)

In [None]:
two_digit_ads_sub_accuracy.append(accuracy(model, tokenizer, two_digit_ads_sub_dataset, test_set="test"))
# onedigit_addition_accuracy.append(accuracy(model, tokenizer, onedigit_addition_test_dataset, test_set="test"))
# twodigit_addition_accuracy.append(accuracy(model, tokenizer, twodigit_addition_test_dataset, test_set="test"))
# one_digit_subtraction_accuracy.append(accuracy(model, tokenizer, onedigit_subtraction_test_dataset, test_set="test"))
# twodigit_subtraction_accuracy.append(accuracy(model, tokenizer, twodigit_subtraction_test_dataset, test_set="test"))

In [None]:
# print(f"Train Losses: {train_losses}")
# print(f"Eval Losses: {eval_losses}")
# print(f"One Digit Addition Accuracy: {onedigit_addition_accuracy}")
# print(f"Two Digit Addition Accuracy: {twodigit_addition_accuracy}")
# print(f"One Digit Subtraction Accuracy: {one_digit_subtraction_accuracy}")
# print(f"Two Digit Subtraction Accuracy: {twodigit_subtraction_accuracy}")
print(f"Two Digit Addition and Subtraction Accuracy: {two_digit_ads_sub_accuracy}")

In [None]:
batch_size = 8
lr = [7e-5, 6e-5, 5e-5, 5e-5, 5e-5, 4e-5, 4e-5, 4e-5, 4e-5, 3e-5, 3e-5, 2e-5, 2e-5, 2e-5, 2e-5, 2e-5, 1e-5]
train_loss = [1.5503794352213542, 0.8744051106770834, 0.6071566263834636, 0.4575591532389323, 0.34520345052083334, 0.21754842122395834, 0.14869319915771484, 0.06719978332519531, 0.07871321360270182, 0.04905564308166504, 0.03357593218485514, 0.02027098814646403, 0.013109755516052247, 0.014046912193298339, 0.005731958548227946, 0.005177744229634603, 0.005364036957422893, 0.0051414374510447185]
eval_loss = [0.7209213972091675, 0.5388258695602417, 0.3167385756969452, 0.24366582930088043, 0.17933863401412964, 0.08913823962211609, 0.059695053845644, 0.05614165589213371, 0.033642660826444626, 0.008912994526326656, 0.005345964804291725, 0.0022574884351342916, 0.0015304312109947205, 0.00018909828213509172, 0.0003367669996805489, 0.0001413973222952336, 0.00016261845303233713, 0.00017344774096272886]

In [None]:
batch_size = 64
lr = [2e-4, 1e-4, 9e-5, 8e-5, 7e-5, 5e-6, 5e-5, 4e-5, 3e-5]
train_loss = [1.1188555908203126, 0.46731974283854166, 0.2867759958902995, 0.18037309010823568, 0.11367528279622396, 0.07348385492960612, 0.028819942474365236, 0.013884073893229166, 0.002491823434829712, 0.0010240906476974487]
eval_loss = [0.44242188334465027, 0.20009556412696838, 0.12076643854379654, 0.0605129599571228, 0.028032438829541206, 0.0066362833604216576, 0.0015798092354089022, 0.0005591694498434663, 0.00011961638665525243, 2.1034184101154096e-05]
onedigit_addition_accuracy = [3.5833333333333335, 77.75, 92.5, 91.58333333333333, 96.41666666666667, 95.66666666666667, 98.08333333333333, 98.0, 99.25, 99.08333333333333]
twodigit_addition_accuracy = [0.16666666666666666, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
one_digit_subtraction_accuracy = [7.7272727272727275, 45.15151515151515, 46.06060606060606, 43.333333333333336, 45.90909090909091, 44.09090909090909, 47.121212121212125, 44.54545454545455, 49.39393939393939, 47.121212121212125]
twodigit_subtraction_accuracy = [0.8333333333333334, 5.166666666666667, 6.0, 5.833333333333333, 6.5, 6.0, 6.666666666666667, 6.5, 6.833333333333333, 5.833333333333333]

In [None]:
# show the progress of the training with graph
# show the graph for onedigit_addition_accuracy, twodigit_addition_accuracy
import matplotlib.pyplot as plt

# show values up to 10 on vertical axis
plt.ylim(-5, 10)
# plt.plot(onedigit_addition_accuracy, label="One Digit Addition")
plt.plot(twodigit_addition_accuracy, label="Two Digit Addition")
# plt.plot(one_digit_subtraction_accuracy, label="One Digit Subtraction")
plt.plot(twodigit_subtraction_accuracy, label="Two Digit Subtraction")
# plot losses
# plt.plot(train_loss, label="Train Loss")
# plt.plot(eval_loss, label="Eval Loss")
plt.legend()

plt.show()


lr = [7e-5, 6e-5, 5e-5, 5e-5, 5e-5, 4e-5, 4e-5, 4e-5, 4e-5, 3e-5, 3e-5, 2e-5, 2e-5, 2e-5, 2e-5, 2e-5, 1e-5]
train_loss = [1.5503794352213542, 0.8744051106770834, 0.6071566263834636, 0.4575591532389323, 0.34520345052083334, 0.21754842122395834, 0.14869319915771484, 0.06719978332519531, 0.07871321360270182, 0.04905564308166504, 0.03357593218485514, 0.02027098814646403, 0.013109755516052247, 0.014046912193298339, 0.005731958548227946, 0.005177744229634603, 0.005364036957422893, 0.0051414374510447185]
eval_loss = [0.7209213972091675, 0.5388258695602417, 0.3167385756969452, 0.24366582930088043, 0.17933863401412964, 0.08913823962211609, 0.059695053845644, 0.05614165589213371, 0.033642660826444626, 0.008912994526326656, 0.005345964804291725, 0.0022574884351342916, 0.0015304312109947205, 0.00018909828213509172, 0.0003367669996805489, 0.0001413973222952336, 0.00016261845303233713, 0.00017344774096272886]
onedigit_addition_accuracy = [4.666666666666667, 74.0, 82.0, 91.6667, 94.0, 94.0, 97.3333, 96.6667, 97.3333, 97.6667, 99.3333, 98.3333, 99.0, 99.0, 98.6667, 99.6667, 99.0, 99.0, 99.0]
twodigit_addition_accuracy = [0.0, 0.0, 0.6667, 0.6667, 0.0, 0.0, 0.0, 0.0, 0.3333, 0.0, 0.3333, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
one_digit_subtraction_accuracy = [6.0606060606060606, 31.5152, 30.3030, 30.9091, 31.5152, 28.4848, 27.2727, 29.0910, 27.8788, 24.4848, 33.9394, 33.9394, 34.5455, 32.7273, 33.3333, 32.7273, 33.3333, 33.3333, 33.3333]
twodigit_subtraction_accuracy = [3.3333333333333335, 1.3333, 0.6667, 0.6667, 0.6667, 0.3333, 0.3333, 0.3333, 0.0, 0.3333, 1.0, 0.6667, 0.6667, 0.6667, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333]

In [None]:
onedigit_addition_accuracy = [4.666666666666667, 74.0, 82.0, 91.6667, 94.0, 94.0, 97.3333, 96.6667, 97.3333, 97.6667, 99.3333, 98.3333, 99.0, 99.0, 98.6667, 99.6667, 99.0, 99.0, 99.0]
twodigit_addition_accuracy = [0.0, 0.0, 0.6667, 0.6667, 0.0, 0.0, 0.0, 0.0, 0.3333, 0.0, 0.3333, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
one_digit_subtraction_accuracy = [6.0606060606060606, 31.5152, 30.3030, 30.9091, 31.5152, 28.4848, 27.2727, 29.0910, 27.8788, 24.4848, 33.9394, 33.9394, 34.5455, 32.7273, 33.3333, 32.7273, 33.3333, 33.3333, 33.3333]
twodigit_subtraction_accuracy = [3.3333333333333335, 1.3333, 0.6667, 0.6667, 0.6667, 0.3333, 0.3333, 0.3333, 0.0, 0.3333, 1.0, 0.6667, 0.6667, 0.6667, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333]

In [None]:
# save the model
model.save_pretrained("distilbert_onedigit_addition_v2_4800_8epochs")

# save the tokenizer
tokenizer.save_pretrained("distilbert_onedigit_addition_v2_4800_8epochs")

In [None]:
test_accuracy = accuracy(model, tokenizer, test_dataset, test_set="test")

print(f"The accuracy on the test set is {test_accuracy} percent.")

In [None]:
from transformers import AutoModelForMaskedLM
from transformers import AutoTokenizer
from datasets import DatasetDict
import torch

def acc(
        model: AutoModelForMaskedLM,
        tokenizer: AutoTokenizer,
        dataset: DatasetDict,
        test_set: str, plus=False, equals=False) -> float:
    """
    Calculate the accuracy of the model on the test dataset.
    :param model: the model
    :param tokenizer: the tokenizer
    :param dataset: the dataset
    :param test_set: the test set
    :param plus: whether to replace '+' with 'plus'
    :param equals: whether to replace '=' with 'equals'
    :return: the accuracy
    """

    accuracy = 0

    # loop through the test dataset
    for row in dataset[test_set]:
        row['masked'] = row['masked'].replace(".", " .")
        row['unmasked'] = row['unmasked'].replace(".", " .")
        if plus:
            # reoplace '+' with 'plus'
            row['masked'] = row['masked'].replace("+", "plus")
        if equals:
            # replace '=' with 'equals'
            row['masked'] = row['masked'].replace("=", "equals")

        # get the index of the masked token
        print(row['masked'])
        idx = row['masked'].split().index("[MASK]")
        label = row['unmasked'].split()[idx]
        print(idx)
        print(label)

        # tokenize the text
        inputs = tokenizer(row['masked'], return_tensors="pt")

        # get the model outputs
        outputs = model(**inputs)

        # get the predicted token
        predictions = torch.argmax(outputs.logits, dim=-1)
        masked_idx = torch.where(inputs.input_ids == tokenizer.mask_token_id)[1]
        correct_token_id = tokenizer.convert_tokens_to_ids(label)

        print(predictions[0, masked_idx], correct_token_id)
        print(predictions[0, masked_idx] == correct_token_id)
        print(accuracy)
        # check if the predicted token is correct
        accuracy += (predictions[0, masked_idx] == correct_token_id).item()
        print(accuracy)
        print('\n')

    return 100 * accuracy / len(dataset[test_set])

In [None]:
acc(model, tokenizer, test_dataset, test_set="test")

In [None]:
# test set accuracy:       40.00% - 44.00% - 48.00% - 52.00% - 52.00% - 52.00% - 64.00% - 56.00% - 60.00% - 60.00% - 60.00% - 60.00% - 60.00% - 64.00%
#                          64.00% - 60.00% - 60.00% - 60.00% - 60.00% - 60.00%
# train set accuracy:      39.02% - 38.80% - 39.47% - 39.02% - 39.69% - 40.13% - 40.80% - 40.58% - 41.91% - 41.91% - 40.80% - 40.80% - 40.58% - 41.46%
#                          41.24% - 40.80% - 39.69% - 40.13% - 39.25% - 39.47%
# validation set accuracy: 29.17% - 25.00% - 29.17% - 33.33% - 29.17% - 20.83% - 33.33% - 33.33% - 33.33% - 33.33% - 33.33% - 29.17% - 29.17% - 37.50%
#                          37.50% - 33.33% - 25.00% - 20.83% - 25.00% - 25.00%
train_acc = accuracy(model, tokenizer, dataset, test_set="train")
val_acc = accuracy(model, tokenizer, dataset, test_set="validation")
# test_acc = accuracy(model, tokenizer, dataset, test_set="test")

print(f"train, validation and test accuracies are respectively: {train_acc}, {val_acc}, and {train_acc}.")
# lr 7, 5, 4
# 66.11 - 58.33, 66.25 - 58.33, 66.25 - 58.33

In [None]:
accs = []
for i in range(20):
    model = fine_tune(model, tokenizer, tokenized_dataset)
    acc = accuracy(model, tokenizer, dataset, test_set="test")
    accs.append(acc)
    print(f"Accuracy: {acc}")

for distilbert on 1d addition by random, 200 total sample with batch_size of 1:
    test:  38.000 + 47.500 + 51.500 + 54.500 + 57.500 + 60.000 + 63.500 - 62.000 - 61.500 - 60.000
    train: 42.105 + 51.579 + 56.842 + 57.105 + 62.763 + 63.684 + 66.447 - 65.789 + 66.447 - 65.263
    valid: 32.500 + 50.000 + 55.000 + 60.000 + 70.000 + 72.500 + 75.000 - 72.500 > 72.500 - 70.000

for distilbert on 1d addition, 100 total sample with batch_size of 1 and weight_decay of 0.01:

12, 11, 15, 16, 18, 19, 20, 17, 17, 17, 16, 19, 17, 15, 19, 17, 15, 18, 19, 20, 20, 20

for distilbert on 1d addition, 100 total sample with batch_size of 8 and weight_decay of 0, learning rate of 2e-5:
11, 16, 18, 18, 19, 21, 21, 21, 21, 21, 22, 23, 23, 25, 19, 20, 28, 29, 29, 24

for distilbert on 1d addition, 100 total sample with batch_size of 8 and weight_decay of 0 and learning rate of 5e-5:
14, 23, 19, 20, 16, 16, 20, 23, 22, 27, 27, 29, 27, 30, 27, 27, 27, 28, 29, 27

for distilbert on 1d addition, 100 total sample with batch_size of 8 and weight_decay of 0 and learning rate of 7e-5:
13, 21, 18, 16, 19, 17, 26, 28, 29, 29, 23, 33, 27, 33, 32, 32, 29, 33, 31, 31

In [None]:
tokenizer.special_tokens_map

In [None]:
tokenizer.all_special_tokens

In [None]:
tokenizer.all_special_ids

In [None]:
tokenizer.get_special_tokens_mask([100, 102, 0, 101, 103, 110, 111, 245], already_has_special_tokens=True)

In [None]:
tokenizer._pad_token

In [None]:
# encode a sequence
sequence = ['2 + 4 = 6.']
labels = tokenizer(sequence, padding=True, max_length=16, truncation=True, return_tensors="pt")['input_ids'][0]
print(labels)
print(tokenizer.convert_ids_to_tokens(labels))
print(torch.full(labels.shape, 0.15))
print(labels.tolist())
print([val for val in labels.tolist()])
print(tokenizer.get_special_tokens_mask([100, 102, 0, 101, 103, 110, 111, 245], already_has_special_tokens=True))
# special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()]

In [None]:
[tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in tokenized_dataset['test'].to_list()]

In [None]:
class DataCollatorForLanguageModeling(DataCollatorMixin):
        return inputs, labels



In [None]:
import torch

def show_example(i):
    # tokenize the text
    row = onedigit_addition_test_dataset['test'][i]

    # print(row['label'])
    # print(row['text'])
    
    idx = row['masked'][:-1].split().index("[MASK]")
    label = row['unmasked'].split()[idx]
    inputs = tokenizer(row['masked'], return_tensors="pt")

    # get the model outputs
    outputs = model(**inputs)

    # get the predicted token
    predictions = torch.argmax(outputs.logits, dim=-1)
    masked_idx = torch.where(inputs.input_ids == tokenizer.mask_token_id)[1]
    correct_token_id = tokenizer.convert_tokens_to_ids(label)

    # print(predictions[0, masked_idx])
    # print(correct_token_id)
    # print the decoded token
    # print(tokenizer.decode(predictions[0, masked_idx]))

    # replace `[MASK]` with `tokenizer.decode(predictions[0, masked_idx])` in the `text`
    text = row['masked'].replace("[MASK]", tokenizer.decode(predictions[0, masked_idx]))

    if text != row['unmasked']:
        print(row['unmasked'])
        print(text)
        print(row['masked'])

In [None]:
i = 0

In [None]:
for i in range(len(onedigit_addition_test_dataset['test'])):
    show_example(i)

4 + [MASK] = 7. (5)
[MASK] + 8 = 10. (3)
5 + 7 = [MASK]. (14)
9 + 6 = [MASK]. (17)
5 + 4 = [MASK]. (11)
8 + [MASK] = 10. (3)
9 + 6 = [MASK]. (17)
4 + [MASK] = 7. (5)


In [None]:
def search_dataset(dataset, text, subset='train'):
    count = 0
    for i in range(len(dataset[subset])):
        if dataset[subset][i]['label'] == text:
            count += 1
    return count

In [None]:
train_heatmap = []
for i in range(10):
    train_heatmap.append([search_dataset(dataset, f"{i} + {j} = {i+j}.") for j in range(10)])

test_heatmap = []
for i in range(10):
    test_heatmap.append([search_dataset(dataset, f"{i} + {j} = {i+j}.", subset='test') for j in range(10)])

In [None]:
# plt the heatmap of the train and test dataset side by side
import matplotlib.pyplot as plt
import seaborn as sns

fig, ax = plt.subplots(1, 2, figsize=(10, 4))
sns.heatmap(train_heatmap, annot=True, ax=ax[0])
ax[0].set_title('Train Dataset')
sns.heatmap(test_heatmap, annot=True, ax=ax[1])
ax[1].set_title('Test Dataset')
plt.show()

In [None]:
model, tokenizer = load_model(model_name)

token_length(tokenizer, 10000, 1, number_floor=1001, ignore=[342, 346, 347, 348, 349])

# 342, 346, 3461

## `tokenize_label`: tokenizes the 'label' column <a class="anchor" id="tokenize_label"></a>

In [None]:
arithmetic_dataset = load_dataset(name="arithmetic_dataset_100_50")

sample = arithmetic_dataset["train"].shuffle(seed=42).select(range(3))

for row in sample:
    print(f"\n'>>> Masked: {row['text']}'")
    print(f"'>>> Unmasked: {row['label']}'")

In [None]:
type(arithmetic_dataset)

## Load the model and tokenizer

In [None]:
# model_name = 'arithmetic_model_100_50_4'

In [None]:
model, tokenizer = load_model(model_name)

## load the dataset

In [None]:
# set Hugging Face Hub user and password
from huggingface_hub import notebook_login

notebook_login()  # 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

In [None]:
# push the trained model to the Hugging Face Hub
model.push_to_hub("arithmetic_model")

In [None]:
# push the tokenizer to the Hugging Face Hub
tokenizer.push_to_hub("arithmetic_model")

In [None]:
import re


def convert_2d_num_to_1d(lst):
    """
    convert a 2d numerals in a list to 1d numerals
    by adding spaces between the numerals
    if the numeral is 1d add 0 as the first character
    :param lst: a list of numerals
    :return: a list of numerals
    """

    new_lst = []
    for txt in lst:
        # in case of 1d numeral add 0 as the first character
        txt = re.sub(r"\b(\d)\b", r"0 \1", txt)
        txt = re.sub(r"\b(\d)(\d)\b", r"\1 \2", txt)
        new_lst.append(txt)
    
    return new_lst

In [None]:
convert_2d_num_to_1d(["1 + 12 = 13.", "2 + 3 = 5."])

In [None]:
tokenizer('0 1 2 3 4 5 6 7 8 9', return_tensors="pt")

In [None]:
mi

In [None]:
import torch

# tokenize the text
text = "2 5 - 2 9 = [MASK] [MASK] ."

# idx = text.split().index("[MASK]")
# label = row['unmasked'].split()[idx]
inputs = tokenizer(text, return_tensors="pt")

# get the model outputs
outputs = model(**inputs)

# print(outputs)

# get the predicted token
predictions = torch.argmax(outputs.logits, dim=-1)

print(predictions)
print(text)
decoded = [tokenizer.decode(predictions[0][i]) for i in range(len(predictions[0]))]
print(decoded)


In [None]:
t1 = load_dataset('datasets/test1')

In [None]:
t2 = load_dataset('datasets/test2')

In [None]:
from datasets import concatenate_datasets

# Assuming dd1 and dd2 are your DatasetDict objects
t1 = {split: concatenate_datasets([t1[split], t2[split]])
                     for split in t1.keys()}


In [6]:
import torch

def show_example(i):
    # tokenize the text
    row = ds['test'][i]

    # print(row['label'])
    # print(row['text'])
    
    idx = row['masked'][:-1].split().index("[MASK]")
    label = row['unmasked'].split()[idx]
    inputs = tokenizer(row['masked'], return_tensors="pt")

    # get the model outputs
    outputs = model(**inputs)

    # get the predicted token
    predictions = torch.argmax(outputs.logits, dim=-1)
    masked_idx = torch.where(inputs.input_ids == tokenizer.mask_token_id)[1]
    correct_token_id = tokenizer.convert_tokens_to_ids(label)

    # print(predictions[0, masked_idx])
    # print(correct_token_id)
    # print the decoded token
    # print(tokenizer.decode(predictions[0, masked_idx]))

    # replace `[MASK]` with `tokenizer.decode(predictions[0, masked_idx])` in the `text`
    text = row['masked'].replace("[MASK]", tokenizer.decode(predictions[0, masked_idx]))

    if text != row['unmasked']:
        print(row['unmasked'])
        print(text)
        print(row['masked'])

In [58]:
inputs = tokenizer("1 1 9 - 1 1 = [MASK].", return_tensors="pt")
print(inputs)
outputs = model(**inputs)
pred = torch.argmax(outputs.logits, dim=-1)
print(pred)
tokenizer.decode(pred[0, -3:-2])

{'input_ids': tensor([[ 101, 1015, 1015, 1023, 1011, 1015, 1015, 1027,  103, 1012,  102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
tensor([[1012, 1015, 1015, 1023, 1011, 1015, 1015, 1014, 1020, 1012, 1015]])


'6'

In [22]:
[tokenizer.convert_tokens_to_ids(int(token)) for token in pred]

TypeError: 'int' object is not iterable