In [1]:
import pandas as pd
from datasets import Dataset, load_dataset
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from pathlib import Path
from metric import compute_metrics

In [2]:
model_name = 'gpt2'
path_to_model = Path('./model')
model = GPT2LMHeadModel.from_pretrained(path_to_model)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

In [3]:
with open('./1_test_str_prompt.txt', 'r', encoding='utf-8') as file:
    test_prompt = file.readlines()

with open('./1_test_str_target.txt', 'r', encoding='utf-8') as file:
    test_target = file.readlines()
    
clean_test_data = {'prompt': [doc.replace('\t', '\n')[:-2] for doc in test_prompt],
                    'target': [doc.replace('\t', '\n')[:-1] for doc in test_target]}

test_df = pd.DataFrame(clean_test_data)
test_dataset = Dataset.from_pandas(test_df)

In [4]:
test_dataset[1]


{'prompt': 'Table: | Description Losses 1939/40 1940/41 1941/42 1942/43 1943/44 1944/45 Total | Direct War Losses 360,000     183,000 543,000 | Murdered 75,000 100,000 116,000 133,000 82,000  506,000 | Deaths In Prisons & Camps 69,000 210,000 220,000 266,000 381,000  1,146,000 | Deaths Outside of Prisons & Camps  42,000 71,000 142,000 218,000  473,000 | Murdered in Eastern Regions      100,000 100,000 | Deaths other countries       2,000 | Total 504,000 352,000 407,000 541,000 681,000 270,000 2,770,000 |\nQuestion: how many people were murdered in 1940/41?\nAnswers:',
 'target': ' | 100,000 |'}

In [5]:
tokenized = tokenizer(test_dataset[15]["prompt"], truncation=True, return_tensors="pt")
tokenized

{'input_ids': tensor([[10962,    25,   930,  7536,  9385,  3471,     2, 10916,     2, 14413,
          3195, 25414, 46502,   590,   930,  2693,   513, 11287,  1906,  1925,
          1078, 42165,     9,  1303,  1157, 15778,  7663,  5595, 18899,    11,
          8355,   220,   370,  1849,  3682,  1906,  1485,  9415,    11, 14454,
           930,  2693,   838, 38684,  1303,  1157, 16754,  1906,    35, 11870,
         10499,  5595,   309, 16241,  7335,  8546,    11,  8355,   449,  3705,
           370,  1849,  1558,  1906,    22,  4317,    11, 10163,   930,  2693,
          1596,   379,  1849, 42007,  6618,  1303,  1065, 36065,  1891, 10499,
          5595,   376, 27067,  4244,    11,  5923,  9738,   370,  1849,  1485,
          1906,    21,  6740,    11, 49352,   930,  2693,  1987, 30941,  1531,
             9,  1303,  1157, 15778,  7663,  5595, 18899,    11,  8355,   220,
           370,  1849,  1238,  1906,   940,  9773,    11, 46636,   930,  3267,
           352,  7859,  1303,  1157, 1

In [6]:
generated_ids = model.generate(**tokenized, do_sample=False, max_new_tokens=100, repetition_penalty=5.0)

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


In [7]:
print(generated_ids)

tensor([[10962,    25,   930,  7536,  9385,  3471,     2, 10916,     2, 14413,
          3195, 25414, 46502,   590,   930,  2693,   513, 11287,  1906,  1925,
          1078, 42165,     9,  1303,  1157, 15778,  7663,  5595, 18899,    11,
          8355,   220,   370,  1849,  3682,  1906,  1485,  9415,    11, 14454,
           930,  2693,   838, 38684,  1303,  1157, 16754,  1906,    35, 11870,
         10499,  5595,   309, 16241,  7335,  8546,    11,  8355,   449,  3705,
           370,  1849,  1558,  1906,    22,  4317,    11, 10163,   930,  2693,
          1596,   379,  1849, 42007,  6618,  1303,  1065, 36065,  1891, 10499,
          5595,   376, 27067,  4244,    11,  5923,  9738,   370,  1849,  1485,
          1906,    21,  6740,    11, 49352,   930,  2693,  1987, 30941,  1531,
             9,  1303,  1157, 15778,  7663,  5595, 18899,    11,  8355,   220,
           370,  1849,  1238,  1906,   940,  9773,    11, 46636,   930,  3267,
           352,  7859,  1303,  1157, 16754,  1906,  

In [8]:
answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=tokenizer.eos_token)

In [9]:
answer[0]

'Table: | Date Opponent# Rank# Site TV Result Attendance | September 3 Tennessee–Chattanooga* #11 Legion Field • Birmingham, AL  W\xa042–13 82,109 | September 10 Vanderbilt #11 Bryant–Denny Stadium • Tuscaloosa, AL JPS W\xa017–7 70,123 | September 17 at\xa0Arkansas #12 Razorback Stadium • Fayetteville, AR ABC W\xa013–6 52,089 | September 24 Tulane* #11 Legion Field • Birmingham, AL  W\xa020–10 81,421 | October 1 Georgia #11 Bryant–Denny Stadium • Tuscaloosa, AL ESPN W\xa029–28 70,123 | October 8 Southern Miss* #11 Bryant–Denny Stadium • Tuscaloosa, AL  W\xa014–6 70,123 | October 15 at\xa0Tennessee #10 Neyland Stadium • Knoxville, TN (Third Saturday in October) ESPN W\xa017–13 96,856 | October 22 Ole Miss #8 Bryant–Denny Stadium • Tuscaloosa, AL (Rivalry) ABC W\xa021–10 70,123 | November 5 at\xa0LSU #6 Tiger Stadium • Baton Rouge, LA (Rivalry) ESPN W\xa035–17 75,453 | November 12 at\xa0#20\xa0Mississippi State #6 Scott Field • Starkville, MS (Rivalry) ABC W\xa029–25 41,358 | November 19

In [10]:
test_dataset[15]["target"]

' | 68 |'

### Cooked data test

In [11]:
prediction = ['17 years', '12', 'Italy', 'United States', '1979', '737522']
label = ['17 years', '15', 'German', 'United States of America', '1979', '15']

In [12]:
compute_metrics(prediction, label)

exact_matches: [1, 0, 0, 0, 1, 0]
character_diff: [1.0, 0.5, 0.18181818181818182, 0.7027027027027027, 1.0, 0.25]
f1_scores: [1.0, 0.0, 0.0, 0.6666666666666666, 1.0, 0.0]
bleu_scores: [0.316227766016838, 0, 0, 0.11633369384516798, 0.1778279410038923, 0]
meteor_scores: [0.9375, 0.0, 0.0, 0.4934210526315789, 0.5, 0.0]


{'exact_match': np.float64(0.3333333333333333),
 'character_diff': np.float64(0.6057534807534808),
 'f1': np.float64(0.4444444444444444),
 'bleu': np.float64(0.10173156681098305),
 'meteor': np.float64(0.3218201754385965)}

In [13]:
prediction[0].split(' ')

['17', 'years']

In [14]:
prediction[1].split(' ')

['12']