# Import necessary libraries

In [1]:
import torch

In [2]:
from datasets import load_dataset

In [1]:
from transformers import AutoModel, AutoTokenizer, AdamW, AutoConfig

# Download datasets

In [4]:
mnli_path = None

In [5]:
if mnli_path == None:
    mnli = load_dataset("multi_nli")
    mnli_path = "./datasets/mnli"
    mnli.save_to_disk(mnli_path)

Using custom data configuration default
Found cached dataset multi_nli (/home/yz709/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39)


  0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
qnli_path = None

In [7]:
if qnli_path == None:
    qnli = load_dataset("SetFit/qnli")
    qnli_path = "./datasets/qnli"
    qnli.save_to_disk(qnli_path)

Using custom data configuration SetFit--qnli-324fd6914ad1beff
Found cached dataset json (/home/yz709/.cache/huggingface/datasets/SetFit___json/SetFit--qnli-324fd6914ad1beff/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/3 [00:00<?, ?it/s]

# Inspect datasets

## For qnli dataset

In [8]:
qnli

DatasetDict({
    train: Dataset({
        features: ['text1', 'text2', 'label', 'idx', 'label_text'],
        num_rows: 104743
    })
    test: Dataset({
        features: ['text1', 'text2', 'label', 'idx', 'label_text'],
        num_rows: 5463
    })
    validation: Dataset({
        features: ['text1', 'text2', 'label', 'idx', 'label_text'],
        num_rows: 5463
    })
})

In [9]:
qnli_train = qnli['train']
qnli_test = qnli['test']
qnli_val = qnli['validation']

In [10]:
# label 0 - entailment, label 1 - not entailment
qnli_train[-2:]

{'text1': ['What individual was responsible for law and maintaining order in the county?',
  'How much of the gross domestic product was spent on public health in 2004?'],
 'text2': ['He was the top civil and military leader of the commandery and handled defense, lawsuits, seasonal instructions to farmers and recommendations of nominees for office sent annually to the capital in a quota system first established by Emperor Wu.',
  'Public expenditure health was at 8.9% of the GDP in 2004, whereas private expenditure was at 1.3%.'],
 'label': [1, 0],
 'idx': [104741, 104742],
 'label_text': ['not entailment', 'entailment']}

In [11]:
qnli_train['text1'][-2:]

['What individual was responsible for law and maintaining order in the county?',
 'How much of the gross domestic product was spent on public health in 2004?']

## For mnli dataset

In [12]:
mnli

DatasetDict({
    train: Dataset({
        features: ['promptID', 'pairID', 'premise', 'premise_binary_parse', 'premise_parse', 'hypothesis', 'hypothesis_binary_parse', 'hypothesis_parse', 'genre', 'label'],
        num_rows: 392702
    })
    validation_matched: Dataset({
        features: ['promptID', 'pairID', 'premise', 'premise_binary_parse', 'premise_parse', 'hypothesis', 'hypothesis_binary_parse', 'hypothesis_parse', 'genre', 'label'],
        num_rows: 9815
    })
    validation_mismatched: Dataset({
        features: ['promptID', 'pairID', 'premise', 'premise_binary_parse', 'premise_parse', 'hypothesis', 'hypothesis_binary_parse', 'hypothesis_parse', 'genre', 'label'],
        num_rows: 9832
    })
})

In [13]:
mnli_train = mnli['train']
mnli_val_match = mnli['validation_matched']
mnli_val_mismatch = mnli['validation_mismatched']

In [14]:
mnli_train

Dataset({
    features: ['promptID', 'pairID', 'premise', 'premise_binary_parse', 'premise_parse', 'hypothesis', 'hypothesis_binary_parse', 'hypothesis_parse', 'genre', 'label'],
    num_rows: 392702
})

In [15]:
mnli_train[0]

{'promptID': 31193,
 'pairID': '31193n',
 'premise': 'Conceptually cream skimming has two basic dimensions - product and geography.',
 'premise_binary_parse': '( ( Conceptually ( cream skimming ) ) ( ( has ( ( ( two ( basic dimensions ) ) - ) ( ( product and ) geography ) ) ) . ) )',
 'premise_parse': '(ROOT (S (NP (JJ Conceptually) (NN cream) (NN skimming)) (VP (VBZ has) (NP (NP (CD two) (JJ basic) (NNS dimensions)) (: -) (NP (NN product) (CC and) (NN geography)))) (. .)))',
 'hypothesis': 'Product and geography are what make cream skimming work. ',
 'hypothesis_binary_parse': '( ( ( Product and ) geography ) ( ( are ( what ( make ( cream ( skimming work ) ) ) ) ) . ) )',
 'hypothesis_parse': '(ROOT (S (NP (NN Product) (CC and) (NN geography)) (VP (VBP are) (SBAR (WHNP (WP what)) (S (VP (VBP make) (NP (NP (NN cream)) (VP (VBG skimming) (NP (NN work)))))))) (. .)))',
 'genre': 'government',
 'label': 1}

In [16]:
mnli_train[3]

{'promptID': 37397,
 'pairID': '37397e',
 'premise': 'How do you know? All this is their information again.',
 'premise_binary_parse': '( ( How ( ( ( do you ) know ) ? ) ) ( ( All this ) ( ( ( is ( their information ) ) again ) . ) ) )',
 'premise_parse': '(ROOT (S (SBARQ (WHADVP (WRB How)) (SQ (VBP do) (NP (PRP you)) (VP (VB know))) (. ?)) (NP (PDT All) (DT this)) (VP (VBZ is) (NP (PRP$ their) (NN information)) (ADVP (RB again))) (. .)))',
 'hypothesis': 'This information belongs to them.',
 'hypothesis_binary_parse': '( ( This information ) ( ( belongs ( to them ) ) . ) )',
 'hypothesis_parse': '(ROOT (S (NP (DT This) (NN information)) (VP (VBZ belongs) (PP (TO to) (NP (PRP them)))) (. .)))',
 'genre': 'fiction',
 'label': 0}

# Preprocess dataset QNLI

In [17]:
# roberta-base or roberta-large
PLM = "roberta-base"

In [18]:
# load model
model = AutoModel.from_pretrained(PLM)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.dense.bias', 'lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.dense.weight', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [19]:
# load tokeniser for the specific model
tokeniser = AutoTokenizer.from_pretrained(PLM)

In [20]:
# test the tokeniser
test_input = "Simple text to test -- tokenizer [roberta-base]"
tokenised_test_input = tokeniser.encode(test_input, return_tensors="pt")
print(tokenised_test_input)

tensor([[    0, 45093,  2788,     7,  1296,   480, 19233,  6315,   646,  1001,
          6747,   102,    12, 11070,   742,     2]])


In [21]:
test_token_id_to_text = tokeniser.convert_ids_to_tokens(tokeniser.encode(test_input))
print(test_token_id_to_text)

['<s>', 'Simple', 'Ġtext', 'Ġto', 'Ġtest', 'Ġ--', 'Ġtoken', 'izer', 'Ġ[', 'ro', 'bert', 'a', '-', 'base', ']', '</s>']


In [22]:
# test the model
test_output = model(tokenised_test_input)

In [23]:
"""
First output:
# (1,16,768): batch_size * #tokens * embedding_size_defined_by_model
# our input is one sentence with 16 tokens
Second output - pooler output, the embedding result of the first token of the sequence <s>
"""
test_output[0].size(), test_output[1].size()

(torch.Size([1, 16, 768]), torch.Size([1, 768]))

In [38]:
def manual_prompt(questions, answers, template):
    prompt = questions + template + answers
    prompt_token_id = tokeniser(prompt, return_tensors="pt").input_ids
    prompt_id_to_text = tokeniser.convert_ids_to_tokens(prompt_token_id[0])
    mask_token_pos = prompt_id_to_text.index(tokeniser.mask_token)
    return prompt_token_id, mask_token_pos

In [34]:
manual_template = " <mask>. "
prompt_token_id, mask_token_pos = manual_prompt(qnli_train[0]['text1'], qnli_train[0]['text2'], manual_template)

When did the third Digimon series begin? <mask>. Unlike the two seasons before it and most of the seasons that followed, Digimon Tamers takes a darker and more realistic approach to its story featuring Digimon who do not reincarnate after their deaths and more complex character development in the original Japanese.


In [26]:
model.eval()
with torch.no_grad():
    predictions = model(prompt_token_id)[0]
print(predictions.size())

torch.Size([1, 68, 768])


In [27]:
values, indices = torch.sort(predictions[0, mask_token_pos], descending=True)

In [28]:
result = list(zip(tokeniser.convert_ids_to_tokens(indices), values))

In [29]:
result[0]

('Ġreal', tensor(10.3457))

In [55]:
result

[('Ġreal', tensor(10.3457)),
 ('k', tensor(0.9194)),
 ('Ġcountries', tensor(0.6951)),
 ('Ġhere', tensor(0.6229)),
 ('Ġwith', tensor(0.5075)),
 ('b', tensor(0.5044)),
 ('Ġ17', tensor(0.4998)),
 ('Ġsome', tensor(0.4802)),
 ('Ġgroup', tensor(0.4780)),
 ('Ġdifferent', tensor(0.4544)),
 ('Ġpublic', tensor(0.4540)),
 ('<unk>', tensor(0.4357)),
 ('os', tensor(0.4248)),
 ('Ġhis', tensor(0.4180)),
 ('Ġwin', tensor(0.4158)),
 ('Ġearly', tensor(0.4154)),
 ('n', tensor(0.4143)),
 ('Ġboth', tensor(0.3981)),
 ('Ġfeel', tensor(0.3941)),
 ('ĠB', tensor(0.3893)),
 ('l', tensor(0.3667)),
 ('Ġthe', tensor(0.3628)),
 ('ĠHowever', tensor(0.3595)),
 ('N', tensor(0.3575)),
 ('Ġtogether', tensor(0.3547)),
 ('ĠG', tensor(0.3540)),
 ('Ġ3', tensor(0.3531)),
 ('Ġits', tensor(0.3499)),
 ('Ġface', tensor(0.3449)),
 ('.', tensor(0.3404)),
 ('Ġforward', tensor(0.3401)),
 ('Ġ21', tensor(0.3373)),
 ('Ġc', tensor(0.3352)),
 ('Ġcame', tensor(0.3332)),
 ('Ġsix', tensor(0.3330)),
 ('Ġsays', tensor(0.3226)),
 ('Ġschool', te

In [30]:
qnli_train[0]['label']

1

In [53]:
def fine_tune(questions, answers, labels, template="? <mask>, ", ent_token="Yes", not_ent_token="No"):
    ent_id = tokeniser.convert_tokens_to_ids(ent_token)
    not_ent_id = tokeniser.convert_tokens_to_ids(not_ent_token)
    
    optimiser = AdamW(model.parameters(), lr=1e-3)
    for q,a,label in zip(questions, answers, labels):
        prompt_token_id, mask_token_pos = manual_prompt(q,a,template)
        pred = model(prompt_token_id)[0]
        pred_prob = pred[0, mask_token_pos][[ent_id, not_ent_id]]
        prob = torch.nn.functional.softmax(pred_prob, dim=0)
        lossFunc = torch.nn.CrossEntropyLoss()
        loss=lossFunc(prob.unsqueeze(0), torch.tensor([label]))
        loss.backward()
        optimiser.step()

In [54]:
fine_tune(qnli_train['text1'], qnli_train['text2'], qnli_train['label'])

9904 3084
torch.Size([768])


IndexError: index 9904 is out of bounds for dimension 0 with size 768