In [1]:
from collections import Counter, OrderedDict
import numpy as np
from pathlib import Path
import re
import pandas as pd
import sys
import textwrap
from tqdm import tqdm
import torch
from typing import Dict, List, Set, Union

import evaluate
from transformers import Seq2SeqTrainingArguments
import bert_score

import uuid

from aic_nlp_utils.batch import batch_apply
from aic_nlp_utils.encoding import nfc
from aic_nlp_utils.json import read_jsonl, read_json, write_json, write_jsonl
from aic_nlp_utils.fever import fever_detokenize, import_fever_corpus_from_sqlite

%cd /home/drchajan/devel/python/FC/Zero-shot-Fact-Verification

from zshot_fact_verify.models.arguments import ModelArguments, DataTrainingArguments
from zshot_fact_verify.models.load import load_tokenizer_and_model, find_last_checkpoint
from zshot_fact_verify.qg.question_generation import BatchQuestionGenerator

%load_ext autoreload
%autoreload 2

/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification


This notebook prepares data to train QG models for different languages.
The paper used T5 model from here: https://github.com/patil-suraj/question_generation

In [2]:
# at least sk-quad dataset has word-tokenized question, this should remove all unneeded whitespace
def word_detokenize(txt: str) -> str:
    def pair_detokenize(txt, s):
        sub = " " + s + " "
        idxs = [m.start() for m in re.finditer(sub, txt)]
        if len(idxs) > 0 and len(idxs) % 2 == 0:
            # ignore odd number of pair substrings
            otxt = ""
            first = 0
            for i, idx in enumerate(idxs):
                if i % 2 == 0:
                    otxt += txt[first:idx] + ' ' + s
                else:
                    otxt += txt[first:idx] + s + ' '
                first = idx + 3
            otxt += txt[first:]
            return otxt
        else:
            return txt
    
    txt = txt.replace("``", '"').replace("''", '"').replace(",,", '"')
    txt = pair_detokenize(txt, '"')
    txt = txt.replace(" '", "'")
    txt = txt.replace(" - ", "-")
    txt = txt.replace(" .", ".").replace(" ,", ",").replace(" ?", "?").replace(" :", ":").replace(" ;", ";")
    txt = txt.replace("( ", "(").replace(" )", ")")
    return txt

In [None]:
def convert_and_fix_squad(fsrc, fdst, word_detokenize_questions=False):
    # converts SQUAD format to "linear" JSONL usable for training seq2seq model
    # skips impossible answers if Squad 2.0 is given
    data = read_json(fsrc)["data"]
    records = []
    for rec in tqdm(data):
        title = nfc(rec["title"])
        for par in rec["paragraphs"]:
            context = nfc(par["context"])
            for qas in par["qas"]:
                if "is_impossible" in qas and qas["is_impossible"]:
                    continue
                answer_set = set()
                for ans in qas["answers"]:
                    ans = ans["text"]
                    if ans[-1] in [".", ","]:
                        ans = ans[:-1]
                    answer_set.add(ans)
                answers = sorted(list(answer_set))
                for answer in answers:
                    question = qas["question"].strip()
                    question = question[0].upper() + question[1:]
                    question = question.replace("  ", " ")
                    if word_detokenize_questions:
                        question = word_detokenize(question)
                    if not (question.endswith("?") or question.endswith('?"')):
                        if not question.lower().startswith("name"):
                            if question[-1] in [".", ":", ">", "/"]: # Probably wrong parsing of original data
                                question = question[:-1] + "?"
                            else:
                                question += "?"
                    question = nfc(question)
                    answer = nfc(answer)
                    print(question)
                    records.append({"title": title, "context": context, "question": question, "answer": answer})
    write_jsonl(fdst, records, mkdir=True)

WORD_DETOKENIZE = False
# SQUAD_DIR, SQUAD_TRN, SQUAD_DEV = 'squad-cs', "train-v1.1.json", "dev-v1.1.json"
# SQUAD_DIR, SQUAD_TRN, SQUAD_DEV = 'squad-sk', "train-230321.json", "dev-230321.json"
SQUAD_DIR, SQUAD_TRN, SQUAD_DEV, WORD_DETOKENIZE = 'sk-quad-220614', "sk-quad-220614-train.json", "sk-quad-220614-dev.json", True
SQUAD_ROOT = Path(f"/mnt/data/factcheck/squad/{SQUAD_DIR}")
QG_ROOT = Path(f"/mnt/data/factcheck/qg/{SQUAD_DIR}")

convert_and_fix_squad(Path(SQUAD_ROOT, SQUAD_DEV), Path(QG_ROOT, SQUAD_DEV), word_detokenize_questions=WORD_DETOKENIZE)
convert_and_fix_squad(Path(SQUAD_ROOT, SQUAD_TRN), Path(QG_ROOT, SQUAD_TRN), word_detokenize_questions=WORD_DETOKENIZE)

In [None]:
def convert_and_fix_csv_squad(fsrc, fdst):
    # converts CSV SQUAD (e.g. SQUAD-pl) format to "linear" JSONL usable for training seq2seq model
    # skips impossible answers
    # title is missing in SQUAD-PL
    df = pd.read_csv(fsrc)[["context", "question", "answer_text"]]
    print(f"#records = {df.shape}")
    df.drop_duplicates(inplace=True)
    df.dropna(inplace=True)
    print(f"after drop #records = {df.shape}")

    records = []
    for idx, r in tqdm(df.iterrows()):
        title = None
        context = nfc(str(r.context))
        question = nfc(str(r.question))
        answer = nfc(str(r.answer_text))
        print(question)
        # print(" > " + answer)
        rec = {"title": title, "context": context, "question": question, "answer": answer}
        records.append(rec)
  
    write_jsonl(fdst, records, mkdir=True)


SQUAD_DIR, SQUAD_TRN, SQUAD_DEV = 'squad-pl', "train", "test"
SQUAD_ROOT = Path(f"/mnt/data/factcheck/squad/{SQUAD_DIR}")
QG_ROOT = Path(f"/mnt/data/factcheck/qg/{SQUAD_DIR}")

convert_and_fix_csv_squad(Path(SQUAD_ROOT, f"{SQUAD_DEV}.csv"), Path(QG_ROOT, f"{SQUAD_DEV}.jsonl"))
convert_and_fix_csv_squad(Path(SQUAD_ROOT, f"{SQUAD_TRN}.csv"), Path(QG_ROOT, f"{SQUAD_TRN}.jsonl"))

# Combine SQUAD Datasets

In [16]:
def combine_squads(split_files, out_file, seed=1234):
    rng = np.random.RandomState(seed)
    data = []
    for sfile in split_files:
        assert len(sfile.items()) == 1
        for lang, path_ in sfile.items():
            pass
        split = read_jsonl(path_)
        for s in split:
            s["lang"] = lang
        data += split
    rng.shuffle(data)
    print(f"writing {len(data)} records to {out_file}")
    write_jsonl(out_file, data, mkdir=True)

combine_squads([
    {"cs": "/mnt/data/factcheck/qg/squad-cs/train-v1.1.json"},
    {"en": "/mnt/data/factcheck/qg/squad-en/train-v1.1.jsonl"},
    {"pl": "/mnt/data/factcheck/qg/squad-pl/train.jsonl"},
    {"sk": "/mnt/data/factcheck/qg/sk-quad-220614/sk-quad-220614-train.json"}], 
    "/mnt/data/factcheck/qg/squad-cs_en_pl_sk/train.jsonl")

combine_squads([
    {"cs": "/mnt/data/factcheck/qg/squad-cs/dev-v1.1.json"},
    {"en": "/mnt/data/factcheck/qg/squad-en/dev-v1.1.jsonl"},
    {"pl": "/mnt/data/factcheck/qg/squad-pl/test.jsonl"},
    {"sk": "/mnt/data/factcheck/qg/sk-quad-220614/sk-quad-220614-dev.json"}], 
    "/mnt/data/factcheck/qg/squad-cs_en_pl_sk/dev.jsonl")

writing 252281 records to /mnt/data/factcheck/qg/squad-cs_en_pl_sk/train.jsonl
writing 40933 records to /mnt/data/factcheck/qg/squad-cs_en_pl_sk/dev.jsonl


# Model Evaluation by ROUGE

In [15]:
def predict(model, tokenizer, inputs, max_source_length=1024, padding=True, device="cuda"):
    model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True, return_tensors="pt")
    input_ids = model_inputs["input_ids"].to(device)
    attention_mask = model_inputs["attention_mask"].to(device)
    with torch.no_grad():
        Y = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=768)
        predictions = tokenizer.batch_decode(
            Y, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )
    return predictions

def predict_original_paper(data):
    from zshot_fact_verify.claim_generation.T5_QG import pipeline
    qg_nlp = pipeline("question-generation", model='valhalla/t5-base-qg-hl', qg_format="highlight", gpu_index=0)
    
    def predict_batch(data): 
        sources = [s["context"] for s in data]
        answers = [s["answer"] for s in data]
        Y = qg_nlp.batch_qg_with_answer(sources, answers)
        return Y
    
    Y = batch_apply(predict_batch, data, batch_size=32, show_progress=True)
    Y = [y["question"] for y in Y]
    T = [s["question"] for s in data]
    return Y, T

def predict_split(model, tokenizer, data, batch_size=128):
    # use batches for faster
    T = []
    Y = []
    X = [nfc(sample["answer"] + "</s>" + sample["context"]) for sample in data]
    pfunc = lambda batch: predict(model, tokenizer, batch)
    Y = batch_apply(pfunc, X, batch_size=batch_size, show_progress=True)
    T = [nfc(sample["question"]) for sample in data]
    return Y, T

def evaluate_rouge(Y, T):
    rouge = evaluate.load("rouge")
    results = rouge.compute(predictions=Y, references=T)
    return results


def evaluate_quality(cfgs, out_json):
    rouge = evaluate.load("rouge")
    results = []
    for cfg in cfgs:
        lang = cfg['lang']
        data_file = cfg["data_file"]
        model_name = cfg["model"]

        if model_name == "original":
            print(f"lang: {lang}, model: ORIGINAL, data file: {data_file}")
            data = read_jsonl(data_file)
            print(f"  loaded {len(data)} samples")
            Y, T = predict_original_paper(data)
        else:
            model_short = "/".join(Path(model_name).parts[8:])
            print(f"lang: {lang}, model: {model_short}, data file: {data_file}")

            data = read_jsonl(data_file)
            print(f"  loaded {len(data)} samples")
            
            model_args = ModelArguments(model_name_or_path=model_name)
            tokenizer, model, data_collator = load_tokenizer_and_model(model_args, lang=lang, fp16=True)
            model.to("cuda")
            model.eval();

            Y, T = predict_split(model, tokenizer, data, batch_size=32)

        ev = rouge.compute(predictions=Y, references=T)

        bsP, bsR, bsF1 = bert_score.score(Y, T, model_type="bert-base-multilingual-cased")
        ev["bert_score_P"] = bsP.mean().item()
        ev["bert_score_R"] = bsR.mean().item()
        ev["bert_score_F1"] = bsF1.mean().item()
        
        print(f"  EVAL = {ev}")
        res = cfg.copy()
        res["eval"] = ev
        res["Y"] = Y
        res["T"] = T
        results.append(res)
        write_jsonl(out_json, [res], append=True)
    return results

In [8]:
LANG = "all"
LANG_SHORT = "all"
MODEL_NAME_ALL = "/mnt/personal/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qg/google/mt5-large_all/checkpoint-126000"
MODEL_NAME_CS = "/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qg/google/mt5-large_cs" # FINAL 
MODEL_NAME_EN = "/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qg/google/mt5-large_en/checkpoint-64000" # FINAL
MODEL_NAME_PL = "/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qg/google/mt5-large_pl" # FINAL
MODEL_NAME_SK = "/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qg/google/mt5-large_sk/checkpoint-61000" # FINAL

HIGHLIGHT = False
DEV_FILE_ALL = "/mnt/data/factcheck/qg/squad-cs_en_pl_sk/dev.jsonl" #ALL
DEV_FILE_CS = "/mnt/data/factcheck/qg/squad-cs/dev-v1.1.json"
DEV_FILE_EN = "/mnt/data/factcheck/qg/squad-en//dev-v1.1.jsonl"
DEV_FILE_PL = "/mnt/data/factcheck/qg/squad-pl/test.jsonl"
DEV_FILE_SK = "/mnt/data/factcheck/qg/sk-quad-220614/sk-quad-220614-dev.json"


cfgs = [
    {"lang": "cs_CZ", "model": MODEL_NAME_CS, "data_file": DEV_FILE_CS},
    {"lang": "en_US", "model": MODEL_NAME_EN, "data_file": DEV_FILE_EN},
    {"lang": "pl_PL", "model": MODEL_NAME_PL, "data_file": DEV_FILE_PL},
    {"lang": "sk_SK", "model": MODEL_NAME_SK, "data_file": DEV_FILE_SK},
    
    # {"lang": "all", "model": MODEL_NAME_ALL, "data_file": DEV_FILE_ALL}, # NOT USED
    {"lang": "cs_CZ", "model": MODEL_NAME_ALL, "data_file": DEV_FILE_CS},
    {"lang": "en_US", "model": MODEL_NAME_ALL, "data_file": DEV_FILE_EN},
    {"lang": "pl_PL", "model": MODEL_NAME_ALL, "data_file": DEV_FILE_PL},
    {"lang": "sk_SK", "model": MODEL_NAME_ALL, "data_file": DEV_FILE_SK},
]

results = evaluate_quality(cfgs, "/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qg/results.jsonl")

lang: pl_PL, model: qg/google/mt5-large_pl, data file: /mnt/data/factcheck/qg/squad-pl/test.jsonl
  loaded 3805 samples


  0%|          | 0/119 [00:00<?, ?it/s]

  EVAL = {'rouge1': 0.37488054319641706, 'rouge2': 0.2138141866465533, 'rougeL': 0.3587911539945141, 'rougeLsum': 0.35832123857523335, 'bert_score_P': 0.8176380395889282, 'bert_score_R': 0.8090795874595642, 'bert_score_F1': 0.8125781416893005}
lang: sk_SK, model: qg/google/mt5-large_sk/checkpoint-61000, data file: /mnt/data/factcheck/qg/sk-quad-220614/sk-quad-220614-dev.json
  loaded 7808 samples


  0%|          | 0/244 [00:00<?, ?it/s]

  EVAL = {'rouge1': 0.6171535214446101, 'rouge2': 0.4522694598934476, 'rougeL': 0.5945420489207971, 'rougeLsum': 0.5946052945286857, 'bert_score_P': 0.871673583984375, 'bert_score_R': 0.8658138513565063, 'bert_score_F1': 0.8678244948387146}
lang: cs_CZ, model: experiments/qg/google/mt5-large_all/checkpoint-126000, data file: /mnt/data/factcheck/qg/squad-cs/dev-v1.1.json
  loaded 11722 samples


  0%|          | 0/367 [00:00<?, ?it/s]

  EVAL = {'rouge1': 0.3708815086247411, 'rouge2': 0.1918320137868037, 'rougeL': 0.3460518595746516, 'rougeLsum': 0.3460155261032911, 'bert_score_P': 0.807235598564148, 'bert_score_R': 0.7936057448387146, 'bert_score_F1': 0.799572229385376}
lang: en_US, model: experiments/qg/google/mt5-large_all/checkpoint-126000, data file: /mnt/data/factcheck/qg/squad-en//dev-v1.1.jsonl
  loaded 17598 samples


  0%|          | 0/550 [00:00<?, ?it/s]

  EVAL = {'rouge1': 0.4971248205606317, 'rouge2': 0.2780295951076388, 'rougeL': 0.46379181760387117, 'rougeLsum': 0.46381832576928517, 'bert_score_P': 0.8460949063301086, 'bert_score_R': 0.8336578607559204, 'bert_score_F1': 0.8391814827919006}
lang: pl_PL, model: experiments/qg/google/mt5-large_all/checkpoint-126000, data file: /mnt/data/factcheck/qg/squad-pl/test.jsonl
  loaded 3805 samples


  0%|          | 0/119 [00:00<?, ?it/s]

  EVAL = {'rouge1': 0.3717682297658417, 'rouge2': 0.20860462247647032, 'rougeL': 0.35546015030178746, 'rougeLsum': 0.3553079250726834, 'bert_score_P': 0.8180533051490784, 'bert_score_R': 0.8055732250213623, 'bert_score_F1': 0.8110008239746094}
lang: sk_SK, model: experiments/qg/google/mt5-large_all/checkpoint-126000, data file: /mnt/data/factcheck/qg/sk-quad-220614/sk-quad-220614-dev.json
  loaded 7808 samples


  0%|          | 0/244 [00:00<?, ?it/s]

  EVAL = {'rouge1': 0.612123392896875, 'rouge2': 0.44658048499923364, 'rougeL': 0.5897836393029163, 'rougeLsum': 0.5900234805856639, 'bert_score_P': 0.8711510896682739, 'bert_score_R': 0.8627465963363647, 'bert_score_F1': 0.8660293221473694}


In [10]:
# LANG = "all"
# LANG_SHORT = "all"
# QA2Ds were here!
# MODEL_NAME_ALL = "/mnt/personal/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qa2d/google/umt5-base_all"
# MODEL_NAME_CS = "/mnt/personal/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qa2d/google/umt5-base_cs_CZ" # FINAL 
# MODEL_NAME_EN = "/mnt/personal/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qa2d/google/umt5-base_en_US" # FINAL
# MODEL_NAME_PL = "/mnt/personal/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qa2d/google/umt5-base_pl_PL" # FINAL
# MODEL_NAME_SK = "/mnt/personal/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qa2d/google/umt5-base_sk_SK" # FINAL

MODEL_NAME_ALL = "/mnt/personal/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qg/google/umt5-base_all"
MODEL_NAME_CS = "/mnt/personal/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qg/google/umt5-base_cs_CZ" # FINAL 
MODEL_NAME_EN = "/mnt/personal/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qg/google/umt5-base_en_US" # FINAL
MODEL_NAME_PL = "/mnt/personal/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qg/google/umt5-base_pl_PL" # FINAL
MODEL_NAME_SK = "/mnt/personal/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qg/google/umt5-base_sk_SK" # FINAL


HIGHLIGHT = False
DEV_FILE_ALL = "/mnt/data/factcheck/qg/squad-cs_en_pl_sk/dev.jsonl" #ALL
DEV_FILE_CS = "/mnt/data/factcheck/qg/squad-cs/dev-v1.1.json"
DEV_FILE_EN = "/mnt/data/factcheck/qg/squad-en//dev-v1.1.jsonl"
DEV_FILE_PL = "/mnt/data/factcheck/qg/squad-pl/test.jsonl"
DEV_FILE_SK = "/mnt/data/factcheck/qg/sk-quad-220614/sk-quad-220614-dev.json"


cfgs = [
    {"lang": "cs_CZ", "model": MODEL_NAME_CS, "data_file": DEV_FILE_CS},
    {"lang": "en_US", "model": MODEL_NAME_EN, "data_file": DEV_FILE_EN},
    {"lang": "pl_PL", "model": MODEL_NAME_PL, "data_file": DEV_FILE_PL},
    {"lang": "sk_SK", "model": MODEL_NAME_SK, "data_file": DEV_FILE_SK}, # done
    
    # {"lang": "all", "model": MODEL_NAME_ALL, "data_file": DEV_FILE_ALL}, # missing!
    {"lang": "cs_CZ", "model": MODEL_NAME_ALL, "data_file": DEV_FILE_CS},
    {"lang": "en_US", "model": MODEL_NAME_ALL, "data_file": DEV_FILE_EN},
    {"lang": "pl_PL", "model": MODEL_NAME_ALL, "data_file": DEV_FILE_PL},
    {"lang": "sk_SK", "model": MODEL_NAME_ALL, "data_file": DEV_FILE_SK},
]

results = evaluate_quality(cfgs, "/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qg/results.jsonl")

lang: cs_CZ, model: experiments/qg/google/umt5-base_cs_CZ, data file: /mnt/data/factcheck/qg/squad-cs/dev-v1.1.json
  loaded 11722 samples


  0%|          | 0/367 [00:00<?, ?it/s]

  EVAL = {'rouge1': 0.3382812891335044, 'rouge2': 0.1656928084780808, 'rougeL': 0.3154058686951723, 'rougeLsum': 0.3154952471415421, 'bert_score_P': 0.7982181310653687, 'bert_score_R': 0.7836409211158752, 'bert_score_F1': 0.7900723814964294}
lang: en_US, model: experiments/qg/google/umt5-base_en_US, data file: /mnt/data/factcheck/qg/squad-en//dev-v1.1.jsonl
  loaded 17598 samples


  0%|          | 0/550 [00:00<?, ?it/s]

  EVAL = {'rouge1': 0.4803576161922338, 'rouge2': 0.26163355416835743, 'rougeL': 0.44866181202066013, 'rougeLsum': 0.4486326232848925, 'bert_score_P': 0.8420814871788025, 'bert_score_R': 0.8275054693222046, 'bert_score_F1': 0.8340489268302917}
lang: pl_PL, model: experiments/qg/google/umt5-base_pl_PL, data file: /mnt/data/factcheck/qg/squad-pl/test.jsonl
  loaded 3805 samples


  0%|          | 0/119 [00:00<?, ?it/s]

  EVAL = {'rouge1': 0.3121625460931829, 'rouge2': 0.16462217627122833, 'rougeL': 0.2969682902766667, 'rougeLsum': 0.29683810731226695, 'bert_score_P': 0.8020727038383484, 'bert_score_R': 0.7883787155151367, 'bert_score_F1': 0.7943432331085205}
lang: sk_SK, model: experiments/qg/google/umt5-base_sk_SK, data file: /mnt/data/factcheck/qg/sk-quad-220614/sk-quad-220614-dev.json
  loaded 7808 samples


  0%|          | 0/244 [00:00<?, ?it/s]

  EVAL = {'rouge1': 0.5825928038968479, 'rouge2': 0.4170115077812878, 'rougeL': 0.5606353587407773, 'rougeLsum': 0.5607869963702234, 'bert_score_P': 0.8632833957672119, 'bert_score_R': 0.8517776727676392, 'bert_score_F1': 0.8565813899040222}
lang: cs_CZ, model: experiments/qg/google/umt5-base_all, data file: /mnt/data/factcheck/qg/squad-cs/dev-v1.1.json
  loaded 11722 samples


  0%|          | 0/367 [00:00<?, ?it/s]

  EVAL = {'rouge1': 0.35302872397920193, 'rouge2': 0.17738190137332735, 'rougeL': 0.32914259636287213, 'rougeLsum': 0.32920015243461415, 'bert_score_P': 0.8018673062324524, 'bert_score_R': 0.7882830500602722, 'bert_score_F1': 0.7942251563072205}
lang: en_US, model: experiments/qg/google/umt5-base_all, data file: /mnt/data/factcheck/qg/squad-en//dev-v1.1.jsonl
  loaded 17598 samples


  0%|          | 0/550 [00:00<?, ?it/s]

  EVAL = {'rouge1': 0.486302119720683, 'rouge2': 0.26812704244115926, 'rougeL': 0.4538395584860664, 'rougeLsum': 0.45384665783885525, 'bert_score_P': 0.8427534103393555, 'bert_score_R': 0.8297354578971863, 'bert_score_F1': 0.8355132937431335}
lang: pl_PL, model: experiments/qg/google/umt5-base_all, data file: /mnt/data/factcheck/qg/squad-pl/test.jsonl
  loaded 3805 samples


  0%|          | 0/119 [00:00<?, ?it/s]

  EVAL = {'rouge1': 0.3639035589476781, 'rouge2': 0.20241942329968257, 'rougeL': 0.3483770292382762, 'rougeLsum': 0.3484797173067946, 'bert_score_P': 0.8161351680755615, 'bert_score_R': 0.8023266196250916, 'bert_score_F1': 0.8083663582801819}
lang: sk_SK, model: experiments/qg/google/umt5-base_all, data file: /mnt/data/factcheck/qg/sk-quad-220614/sk-quad-220614-dev.json
  loaded 7808 samples


  0%|          | 0/244 [00:00<?, ?it/s]

  EVAL = {'rouge1': 0.5935037219019206, 'rouge2': 0.4272806506920769, 'rougeL': 0.5710327764377565, 'rougeLsum': 0.5710777537603315, 'bert_score_P': 0.866708517074585, 'bert_score_R': 0.8554693460464478, 'bert_score_F1': 0.8601471781730652}


In [16]:
DEV_FILE_ALL = "/mnt/data/factcheck/qg/squad-cs_en_pl_sk/dev.jsonl" #ALL
DEV_FILE_CS = "/mnt/data/factcheck/qg/squad-cs/dev-v1.1.json"
DEV_FILE_EN = "/mnt/data/factcheck/qg/squad-en//dev-v1.1.jsonl"
DEV_FILE_PL = "/mnt/data/factcheck/qg/squad-pl/test.jsonl"
DEV_FILE_SK = "/mnt/data/factcheck/qg/sk-quad-220614/sk-quad-220614-dev.json"


cfgs = [
    {"lang": "cs_CZ", "model": "original", "data_file": DEV_FILE_CS},
    {"lang": "en_US", "model": "original", "data_file": DEV_FILE_EN},
]

results = evaluate_quality(cfgs, "/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qg/results.jsonl")

lang: cs_CZ, model: ORIGINAL, data file: /mnt/data/factcheck/qg/squad-cs/dev-v1.1.json
  loaded 11722 samples


You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565, and set the legacy attribute accordingly.


  0%|          | 0/367 [00:00<?, ?it/s]

  EVAL = {'rouge1': 0.09841955317631386, 'rouge2': 0.02848823438655762, 'rougeL': 0.09185948321821162, 'rougeLsum': 0.09185807029510903, 'bert_score_P': 0.6930422782897949, 'bert_score_R': 0.6721824407577515, 'bert_score_F1': 0.6818882822990417}
lang: en_US, model: ORIGINAL, data file: /mnt/data/factcheck/qg/squad-en//dev-v1.1.jsonl
  loaded 17598 samples


  0%|          | 0/550 [00:00<?, ?it/s]

  EVAL = {'rouge1': 0.49879397519319624, 'rouge2': 0.2840094098384389, 'rougeL': 0.4634463180184327, 'rougeLsum': 0.46345589394595976, 'bert_score_P': 0.839036226272583, 'bert_score_R': 0.8369964361190796, 'bert_score_F1': 0.8373371362686157}


In [25]:
def compare_results_qg(result_jsonls):
    data = []
    for rjsonl in result_jsonls:
        data += read_jsonl(rjsonl)
    # for d in data:
    #     if "bert_score_R1" in d["eval"]:
    #         t = d["eval"]["bert_score_R1"]
    #         d["eval"]["bert_score_R"] = t
    #         del d["eval"]["bert_score_R1"]
    # write_jsonl(result_jsonls[0], data)
    # return
    df = pd.DataFrame(data)
    models = ['/'.join(m.split("/")[-3:]) for m in df.model]
    df["model"] = models
    df["rouge1"] = [100*e["rouge1"] for e in df["eval"]]
    df["rouge2"] = [100*e["rouge2"] for e in df["eval"]]
    df["rougeL"] = [100*e["rougeL"] for e in df["eval"]]
    df["rougeLsum"] = [100*e["rougeLsum"] for e in df["eval"]]
    df["bert_score_P"] = [100*e["bert_score_P"] for e in df["eval"]]
    df["bert_score_R"] = [100*e["bert_score_R"] for e in df["eval"]]
    df["bert_score_F1"] = [100*e["bert_score_F1"] for e in df["eval"]]
    df = df[["lang", "model", "rouge1", "rouge2", "rougeL", "rougeLsum", "bert_score_P", "bert_score_R", "bert_score_F1"]]
    df.sort_values("lang", inplace=True)
    return df

df = compare_results_qg([
    "/mnt/personal/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qg/results.jsonl"
])

df

Unnamed: 0,lang,model,rouge1,rouge2,rougeL,rougeLsum,bert_score_P,bert_score_R,bert_score_F1
0,cs_CZ,qg/google/mt5-large_cs,37.288131,19.425672,34.744723,34.73332,80.542296,79.557836,79.971039
12,cs_CZ,qg/google/umt5-base_all,35.302872,17.73819,32.91426,32.920015,80.186731,78.828305,79.422516
16,cs_CZ,original,9.841955,2.848823,9.185948,9.185807,69.304228,67.218244,68.188828
4,cs_CZ,google/mt5-large_all/checkpoint-126000,37.088151,19.183201,34.605186,34.601553,80.72356,79.360574,79.957223
8,cs_CZ,qg/google/umt5-base_cs_CZ,33.828129,16.569281,31.540587,31.549525,79.821813,78.364092,79.007238
5,en_US,google/mt5-large_all/checkpoint-126000,49.712482,27.80296,46.379182,46.381833,84.609491,83.365786,83.918148
9,en_US,qg/google/umt5-base_en_US,48.035762,26.163355,44.866181,44.863262,84.208149,82.750547,83.404893
1,en_US,google/mt5-large_en/checkpoint-64000,49.954044,28.132323,46.689003,46.679203,84.717369,83.349091,83.961731
13,en_US,qg/google/umt5-base_all,48.630212,26.812704,45.383956,45.384666,84.275341,82.973546,83.551329
17,en_US,original,49.879398,28.400941,46.344632,46.345589,83.903623,83.699644,83.733714


In [26]:
print(df[~df.model.str.contains("umt5")][['lang', 'model', 'rouge1', 'rouge2', 'rougeL', 'bert_score_F1']].to_latex(float_format="%.1f", index=False))

\begin{tabular}{llrrrr}
\toprule
 lang &                                  model &  rouge1 &  rouge2 &  rougeL &  bert\_score\_F1 \\
\midrule
cs\_CZ &                 qg/google/mt5-large\_cs &    37.3 &    19.4 &    34.7 &           80.0 \\
cs\_CZ &                               original &     9.8 &     2.8 &     9.2 &           68.2 \\
cs\_CZ & google/mt5-large\_all/checkpoint-126000 &    37.1 &    19.2 &    34.6 &           80.0 \\
en\_US & google/mt5-large\_all/checkpoint-126000 &    49.7 &    27.8 &    46.4 &           83.9 \\
en\_US &   google/mt5-large\_en/checkpoint-64000 &    50.0 &    28.1 &    46.7 &           84.0 \\
en\_US &                               original &    49.9 &    28.4 &    46.3 &           83.7 \\
pl\_PL & google/mt5-large\_all/checkpoint-126000 &    37.2 &    20.9 &    35.5 &           81.1 \\
pl\_PL &                 qg/google/mt5-large\_pl &    37.5 &    21.4 &    35.9 &           81.3 \\
sk\_SK & google/mt5-large\_all/checkpoint-126000 &    61.2 &    44.7 

  print(df[~df.model.str.contains("umt5")][['lang', 'model', 'rouge1', 'rouge2', 'rougeL', 'bert_score_F1']].to_latex(float_format="%.1f", index=False))


### CS
`qg/ctu-aic/flan-t5-large_cs_CZ/bkp/checkpoint-12672`

*{'rouge1': 0.12963591761229037, 'rouge2': 0.023885136378797505, 'rougeL': 0.12183362628568076, 'rougeLsum': 0.1217667170120719}*

Wrong tokenization.

`qg/google/mt5-large_cs_CZ/checkpoint-59000`

*{'rouge1': 0.3694537189887302, 'rouge2': 0.19169694531279952, 'rougeL': 0.34545989041700687, 'rougeLsum': 0.34555287654344513}*

`qg/google/umt5-base_cs_CZ/bkp/checkpoint-6400`

*{'rouge1': 0.30904739385202495, 'rouge2': 0.14175598892317176, 'rougeL': 0.2874681246496394, 'rougeLsum': 0.28755029622694217}*

`qg/google/umt5-base_cs_CZ`

*{'rouge1': 0.3379285112695145, 'rouge2': 0.16586064794677036, 'rougeL': 0.3152727384483949, 'rougeLsum': 0.3154454957579359}*

### EN

`original`

*{'rouge1': 0.4989229792773411, 'rouge2': 0.28414967309411343, 'rougeL': 0.4635031675966017, 'rougeLsum': 0.4634836172778146}*

`qg/google/mt5-large_en/checkpoint-64000`

**{'rouge1': 0.4995418652865221, 'rouge2': 0.2815535664262393, 'rougeL': 0.4668170545426994, 'rougeLsum': 0.46687703737445807}**

`qg/google/flan-t5-large_en_US/checkpoint-7936`

*{'rouge1': 0.5050762371759718, 'rouge2': 0.28730345133694524, 'rougeL': 0.4710796839809428, 'rougeLsum': 0.47103764958291705}*

`qg/google/umt5-base_en_US`

*{'rouge1': 0.48045112431026354, 'rouge2': 0.26173784145711765, 'rougeL': 0.4487262083667449, 'rougeLsum': 0.4487590621176192}*

### PL

`qg/google/mt5-large_pl/checkpoint-34000`

*{'rouge1': 0.3670461463062099, 'rouge2': 0.20664767638201073, 'rougeL': 0.3513302909076887, 'rougeLsum': 0.3514851628255672}*

`qg/google/flan-t5-large_pl_PL`

*{'rouge1': 0.2651964323903675, 'rouge2': 0.10999222537607427, 'rougeL': 0.25368293682015525, 'rougeLsum': 0.2532975070425204}*

`qg/google/umt5-base_pl_PL`

*{'rouge1': 0.31209527473925364, 'rouge2': 0.16440549183090475, 'rougeL': 0.29680613676413997, 'rougeLsum': 0.29705324007033296}*

### SK
`qg/google/mt5-large_sk_SK/checkpoint-37000`

*{'rouge1': 0.2947516032244862, 'rouge2': 0.13924328095469365, 'rougeL': 0.2711889534899373, 'rougeLsum': 0.2710563599995189}*

`qg/google/umt5-base_sk_SK`
*{'rouge1': 0.5825675978707048, 'rouge2': 0.41734228678000224, 'rougeL': 0.5607684287432402, 'rougeLsum': 0.560618385882808}*

### ALL

`qg/google/umt5-base_all`

**ALL** *{'rouge1': 0.4572821393722631, 'rouge2': 0.2663545003075309, 'rougeL': 0.43077904314996096, 'rougeLsum': 0.43072475188146}*

**CS** *{'rouge1': 0.35290062436342096, 'rouge2': 0.17738150128215946, 'rougeL': 0.3292826941280448, 'rougeLsum': 0.3290628070321041}*

**EN** *{'rouge1': 0.48636008671818154,'rouge2': 0.2681378422877396, 'rougeL': 0.4539385644052849, 'rougeLsum': 0.4539682258809208}*

**PL** *{'rouge1': 0.36408309004184375, 'rouge2': 0.20234412375822144, 'rougeL': 0.3484808037947017, 'rougeLsum': 0.3484233231283168}*

**SK** *{'rouge1': 0.5935977294706721, 'rouge2': 0.4272281779017747, 'rougeL': 0.5711202748270763, 'rougeLsum': 0.5709580363216058}*

`qg/google/mt5-large_all/BKP/checkpoint-126000`

**ALL** *{'rouge1': 0.4712879903083035, 'rouge2': 0.2790771005345815,  'rougeL': 0.44412574432094815,  'rougeLsum': 0.4440181809982411}*

**CS** **{'rouge1': 0.37076781953491367, 'rouge2': 0.19188317187558196, 'rougeL': 0.34603606602711284, 'rougeLsum': 0.3458732199471718}**

**EN** *{'rouge1': 0.4972121341475896, 'rouge2': 0.27814731357909395, 'rougeL': 0.46391560371478174, 'rougeLsum': 0.46394229428970657}*

**PL** **{'rouge1': 0.3717109027777702, 'rouge2': 0.20806629045849118, 'rougeL': 0.35534417379213407, 'rougeLsum': 0.35545025872361985}**

**SK** **{'rouge1': 0.6120836094748034, 'rouge2': 0.4468872126953193, 'rougeL': 0.5898635208080331, 'rougeLsum': 0.5898424756234816}**

# Playground for model inference
see `scripts/wiki_qg.py` for question generation

In [3]:
# CZ
lang = "cs_CZ"
# model_args = ModelArguments(model_name_or_path="/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qg/facebook/mbart-large-cc25_cs_CZ/BEST/checkpoint-32000")

# SK
lang = "sk_SK"
model_args = ModelArguments(model_name_or_path="/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qg/google/mt5-large_sk_SK/checkpoint-37000")

In [4]:
tokenizer, model, data_collator = load_tokenizer_and_model(model_args, lang=lang, fp16=True)


In [9]:
# CS
# data = read_jsonl("/mnt/data/factcheck/qg/squad-cs/dev-v1.1.json")

# SK
data = read_jsonl("/mnt/data/factcheck/qg/sk-quad-220614/sk-quad-220614-dev.json")

In [10]:
data[0]

{'title': 'Vysoký grúň (Laborecká vrchovina)',
 'context': 'Cez vrch Vysoký grúň vedie hlavná  červená turistická značka, ktorá zároveň vedie po hlavnom karpatskom hrebeni cez najvýchodnejší bod Slovenska – trojmedzie (1207.7 Mnm) na vrchu Kremenec (1221.0 Mnm) a prechádza po slovensko-poľskej štátnej hranici cez viacero vrchov s viacerými panoramatickými vyhliadkami, ako napr. Kamenná lúka (1200.9 Mnm), Jarabá skala (1199.0 Mnm), Ďurkovec (1188.7 Mnm), Pľaša (1162.8 Mnm), ďalej cez Ruské sedlo (801.0 Mnm), vrchy Rypy (1002.7 Mnm), Strop, (1011.2 Mnm), Černiny (929.4 Mnm), Laborecký priesmyk (684.0 Mnm) až k Duklianskemu priesmyku (502.0 Mnm).',
 'question': 'Akú nadmorskú výšku má vrch Kremenec?',
 'answer': '1221.0 Mnm'}

In [12]:
def predict(model, tokenizer, inputs, max_source_length=1024, padding=False):
    model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True, return_tensors="pt")
    with torch.no_grad():
        Y = model.generate(**model_inputs, max_new_tokens=768)
        predictions = tokenizer.batch_decode(
            Y, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )
    return predictions


sample = data[120]
answer = sample["answer"]
context = "Vzkaz na vojenské technice zaujal po sobotní části Pavlovy návštěvy Ukrajiny některá ukrajinská média. Píše o něm také kupříkladu agentura Unian, která zároveň informuje o Pavlově setkání s ukrajinskými vojáky v Dněpropetrovské oblasti, agentura Ukrinform, jež připomněla Pavlovo působení na vrcholné pozici v NATO, nebo server Obozrevatel. Ukrinform mimo jiné zaznamenal jednání prezidenta s místními činiteli o obnově Dněpropetrovské oblasti, nad níž Česko převzalo záštitu. Weby Hromadske nebo Jevropejska pravda upozornily na prezidentovo setkání s vysídlenými Ukrajinci. Ukrajinská média informovala už dříve o pátečním programu Pavla a Čaputové, kromě jejich oficiálních setkání s ukrajinskými činiteli si povšimla mimo jiné faktu, že státníci museli kvůli vzdušnému poplachu v pátek na čas do hotelového krytu. Ukrajinska pravda v souvislosti s návštěvou prezidentů Česka a Slovenska poznamenala, že vysoce postavení zahraniční představitelé od začátku ruské invaze jen zřídka zůstali na Ukrajině přes noc."
answer = "NATO"
print(textwrap.fill(context))
print(answer)
predict(model, tokenizer, [answer + "</s>" + context])

Vzkaz na vojenské technice zaujal po sobotní části Pavlovy návštěvy
Ukrajiny některá ukrajinská média. Píše o něm také kupříkladu agentura
Unian, která zároveň informuje o Pavlově setkání s ukrajinskými vojáky
v Dněpropetrovské oblasti, agentura Ukrinform, jež připomněla Pavlovo
působení na vrcholné pozici v NATO, nebo server Obozrevatel. Ukrinform
mimo jiné zaznamenal jednání prezidenta s místními činiteli o obnově
Dněpropetrovské oblasti, nad níž Česko převzalo záštitu. Weby
Hromadske nebo Jevropejska pravda upozornily na prezidentovo setkání s
vysídlenými Ukrajinci. Ukrajinská média informovala už dříve o
pátečním programu Pavla a Čaputové, kromě jejich oficiálních setkání s
ukrajinskými činiteli si povšimla mimo jiné faktu, že státníci museli
kvůli vzdušnému poplachu v pátek na čas do hotelového krytu.
Ukrajinska pravda v souvislosti s návštěvou prezidentů Česka a
Slovenska poznamenala, že vysoce postavení zahraniční představitelé od
začátku ruské invaze jen zřídka zůstali na Ukraj

['V ktorom štáte pôsobil prezident Pavlo Pavlov na vrcholnej pozicii?']

# Tokenizer experiments

In [1]:
from transformers import (
    AutoModel,
    AutoTokenizer,
)
model_id = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id)

accents = "áčďéěíňóřšťúůýž" # CS
accents += "ąćęłńóśźż" # PL
accents += "áäčďéíĺľňóôŕšťúýž" # SK
accents += accents.upper()
accents = set(c for c in accents)
new_tokens = accents - set(tokenizer.vocab.keys())

tokenizer.add_tokens(list(new_tokens))

model.resize_token_embeddings(len(tokenizer))

Embedding(32150, 1024)

In [3]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [4]:
model.push_to_hub(f"ctu-aic/flan-t5-large")

pytorch_model.bin:   0%|          | 0.00/3.00G [00:00<?, ?B/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

CommitInfo(commit_url='https://huggingface.co/ctu-aic/flan-t5-large/commit/44e85df0c017c2ee04a322a13ddd14d0674f357d', commit_message='Upload model', commit_description='', oid='44e85df0c017c2ee04a322a13ddd14d0674f357d', pr_url=None, pr_revision=None, pr_num=None)

In [5]:
tokenizer.push_to_hub(f"ctu-aic/flan-t5-large")

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/ctu-aic/flan-t5-large/commit/6e995ffb333de1a0238236976bc7a9271ddf1e3a', commit_message='Upload tokenizer', commit_description='', oid='6e995ffb333de1a0238236976bc7a9271ddf1e3a', pr_url=None, pr_revision=None, pr_num=None)

In [2]:
txt = "Nechť již hříšné saxofony ďáblů rozezvučí síň úděsnými tóny waltzu, tanga a quickstepu."
ids = tokenizer(txt)["input_ids"]
tokenizer.decode(ids)

'Nechť již hříš né saxofony ď áblů rozezvučí síň ú dě sný mi tóny waltzu, tanga a quickstepu.</s>'

In [24]:
ids

[1484,
 524,
 2,
 3,
 354,
 23,
 2,
 3,
 107,
 2,
 29,
 154,
 3,
 7,
 9,
 226,
 858,
 106,
 63,
 3,
 2,
 2975,
 115,
 40,
 2,
 3,
 9860,
 457,
 208,
 76,
 2,
 3,
 7,
 2,
 3,
 2,
 26,
 2,
 7,
 29,
 2,
 51,
 23,
 3,
 17,
 15742,
 63,
 3,
 5380,
 17,
 1000,
 6,
 3,
 8967,
 9,
 3,
 9,
 1704,
 7910,
 76,
 5,
 1]

In [14]:

new_tokens

{'Á',
 'Ä',
 'Í',
 'Ó',
 'Ô',
 'Ú',
 'Ý',
 'í',
 'ú',
 'ý',
 'Ą',
 'ą',
 'Ć',
 'ć',
 'Č',
 'č',
 'Ď',
 'ď',
 'Ę',
 'ę',
 'Ě',
 'ě',
 'Ĺ',
 'ĺ',
 'Ľ',
 'ľ',
 'Ł',
 'ł',
 'Ń',
 'ń',
 'Ň',
 'ň',
 'Ŕ',
 'ŕ',
 'Ř',
 'ř',
 'Ś',
 'ś',
 'Š',
 'š',
 'Ť',
 'ť',
 'Ů',
 'ů',
 'Ź',
 'ź',
 'Ż',
 'ż',
 'Ž',
 'ž'}

In [20]:
tokenizer.add_tokens(list(new_tokens))

50

Embedding(32100, 1024)