In [1]:
from collections import Counter, OrderedDict
import numpy as np
from pathlib import Path
import pandas as pd
import sys
import textwrap
from tqdm import tqdm
import torch
from typing import Dict, List, Set, Union

import evaluate
from transformers import Seq2SeqTrainingArguments
import bert_score

import unicodedata
import uuid

from aic_nlp_utils.batch import batch_apply
from aic_nlp_utils.encoding import nfc
from aic_nlp_utils.json import read_jsonl, read_json, write_json, write_jsonl
from aic_nlp_utils.fever import fever_detokenize, import_fever_corpus_from_sqlite
from simpletransformers.seq2seq import Seq2SeqModel, Seq2SeqArgs
import stanza
# stanza.download("en")

from zshot_fact_verify.models.arguments import ModelArguments, DataTrainingArguments
from zshot_fact_verify.models.load import load_tokenizer_and_model, find_last_checkpoint

%load_ext autoreload
%autoreload 2

# Combine and Fix QA2D Datasets

In [8]:
def combine_qa2ds(split_files, out_file, seed=1234):
    rng = np.random.RandomState(seed)
    data = []
    for sfile in split_files:
        assert len(sfile.items()) == 1
        for lang, path_ in sfile.items():
            pass
        split = read_jsonl(path_)
        for s in split:
            ta = s["turker_answer"]
            # temporal fix for lower case starting letters, should be fixed in all individual datasets
            # see below :)
            s["turker_answer"] = ta[:1].upper() + ta[1:]
            s["lang"] = lang
        data += split
    rng.shuffle(data)
    print(f"writing {len(data)} records to {out_file}")
    write_jsonl(out_file, data, mkdir=True)

combine_qa2ds([
    {"cs": "/mnt/data/factcheck/qa2d/cs/dev.jsonl"},
    {"en": "/mnt/data/factcheck/qa2d/en/dev.jsonl"},
    {"pl": "/mnt/data/factcheck/qa2d/pl/dev.jsonl"},
    {"sk": "/mnt/data/factcheck/qa2d/sk/dev.jsonl"}], 
    "/mnt/data/factcheck/qa2d/cs_en_pl_sk/dev.jsonl")

combine_qa2ds([
    {"cs": "/mnt/data/factcheck/qa2d/cs/train.jsonl"},
    {"en": "/mnt/data/factcheck/qa2d/en/train.jsonl"},
    {"pl": "/mnt/data/factcheck/qa2d/pl/train.jsonl"},
    {"sk": "/mnt/data/factcheck/qa2d/sk/train.jsonl"}], 
    "/mnt/data/factcheck/qa2d/cs_en_pl_sk/train.jsonl")

writing 41376 records to /mnt/data/factcheck/qa2d/cs_en_pl_sk/dev.jsonl
writing 242840 records to /mnt/data/factcheck/qa2d/cs_en_pl_sk/train.jsonl


In [20]:
def fix_qa2ds(split_files):
    # answers sometimes started with lowe-case letters
    for sfile in split_files:
        lines = read_jsonl(sfile)
        assert not Path(sfile + ".orig").is_file(), f"already exists: '{sfile}.orig'"
        Path(sfile).rename(sfile + ".orig") # backup
        for r in lines:
            rb = r["rule-based"]
            ta = r["turker_answer"]
            r["rule-based"] = rb[:1].upper() + rb[1:]
            r["turker_answer"] = ta[:1].upper() + ta[1:]
        print(f"writing {len(lines)} records to {sfile}")
        write_jsonl(sfile, lines, mkdir=True)

fix_qa2ds([
    "/mnt/data/factcheck/qa2d/cs/dev.jsonl",
    "/mnt/data/factcheck/qa2d/en/dev.jsonl",
    "/mnt/data/factcheck/qa2d/pl/dev.jsonl",
    "/mnt/data/factcheck/qa2d/sk/dev.jsonl", 
    "/mnt/data/factcheck/qa2d/cs/train.jsonl",
    "/mnt/data/factcheck/qa2d/en/train.jsonl",
    "/mnt/data/factcheck/qa2d/pl/train.jsonl",
    "/mnt/data/factcheck/qa2d/sk/train.jsonl", 
])

writing 10344 records to /mnt/data/factcheck/qa2d/cs/dev.jsonl
writing 10344 records to /mnt/data/factcheck/qa2d/en/dev.jsonl
writing 10344 records to /mnt/data/factcheck/qa2d/pl/dev.jsonl
writing 10344 records to /mnt/data/factcheck/qa2d/sk/dev.jsonl
writing 60710 records to /mnt/data/factcheck/qa2d/cs/train.jsonl
writing 60710 records to /mnt/data/factcheck/qa2d/en/train.jsonl
writing 60710 records to /mnt/data/factcheck/qa2d/pl/train.jsonl
writing 60710 records to /mnt/data/factcheck/qa2d/sk/train.jsonl


# Test Models

In [4]:
MODEL_NAME_ALL = "/mnt/personal/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qa2d/google/umt5-base_all"
MODEL_NAME_ALL2 = "/mnt/personal/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qa2d/google/mt5-large_all/checkpoint-156000"
MODEL_NAME_CS1 = "/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qa2d/google/mt5-large_cs_CZ/checkpoint-76000"
MODEL_NAME_CS2 = "/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qa2d/facebook/mbart-large-cc25_cs_CZ/checkpoint-26000"
MODEL_NAME_EN1 = "/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qa2d/facebook/mbart-large-cc25_en_US/checkpoint-30000"
MODEL_NAME_EN2 = "/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qa2d/google/mt5-large_en_US/checkpoint-94000"
MODEL_NAME_PL1 = "/mnt/personal/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qa2d/google/mt5-large_pl_PL/checkpoint-85000"
MODEL_NAME_SK1 = "/mnt/personal/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qa2d/google/mt5-large_sk_SK/checkpoint-85000"

DEV_FILE_ALL = "/mnt/data/factcheck/qa2d/cs_en_pl_sk/dev.jsonl"
DEV_FILE_CS = "/mnt/data/factcheck/qa2d/cs/dev.jsonl"
DEV_FILE_EN = "/mnt/data/factcheck/qa2d/en/dev.jsonl"
DEV_FILE_PL = "/mnt/data/factcheck/qa2d/pl/dev.jsonl"
DEV_FILE_SK = "/mnt/data/factcheck/qa2d/sk/dev.jsonl"

In [None]:
from simpletransformers.seq2seq import Seq2SeqModel, Seq2SeqArgs

def predict_split_original(model, data):
    # use batches for faster
    T = []
    Y = []
    X = [nfc(sample["answer"] + "[SEP]" + sample["question"]) for sample in data]
    Y = model.predict(X)
    T = [sample["turker_answer"] for sample in data]
    return Y, T

def evaluate_original_model(cfgs, out_json):
    rouge = evaluate.load("rouge")
    model_args = Seq2SeqArgs()
    model_args.max_length = 64
    original_model = Seq2SeqModel(
                encoder_decoder_type="bart", 
                encoder_decoder_name="/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/dependencies/QA2D_model",
                cuda_device=0,
                args=model_args
            )

    results = []
    for cfg in cfgs:
        lang = cfg['lang']
        data_file = cfg["data_file"]
        print(f"lang: {lang}, data file: {data_file}")

        data = read_jsonl(data_file)
        print(f"  loaded {len(data)} samples")
        
        Y, T = predict_split_original(original_model, data)
        
        ev = rouge.compute(predictions=Y, references=T)
        bsP, bsR, bsF1 = bert_score.score(Y, T, model_type="bert-base-multilingual-cased")
        ev["bert_score_P"] = bsP.mean().item()
        ev["bert_score_R"] = bsR.mean().item()
        ev["bert_score_F1"] = bsF1.mean().item()
        
        print(f"  EVAL = {ev}")
        print("---------------------------------------")
        res = cfg.copy()
        res["eval"] = ev
        res["Y"] = Y
        res["T"] = T
        results.append(res)
        write_jsonl(out_json, [res], append=True)
    return results

cfgs = [
    {"lang": "cs_CZ", "model": "original", "data_file": DEV_FILE_CS},
    {"lang": "en_US", "model": "original", "data_file": DEV_FILE_EN},
    {"lang": "pl_PL", "model": "original", "data_file": DEV_FILE_PL},
    {"lang": "sk_SK", "model": "original", "data_file": DEV_FILE_SK},
]

results = evaluate_original_model(cfgs, "/mnt/personal/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qa2d/results.jsonl")

In [6]:
def predict(model, tokenizer, inputs, max_source_length=1024, padding=True, device="cuda"):
    model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True, return_tensors="pt")
    input_ids = model_inputs["input_ids"].to(device)
    attention_mask = model_inputs["attention_mask"].to(device)
    with torch.no_grad():
        Y = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=768)
        predictions = tokenizer.batch_decode(
            Y, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )
    return predictions

def predict_split(model, tokenizer, data, batch_size=16):
    # use batches for faster
    T = []
    Y = []
    X = [nfc(sample["answer"] + "</s>" + sample["question"]) for sample in data]
    pfunc = lambda batch: predict(model, tokenizer, batch)
    Y = batch_apply(pfunc, X, batch_size=batch_size, show_progress=True)
    # some turker answers start with lower case letters; ROUGE ignores this but anyway,...
    T = [nfc(sample["turker_answer"][0:1].upper() + sample["turker_answer"][1:]) for sample in data]
    return Y, T

def evaluate_quality(cfgs, out_json):
    rouge = evaluate.load("rouge")
    results = []
    for cfg in cfgs:
        lang = cfg['lang']
        data_file = cfg["data_file"]
        model_name = cfg["model"]
        model_short = "/".join(Path(model_name).parts[8:])
        print(f"lang: {lang}, model: {model_short}, data file: {data_file}")

        data = read_jsonl(data_file)
        print(f"  loaded {len(data)} samples")
        
        model_args = ModelArguments(model_name_or_path=model_name)
        tokenizer, model, data_collator = load_tokenizer_and_model(model_args, lang=lang, fp16=True)
        model.to("cuda")
        model.eval();

        Y, T = predict_split(model, tokenizer, data, batch_size=16)
        ev = rouge.compute(predictions=Y, references=T)

        bsP, bsR, bsF1 = bert_score.score(Y, T, model_type="bert-base-multilingual-cased")
        ev["bert_score_P"] = bsP.mean().item()
        ev["bert_score_R"] = bsR.mean().item()
        ev["bert_score_F1"] = bsF1.mean().item()
        
        print(f"  EVAL = {ev}")
        print("---------------------------------------")
        res = cfg.copy()
        res["eval"] = ev
        res["Y"] = Y
        res["T"] = T
        results.append(res)
        write_jsonl(out_json, [res], append=True)
    return results

In [7]:
cfgs = [
    # {"lang": "cs_CZ", "model": MODEL_NAME_CS1, "data_file": DEV_FILE_CS},
    # {"lang": "cs_CZ", "model": MODEL_NAME_CS2, "data_file": DEV_FILE_CS},
    # {"lang": "en_US", "model": MODEL_NAME_EN1, "data_file": DEV_FILE_EN},
    # {"lang": "en_US", "model": MODEL_NAME_EN2, "data_file": DEV_FILE_EN},
    {"lang": "pl_PL", "model": MODEL_NAME_PL1, "data_file": DEV_FILE_PL},
    {"lang": "sk_SK", "model": MODEL_NAME_SK1, "data_file": DEV_FILE_SK},
]

results = evaluate_quality(cfgs, 
                         "/mnt/personal/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qa2d/results.jsonl")

lang: pl_PL, model: experiments/qa2d/google/mt5-large_pl_PL/checkpoint-85000, data file: /mnt/data/factcheck/qa2d/pl/dev.jsonl
  loaded 10344 samples


  0%|          | 0/647 [00:00<?, ?it/s]

  EVAL = {'rouge1': 0.7671426770279015, 'rouge2': 0.6167445636811792, 'rougeL': 0.7122150031245609, 'rougeLsum': 0.7121087817666008, 'bert_score_P': 0.923403263092041, 'bert_score_R': 0.9220861792564392, 'bert_score_F1': 0.9225190281867981}
---------------------------------------
lang: sk_SK, model: experiments/qa2d/google/mt5-large_sk_SK/checkpoint-85000, data file: /mnt/data/factcheck/qa2d/sk/dev.jsonl
  loaded 10344 samples


  0%|          | 0/647 [00:00<?, ?it/s]

  EVAL = {'rouge1': 0.7876781118384524, 'rouge2': 0.6400520532950926, 'rougeL': 0.7195672844868037, 'rougeLsum': 0.7193888996809157, 'bert_score_P': 0.9229689836502075, 'bert_score_R': 0.9209535717964172, 'bert_score_F1': 0.9217352867126465}
---------------------------------------


In [None]:
cfgs = [
    # {"lang": "all", "model": MODEL_NAME_ALL, "data_file": DEV_FILE_ALL},
    {"lang": "cs_CZ", "model": MODEL_NAME_ALL, "data_file": DEV_FILE_CS},
    {"lang": "en_US", "model": MODEL_NAME_ALL, "data_file": DEV_FILE_EN},
    {"lang": "pl_PL", "model": MODEL_NAME_ALL, "data_file": DEV_FILE_PL},
    {"lang": "sk_SK", "model": MODEL_NAME_ALL, "data_file": DEV_FILE_SK},
]

results = evaluate_quality(cfgs, 
                         "/mnt/personal/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qa2d/results.jsonl")

In [7]:
cfgs = [
    # {"lang": "all", "model": MODEL_NAME_ALL2, "data_file": DEV_FILE_ALL},
    {"lang": "cs_CZ", "model": MODEL_NAME_ALL2, "data_file": DEV_FILE_CS},
    {"lang": "en_US", "model": MODEL_NAME_ALL2, "data_file": DEV_FILE_EN},
    {"lang": "pl_PL", "model": MODEL_NAME_ALL2, "data_file": DEV_FILE_PL},
    {"lang": "sk_SK", "model": MODEL_NAME_ALL2, "data_file": DEV_FILE_SK},
]

results = evaluate_quality(cfgs, 
                         "/mnt/personal/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qa2d/results.jsonl")

lang: cs_CZ, model: experiments/qa2d/google/mt5-large_all/checkpoint-156000, data file: /mnt/data/factcheck/qa2d/cs/dev.jsonl
  loaded 10344 samples


  0%|          | 0/647 [00:00<?, ?it/s]

  EVAL = {'rouge1': 0.7829778956800801, 'rouge2': 0.6310825936825517, 'rougeL': 0.71030435160657, 'rougeLsum': 0.7104467150286194, 'bert_score_P': 0.9223543405532837, 'bert_score_R': 0.920818567276001, 'bert_score_F1': 0.9213548302650452}
---------------------------------------
lang: en_US, model: experiments/qa2d/google/mt5-large_all/checkpoint-156000, data file: /mnt/data/factcheck/qa2d/en/dev.jsonl
  loaded 10344 samples


  0%|          | 0/647 [00:00<?, ?it/s]

  EVAL = {'rouge1': 0.9348856529519217, 'rouge2': 0.8577113688879252, 'rougeL': 0.8872454655274765, 'rougeLsum': 0.887249012321926, 'bert_score_P': 0.9657933115959167, 'bert_score_R': 0.9641131162643433, 'bert_score_F1': 0.9648184776306152}
---------------------------------------
lang: pl_PL, model: experiments/qa2d/google/mt5-large_all/checkpoint-156000, data file: /mnt/data/factcheck/qa2d/pl/dev.jsonl
  loaded 10344 samples


  0%|          | 0/647 [00:00<?, ?it/s]

  EVAL = {'rouge1': 0.7630884193918397, 'rouge2': 0.6101584128950617, 'rougeL': 0.7071874956951201, 'rougeLsum': 0.7073298532364101, 'bert_score_P': 0.9224919676780701, 'bert_score_R': 0.9208790063858032, 'bert_score_F1': 0.9214576482772827}
---------------------------------------
lang: sk_SK, model: experiments/qa2d/google/mt5-large_all/checkpoint-156000, data file: /mnt/data/factcheck/qa2d/sk/dev.jsonl
  loaded 10344 samples


  0%|          | 0/647 [00:00<?, ?it/s]

  EVAL = {'rouge1': 0.7855878556292287, 'rouge2': 0.636764481628465, 'rougeL': 0.7169630516951414, 'rougeLsum': 0.7170890600080113, 'bert_score_P': 0.9222889542579651, 'bert_score_R': 0.9202671051025391, 'bert_score_F1': 0.9210512042045593}
---------------------------------------


In [8]:
def compare_results_qa2d(result_jsonls):
    data = []
    for rjsonl in result_jsonls:
        data += read_jsonl(rjsonl)
    # for d in data:
    #     t = d["eval"]["bert_score_R1"]
    #     d["eval"]["bert_score_R"] = t
    #     del d["eval"]["bert_score_R1"]
    # write_jsonl(result_jsonls[0], data)
    # return
    df = pd.DataFrame(data)
    models = ['/'.join(m.split("/")[-3:]) for m in df.model]
    df["model"] = models
    df["rouge1"] = [100*e["rouge1"] for e in df["eval"]]
    df["rouge2"] = [100*e["rouge2"] for e in df["eval"]]
    df["rougeL"] = [100*e["rougeL"] for e in df["eval"]]
    df["rougeLsum"] = [100*e["rougeLsum"] for e in df["eval"]]
    df["bert_score_P"] = [100*e["bert_score_P"] for e in df["eval"]]
    df["bert_score_R"] = [100*e["bert_score_R"] for e in df["eval"]]
    df["bert_score_F1"] = [100*e["bert_score_F1"] for e in df["eval"]]
    df = df[["lang", "model", "rouge1", "rouge2", "rougeL", "rougeLsum", "bert_score_P", "bert_score_R", "bert_score_F1"]]
    df.sort_values("lang", inplace=True)
    return df

df = compare_results_qa2d([
    "/mnt/personal/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qa2d/results.jsonl"
])

df

Unnamed: 0,lang,model,rouge1,rouge2,rougeL,rougeLsum,bert_score_P,bert_score_R,bert_score_F1
0,cs_CZ,original,64.518723,44.506546,53.023387,53.024811,83.560187,84.120321,83.777118
12,cs_CZ,google/mt5-large_all/checkpoint-156000,78.29779,63.108259,71.030435,71.044672,92.235434,92.081857,92.135483
5,cs_CZ,facebook/mbart-large-cc25_cs_CZ/checkpoint-26000,77.367544,61.625559,69.675564,69.668337,91.769564,91.5932,91.656566
4,cs_CZ,google/mt5-large_cs_CZ/checkpoint-76000,78.527929,63.534912,71.298343,71.297282,92.314482,92.084652,92.175812
8,cs_CZ,qa2d/google/umt5-base_all,77.822632,62.331959,70.382139,70.396227,92.12839,91.813821,91.946673
6,en_US,facebook/mbart-large-cc25_en_US/checkpoint-30000,93.051458,84.618493,87.766511,87.768816,96.241087,96.064985,96.138746
7,en_US,google/mt5-large_en_US/checkpoint-94000,93.505732,85.850939,88.753483,88.740754,96.591437,96.395785,96.47994
9,en_US,qa2d/google/umt5-base_all,93.332366,85.458207,88.547557,88.547558,96.519303,96.288776,96.389705
1,en_US,original,91.945534,82.951799,85.792845,85.790928,95.469755,95.382667,95.406002
13,en_US,google/mt5-large_all/checkpoint-156000,93.488565,85.771137,88.724547,88.724901,96.579331,96.411312,96.481848


In [9]:
print(df[~df.model.str.contains("umt5")][['lang', 'model', 'rouge1', 'rouge2', 'rougeL', 'bert_score_F1']].to_latex(float_format="%.1f", index=False))

\begin{tabular}{llrrrr}
\toprule
 lang &                                            model &  rouge1 &  rouge2 &  rougeL &  bert\_score\_F1 \\
\midrule
cs\_CZ &                                         original &    64.5 &    44.5 &    53.0 &           83.8 \\
cs\_CZ &           google/mt5-large\_all/checkpoint-156000 &    78.3 &    63.1 &    71.0 &           92.1 \\
cs\_CZ & facebook/mbart-large-cc25\_cs\_CZ/checkpoint-26000 &    77.4 &    61.6 &    69.7 &           91.7 \\
cs\_CZ &          google/mt5-large\_cs\_CZ/checkpoint-76000 &    78.5 &    63.5 &    71.3 &           92.2 \\
en\_US & facebook/mbart-large-cc25\_en\_US/checkpoint-30000 &    93.1 &    84.6 &    87.8 &           96.1 \\
en\_US &          google/mt5-large\_en\_US/checkpoint-94000 &    93.5 &    85.9 &    88.8 &           96.5 \\
en\_US &                                         original &    91.9 &    83.0 &    85.8 &           95.4 \\
en\_US &           google/mt5-large\_all/checkpoint-156000 &    93.5 &    85.8 &    

  print(df[~df.model.str.contains("umt5")][['lang', 'model', 'rouge1', 'rouge2', 'rougeL', 'bert_score_F1']].to_latex(float_format="%.1f", index=False))


In [None]:
data = read_jsonl("/mnt/personal/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qacg/google/mt5_results.jsonl")

In [None]:
[{'lang': 'all',
  'model': '/mnt/personal/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qa2d/google/umt5-base_all',
  'data_file': '/mnt/data/factcheck/qa2d/cs_en_pl_sk/dev.jsonl',
  'rouge': {'rouge1': 0.8121849974992059,
   'rouge2': 0.6774419510196882,
   'rougeL': 0.7497730778117646,
   'rougeLsum': 0.7497340248321841}},
 {'lang': 'cs_CZ',
  'model': '/mnt/personal/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qa2d/google/umt5-base_all',
  'data_file': '/mnt/data/factcheck/qa2d/cs/dev.jsonl',
  'rouge': {'rouge1': 0.7783607331568722,
   'rouge2': 0.6231575608427122,
   'rougeL': 0.7039034412670073,
   'rougeLsum': 0.703985133472832}},
 {'lang': 'en_US',
  'model': '/mnt/personal/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qa2d/google/umt5-base_all',
  'data_file': '/mnt/data/factcheck/qa2d/en/dev.jsonl',
  'rouge': {'rouge1': 0.9333374676443688,
   'rouge2': 0.8544504708653158,
   'rougeL': 0.88537866791236,
   'rougeLsum': 0.8855242981238776}},
 {'lang': 'pl_PL',
  'model': '/mnt/personal/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qa2d/google/umt5-base_all',
  'data_file': '/mnt/data/factcheck/qa2d/pl/dev.jsonl',
  'rouge': {'rouge1': 0.7579418551444153,
   'rouge2': 0.6037844541036594,
   'rougeL': 0.70104065486047,
   'rougeLsum': 0.7011065700120231}},
 {'lang': 'sk_SK',
  'model': '/mnt/personal/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qa2d/google/umt5-base_all',
  'data_file': '/mnt/data/factcheck/qa2d/sk/dev.jsonl',
  'rouge': {'rouge1': 0.7796807605964733,
   'rouge2': 0.6282711920363941,
   'rougeL': 0.7087832206401827,
   'rougeLsum': 0.7089273080145163}}]

In [3]:
model_args = ModelArguments(model_name_or_path="/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qa2d/facebook/mbart-large-cc25_cs_CZ/BEST/checkpoint-26000")
tokenizer, model, data_collator = load_tokenizer_and_model(model_args, lang="cs_CZ", fp16=True)

In [11]:
model.to("cuda");

In [20]:
def predict(model, tokenizer, inputs, max_source_length=1024, padding=False):
    model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True, return_tensors="pt")
    model_inputs = {k: model_inputs[k].to("cuda") for k in model_inputs.keys()}
    with torch.no_grad():
        Y = model.generate(**model_inputs, max_new_tokens=768)
        predictions = tokenizer.batch_decode(
            Y, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )
    return predictions


sample = data[6]
question = sample["question"]
answer = sample["answer"]
# question = "V kolika letech zemřel Petr?"
# answer = "25"
print(textwrap.fill(question))
print(answer)
predict(model, tokenizer, [answer + "</s>" + question])

Proč organismy dědí vlastnosti svých rodičů?
buňky potomků obsahují kopie genů z buněk jejich rodičů


['organismy dědí vlastnosti svých rodičů, protože buňky potomků obsahují kopie genů z buněk jejich rodičů.']