In [1]:
import numpy as np
from tqdm import tqdm

from collections import defaultdict, OrderedDict, Counter
from dataclasses import dataclass
import datetime as dt
from itertools import chain
import math
import os
import pathlib
from pathlib import Path
import pandas as pd
import unicodedata as ud
from time import time
from typing import Dict, Type, Callable, List, Union
import sys
import ujson

import torch
import transformers
from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
)
from datasets import load_dataset

from aic_nlp_utils.json import read_jsonl, read_json, write_json, write_jsonl
from aic_nlp_utils.encoding import nfc
from aic_nlp_utils.fever import fever_detokenize

%load_ext autoreload
%autoreload 2

**TODO** move elsewhere NLI models should be covered in own package. Currently it is here for convenience only.

In [2]:
# APPROACH = "full" # all generated data
# APPROACH = "balanced" # balanced classes
APPROACH = "balance_shuf" # balanced classes, shuffled
# APPROACH = "fever_size" # QACG data subsampled to Cs/EnFEVER dataset size

LANG, NER_DIR = "cs", "PAV-ner-CNEC"
# LANG, NER_DIR = "en", "stanza"
# LANG, NER_DIR = "sk", "crabz_slovakbert-ner" 
# LANG, NER_DIR = "pl", "stanza"
DATE = "20230801"

QG_DIR = "mt5-large_all-cp126k"
QACG_DIR = "mt5-large_all-cp156k"

# BELOW configuration is language-agnostic

WIKI_ROOT = f"/mnt/data/factcheck/wiki/{LANG}/{DATE}"
WIKI_CORPUS = f"{WIKI_ROOT}/paragraphs/{LANG}wiki-{DATE}-paragraphs.jsonl"

QACG_ROOT = f"{WIKI_ROOT}/qacg"

NLI_DIR = Path("nli", NER_DIR, QG_DIR, QACG_DIR)
NLI_ROOT = Path(QACG_ROOT, NLI_DIR)

SPLIT_DIR = Path("splits", NER_DIR, QG_DIR, QACG_DIR)
SPLIT_ROOT = Path(QACG_ROOT, SPLIT_DIR)

In [3]:
SPLIT_ROOT

PosixPath('/mnt/data/factcheck/wiki/cs/20230801/qacg/splits/PAV-ner-CNEC/mt5-large_all-cp126k/mt5-large_all-cp156k')

In [4]:
def import_corpus(corpus_file):
    # it already has correct format
    raw = read_jsonl(corpus_file, show_progress=True)
    for e in raw:
        e["id"] = nfc(e["id"])
        e["did"] = nfc(e["did"])
        e["text"] = nfc(e["text"])
    return raw


def generate_original_id2pid_mapping(corpus):
    original_id2pid = {}
    for pid, r in enumerate(corpus):
        original_id = r["id"]
        # assert original_id not in original_id2pid, f"original ID not unique! {original_id}"
        if original_id in original_id2pid:
            print(f"original ID not unique! {pid} {original_id}, previous pid: {original_id2pid[original_id]}")
        original_id2pid[original_id] = pid
    return original_id2pid

corpus = import_corpus(WIKI_CORPUS)
original_id2pid = generate_original_id2pid_mapping(corpus)

0.00it [00:00, ?it/s]

In [5]:
def prepare_nli_data(src_file, dst_file, corpus, original_id2pid, seed=1234):
    # imports data created for Evidence retrieval (ColBERTv2:prepare_data_wiki.ipynb)
    rng = np.random.RandomState(seed)
    recs = []
    counts = Counter()
    data = read_jsonl(src_file)
    for sample in tqdm(data):
        claim = sample["claim"]
        label = sample["label"]
        evidence_bids = sample["evidence"]
        assert len(evidence_bids) == 1, "More than single evidence not impemented (yet)" 
        context = corpus[original_id2pid[evidence_bids[0]]]["text"]
        recs.append({"claim": claim, "context": context, "label": label})
        counts[label] += 1
    rng.shuffle(recs)
    print(f"exporting {len(recs)}, label counts: {counts} to:\n {str(dst_file)}")
    write_jsonl(dst_file, recs, mkdir=True)

prepare_nli_data(Path(SPLIT_ROOT, f"train_{APPROACH}.jsonl"), Path(NLI_ROOT, f"train_{APPROACH}.jsonl"), corpus, original_id2pid, seed=1234)
prepare_nli_data(Path(SPLIT_ROOT, f"dev_{APPROACH}.jsonl"), Path(NLI_ROOT, f"dev_{APPROACH}.jsonl"), corpus, original_id2pid, seed=1235)
prepare_nli_data(Path(SPLIT_ROOT, f"test_{APPROACH}.jsonl"), Path(NLI_ROOT, f"test_{APPROACH}.jsonl"), corpus, original_id2pid, seed=1236)

100%|██████████| 107330/107330 [00:00<00:00, 733903.03it/s]


exporting 107330, label counts: Counter({'s': 53542, 'n': 35639, 'r': 18149}) to:
 /mnt/data/factcheck/wiki/cs/20230801/qacg/nli/PAV-ner-CNEC/mt5-large_all-cp126k/mt5-large_all-cp156k/train_fever_size.jsonl


100%|██████████| 9999/9999 [00:00<00:00, 908632.59it/s]


exporting 9999, label counts: Counter({'n': 3333, 'r': 3333, 's': 3333}) to:
 /mnt/data/factcheck/wiki/cs/20230801/qacg/nli/PAV-ner-CNEC/mt5-large_all-cp126k/mt5-large_all-cp156k/dev_fever_size.jsonl


100%|██████████| 9999/9999 [00:00<00:00, 963402.69it/s]


exporting 9999, label counts: Counter({'s': 3333, 'r': 3333, 'n': 3333}) to:
 /mnt/data/factcheck/wiki/cs/20230801/qacg/nli/PAV-ner-CNEC/mt5-large_all-cp126k/mt5-large_all-cp156k/test_fever_size.jsonl


# Combine Languages

In [6]:
def create_combined_split(src_files, dst_files, rng):
    # combined dataset has the same size as each individual language dataset
    data = [read_jsonl(src_file[1]) for src_file in src_files]
    langs = [src_file[0] for src_file in src_files]
    n = len(data)
    lens = list(set([len(d) for d in data]))
    if not all(l == lens[0] for l in lens[1:]):
        print(f"Warning differing lengths: {lens}")
    l = lens[0]
    psize = math.ceil(l/n)
    over = n*psize - l
    print(f"sampling {n} x {psize} = {n*psize}, {over} over to {l} to: {dst_files}")
    recs = []
    for lang, d in zip(langs, data):
        indices = rng.choice(len(d), psize, replace=False)
        rec = []
        for idx in indices:
            r = d[idx]
            r["lang"] = lang
            r["orig_idx"] = idx # the index in the original language claim file
            rec.append(r)
        recs += list(rec)
    del recs[:over]
    rng.shuffle(recs)
    assert len(recs) == l
    write_jsonl(dst_files, recs, mkdir=True)


APPROACH = "balanced_shuf"
rng = np.random.RandomState(1234)
for split in [f"train_{APPROACH}.jsonl", f"dev_{APPROACH}.jsonl", f"test_{APPROACH}.jsonl"]:
    create_combined_split([
        ("cs", Path("/mnt/data/factcheck/wiki/cs/20230801/qacg/splits/PAV-ner-CNEC/mt5-large_all-cp126k/mt5-large_all-cp156k", split)),
        ("en", Path("/mnt/data/factcheck/wiki/en/20230801/qacg/splits/stanza/mt5-large_all-cp126k/mt5-large_all-cp156k", split)),
        ("pl", Path("/mnt/data/factcheck/wiki/pl/20230801/qacg/splits/stanza/mt5-large_all-cp126k/mt5-large_all-cp156k", split)),
        ("sk", Path("/mnt/data/factcheck/wiki/sk/20230801/qacg/splits/crabz_slovakbert-ner/mt5-large_all-cp126k/mt5-large_all-cp156k", split)),
        ],
        Path("/mnt/data/factcheck/wiki/cs_en_pl_sk/20230801/qacg/splits", split),
        rng=rng)

sampling 4 x 46623 = 186492, 3 over to 186489 to: /mnt/data/factcheck/wiki/cs_en_pl_sk/20230801/qacg/splits/train_balanced_shuf.jsonl
sampling 4 x 4688 = 18752, 2 over to 18750 to: /mnt/data/factcheck/wiki/cs_en_pl_sk/20230801/qacg/splits/dev_balanced_shuf.jsonl
sampling 4 x 4537 = 18148, 2 over to 18146 to: /mnt/data/factcheck/wiki/cs_en_pl_sk/20230801/qacg/splits/test_balanced_shuf.jsonl


In [7]:
def create_sum_split(src_files, dst_files, rng):
    # the sum dataset simply concatenates (and shuffles) all source language datasets
    data = [read_jsonl(src_file[1]) for src_file in src_files]
    langs = [src_file[0] for src_file in src_files]
    recs = []
    for lang, d in zip(langs, data):
        indices = range(len(d))
        rec = []
        for idx in indices:
            r = d[idx]
            r["lang"] = lang
            r["orig_idx"] = idx # the index in the original language claim file
            rec.append(r)
        recs += list(rec)
    rng.shuffle(recs)
    write_jsonl(dst_files, recs, mkdir=True)


APPROACH = "balanced_shuf"
rng = np.random.RandomState(1234)
for split in [f"train_{APPROACH}.jsonl", f"dev_{APPROACH}.jsonl", f"test_{APPROACH}.jsonl"]:
    create_sum_split([
        ("cs", Path("/mnt/data/factcheck/wiki/cs/20230801/qacg/splits/PAV-ner-CNEC/mt5-large_all-cp126k/mt5-large_all-cp156k", split)),
        ("en", Path("/mnt/data/factcheck/wiki/en/20230801/qacg/splits/stanza/mt5-large_all-cp126k/mt5-large_all-cp156k", split)),
        ("pl", Path("/mnt/data/factcheck/wiki/pl/20230801/qacg/splits/stanza/mt5-large_all-cp126k/mt5-large_all-cp156k", split)),
        ("sk", Path("/mnt/data/factcheck/wiki/sk/20230801/qacg/splits/crabz_slovakbert-ner/mt5-large_all-cp126k/mt5-large_all-cp156k", split)),
        ],
        Path("/mnt/data/factcheck/wiki/sum_cs_en_pl_sk/20230801/qacg/splits", split),
        rng=rng)

In [6]:
def prepare_nli_data_combined(src_files, dst_files, lang2fcorpus, seed=1234):
    # imports data created for Evidence retrieval (ColBERTv2:prepare_data_wiki.ipynb)
    srcs = [read_jsonl(src_file) for src_file in tqdm(src_files, desc="reading sources")]

    for lang, fcorpus in lang2fcorpus.items():
        print(f"loading corpus for {lang.upper()} from '{fcorpus}'")
        corpus = import_corpus(fcorpus)
        original_id2pid = generate_original_id2pid_mapping(corpus)
        for src in srcs:
            for sample in src:
                if sample["lang"] == lang:
                    evidence_bids = sample["evidence"]
                    assert len(evidence_bids) == 1, "More than single evidence not impemented (yet)" 
                    context = corpus[original_id2pid[evidence_bids[0]]]["text"]
                    sample["context"] = context
    for src, dst_file in zip(srcs, dst_files):
        print(f"exporting {len(src)} to:\n {str(dst_file)}")
        write_jsonl(dst_file, src, mkdir=True)

In [None]:
DATE = '20230801'
prepare_nli_data_combined(
    src_files=[
        f"/mnt/data/factcheck/wiki/cs_en_pl_sk/{DATE}/qacg/splits/dev_balanced.jsonl",
        f"/mnt/data/factcheck/wiki/cs_en_pl_sk/{DATE}/qacg/splits/test_balanced.jsonl",
        f"/mnt/data/factcheck/wiki/cs_en_pl_sk/{DATE}/qacg/splits/train_balanced.jsonl",
    ],
    dst_files=[
        f"/mnt/data/factcheck/wiki/cs_en_pl_sk/{DATE}/qacg/nli/dev_balanced.jsonl",
        f"/mnt/data/factcheck/wiki/cs_en_pl_sk/{DATE}/qacg/nli/test_balanced.jsonl",
        f"/mnt/data/factcheck/wiki/cs_en_pl_sk/{DATE}/qacg/nli/train_balanced.jsonl",
    ],
    lang2fcorpus={
        "cs": f"/mnt/data/factcheck/wiki/cs/{DATE}/paragraphs/cswiki-{DATE}-paragraphs.jsonl",
        "en": f"/mnt/data/factcheck/wiki/en/{DATE}/paragraphs/enwiki-{DATE}-paragraphs.jsonl",
        "pl": f"/mnt/data/factcheck/wiki/pl/{DATE}/paragraphs/plwiki-{DATE}-paragraphs.jsonl",
        "sk": f"/mnt/data/factcheck/wiki/sk/{DATE}/paragraphs/skwiki-{DATE}-paragraphs.jsonl",
    })

In [9]:
# use the same procedure for the SUM dataset
DATE = '20230801'
prepare_nli_data_combined(
    src_files=[
        f"/mnt/data/factcheck/wiki/sum_cs_en_pl_sk/{DATE}/qacg/splits/dev_balanced.jsonl",
        f"/mnt/data/factcheck/wiki/sum_cs_en_pl_sk/{DATE}/qacg/splits/test_balanced.jsonl",
        f"/mnt/data/factcheck/wiki/sum_cs_en_pl_sk/{DATE}/qacg/splits/train_balanced.jsonl",
    ],
    dst_files=[
        f"/mnt/data/factcheck/wiki/sum_cs_en_pl_sk/{DATE}/qacg/nli/dev_balanced.jsonl",
        f"/mnt/data/factcheck/wiki/sum_cs_en_pl_sk/{DATE}/qacg/nli/test_balanced.jsonl",
        f"/mnt/data/factcheck/wiki/sum_cs_en_pl_sk/{DATE}/qacg/nli/train_balanced.jsonl",
    ],
    lang2fcorpus={
        "cs": f"/mnt/data/factcheck/wiki/cs/{DATE}/paragraphs/cswiki-{DATE}-paragraphs.jsonl",
        "en": f"/mnt/data/factcheck/wiki/en/{DATE}/paragraphs/enwiki-{DATE}-paragraphs.jsonl",
        "pl": f"/mnt/data/factcheck/wiki/pl/{DATE}/paragraphs/plwiki-{DATE}-paragraphs.jsonl",
        "sk": f"/mnt/data/factcheck/wiki/sk/{DATE}/paragraphs/skwiki-{DATE}-paragraphs.jsonl",
    })

reading sources: 100%|██████████| 3/3 [00:07<00:00,  2.60s/it]

loading corpus for CS from '/mnt/data/factcheck/wiki/cs/20230801/paragraphs/cswiki-20230801-paragraphs.jsonl'





0.00it [00:00, ?it/s]

loading corpus for EN from '/mnt/data/factcheck/wiki/en/20230801/paragraphs/enwiki-20230801-paragraphs.jsonl'


0.00it [00:00, ?it/s]

loading corpus for PL from '/mnt/data/factcheck/wiki/pl/20230801/paragraphs/plwiki-20230801-paragraphs.jsonl'


0.00it [00:00, ?it/s]

loading corpus for SK from '/mnt/data/factcheck/wiki/sk/20230801/paragraphs/skwiki-20230801-paragraphs.jsonl'


0.00it [00:00, ?it/s]

exporting 120348 to:
 /mnt/data/factcheck/wiki/sum_cs_en_pl_sk/20230801/qacg/nli/dev_balanced.jsonl
exporting 113760 to:
 /mnt/data/factcheck/wiki/sum_cs_en_pl_sk/20230801/qacg/nli/test_balanced.jsonl
exporting 1180836 to:
 /mnt/data/factcheck/wiki/sum_cs_en_pl_sk/20230801/qacg/nli/train_balanced.jsonl


# Test Models

In [8]:
NLI_ROOT

PosixPath('/mnt/data/factcheck/wiki/sk/20230801/qacg/nli/crabz_slovakbert-ner/mt5-large_all-cp126k/mt5-large_all-cp156k')

In [3]:
raw_nli = load_dataset("json", data_files={
    "train": str(Path(NLI_ROOT, "train_balanced.jsonl")),
    "dev": str(Path(NLI_ROOT, "dev_balanced.jsonl")),
    "test": str(Path(NLI_ROOT, "test_balanced.jsonl"))
    })

# raw_nli = load_dataset("json", data_files={
#     "train": str(Path(CLAIM_ROOT, "train_nli_sr.jsonl")),
#     "dev": str(Path(CLAIM_ROOT, "dev_nli_sr.jsonl")),
#     "test": str(Path(CLAIM_ROOT, "test_nli_sr.jsonl"))
#     })

Found cached dataset json (/home/drchajan/.cache/huggingface/datasets/json/default-696677cf81c28f6c/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)


  0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
for a in raw_nli.:
    print(a)

train
dev
test


In [13]:
for i, c in enumerate(raw_nli["train"]["context"]):
    if len(c) < 30:
        print(i, c)

In [14]:
# model_id = "/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/nli/google/flan-t5-base_cs_CZ/checkpoint-896"
# model_id = "/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/nli/google/flan-t5-large_cs_CZ/checkpoint-1568"
# model_id = "/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/nli/google/flan-t5-large_cs_CZ-20230801/checkpoint-6144"
# model_id = "/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/nli/google/flan-t5-large_cs_CZ-20230801/checkpoint-23936"

# SUPPORT/REFUTE only models
# model_id = "/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/nli/google/umt5-base_cs_CZ-20230801_sr/checkpoint-23936"
# model_id = "/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/nli/google/flan-t5-large_cs_CZ-20230801_sr/checkpoint-256"

# tokenizer = AutoTokenizer.from_pretrained(model_id)
# model = AutoModelForSeq2SeqLM.from_pretrained(model_id, device_map="auto")

# Encoder (Softmax) models
# model_id = "ctu-aic/xlm-roberta-large-xnli-csfever"
# model_id = "/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/nli/deepset/xlm-roberta-large-squad2_cs_CZ-20230801_lr1e-6/checkpoint-41792"
# model_id = "/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/nli/deepset/xlm-roberta-large-squad2_en_US-20230220_lr1e-6/checkpoint-48416"
# model_id = "/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/nli_fever/deepset/xlm-roberta-large-squad2_en_US_lr1e-6/checkpoint-132864"
model_id = "/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/nli/deepset/xlm-roberta-large-squad2_sk_SK-20230801_balanced_lr1e-6/checkpoint-202112"

id2label = {0: "s", 1: "r", 2: "n"}
label2id = {"s": 0, "r": 1, "n": 1}
tokenizer = AutoTokenizer.from_pretrained(model_id)
# model = AutoModelForSequenceClassification.from_pretrained(model_id, device_map="auto", id2label=id2label, label2id=label2id)
model = AutoModelForSequenceClassification.from_pretrained(model_id, device_map="auto")


  warn("The installed version of bitsandbytes was compiled without GPU support. "


/home/drchajan/devel/python/FC/fc_env_hflarge/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so: undefined symbol: cadam32bit_grad_fp32




In [15]:
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score
from aic_nlp_utils.batch import batch_apply

def split_predict(model, split, batch_size=128, device="cuda", max_length=128):
    def predict(inputs):
        X = tokenizer(inputs, max_length=max_length, padding=True, truncation=True, return_tensors="pt")
        input_ids = X["input_ids"].to(device)
        attention_mask = X["attention_mask"].to(device)
        with torch.no_grad():
            Y = model(input_ids=input_ids, attention_mask=attention_mask).logits
            return Y
        
    inputs = [[claim, context] for claim, context in zip(split["claim"],  split["context"])]
    # inputs = [[context, claim] for claim, context in zip(split["claim"],  split["context"])] # SWITCHED CTX and CLAIM!!!
    Ys = batch_apply(predict, inputs, batch_size=batch_size, show_progress=True)
    Y = torch.vstack(Ys)
    C = [model.config.id2label[id_.item()] for id_ in Y.argmax(dim=1)]
    T = [l for l in split["label"]]
    return Y, C, T

# Y, C, T = split_predict(model, raw_nli["test"], device="cuda", max_length=128) # FAST
# Y, C, T = split_predict(model, raw_nli["test"], device="cuda", max_length=512) # CORRECT, set to model maximum input length
Y, C, T = split_predict(model, raw_nli["test"], device="cpu", max_length=512) # CORRECT, set to model maximum input length

  0%|          | 0/223 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [58]:
print(f"acc: {accuracy_score(T, C)}")
print(f"F1: {f1_score(T, C, average='macro')}")
# print(f"cm:\n{confusion_matrix(T, C)}")
print()
print(f"cm:\n{confusion_matrix(T, C, labels=['s', 'r', 'n'])}")
print()
print(f"C={Counter(C)}")
print(f"T={Counter(T)}")

acc: 0.6492715709710883
F1: 0.5975708877212096

cm:
[[7657  233  210]
 [3508 1488 1310]
 [ 591 1924 5250]]

C=Counter({'s': 11756, 'n': 6770, 'r': 3645})
T=Counter({'s': 8100, 'n': 7765, 'r': 6306})


In [53]:
print(f"acc: {accuracy_score(T, C)}")
print(f"F1: {f1_score(T, C, average='macro')}")
# print(f"cm:\n{confusion_matrix(T, C)}")
print()
print(f"cm:\n{confusion_matrix(T, C, labels=['s', 'r', 'n'])}")
print()
print(f"C={Counter(C)}")
print(f"T={Counter(T)}")

acc: 0.9010419015831491
F1: 0.8978291959432556

cm:
[[7504  568   28]
 [1120 5171   15]
 [ 195  268 7302]]

C=Counter({'s': 8819, 'n': 7345, 'r': 6007})
T=Counter({'s': 8100, 'n': 7765, 'r': 6306})


In [7]:
def preprocess_function(examples):
    max_length = 128
    claims = examples["claim"]
    contexts = examples["context"]
    targets = examples["label"]
    inputs = [claim + "</s>" + context for claim, context in zip(claims, contexts)]
    # inputs = [context + "</s>" + claim for claim, context in zip(claims, contexts)]
    model_inputs = tokenizer(inputs, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt")
    # labels = tokenizer(targets, max_length=3, padding="max_length", truncation=True, return_tensors="pt")
    # labels = labels["input_ids"]
    # labels[labels == tokenizer.pad_token_id] = -100
    # model_inputs["labels"] = labels
    return model_inputs

# tokenized_nli = raw_nli.map(preprocess_function, batched=True,  
#                     #   remove_columns=raw_nli["train"].column_names,
#                       load_from_cache_file=False)
# tokenized_nli.set_format(type='torch', columns=['input_ids', 'attention_mask'])

In [37]:
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score

def split_predict(model, split, batch_size=256):
    def predict_batch(inputs):
        Y = model(input_ids=inputs["input_ids"].to("cuda"), attention_mask=inputs["attention_mask"].to("cuda")).logits.argmax(dim=1)
        C = [model.config.id2label[id_.item()] for id_ in Y]
        out = {"pred": C}
        return out
    
    split = split.map(predict_batch, batch_size=batch_size, batched=True)
    C = split["pred"]
    T = tokenized_nli["dev"]["label"]
    return C, T

def split_predict_generate(model, split, batch_size=256):
    def predict_batch(inputs):
        Y = model.generate(input_ids=inputs["input_ids"].to("cuda"), attention_mask=inputs["attention_mask"].to("cuda"))
        out = {"pred": tokenizer.batch_decode(Y, skip_special_tokens=True)}
        return out
    
    split = split.map(predict_batch, batch_size=batch_size, batched=True)
    C = split["pred"]
    T = tokenized_nli["dev"]["label"]
    return C, T

# C, T = split_predict(model, tokenized_nli["dev"], batch_size=64)
# C, T = split_predict_generate(model, tokenized_nli["dev"])

In [47]:
# old type of models used in FactSearch
from prediction.nli import SupportRefuteNEIModel # Make OBSOLETE
from sentence_transformers.cross_encoder import CrossEncoder

model_id = "ctu-aic/xlm-roberta-large-xnli-csfever"
id2label = {0: "s", 1: "r", 2: "n"}
label2id = {"s": 0, "r": 1, "n": 1}
model = CrossEncoder(model_id, device="cuda")

In [49]:
from sklearn.metrics import accuracy_score, confusion_matrix
from aic_nlp_utils.batch import batch_apply

def split_predict_generate(model, split, batch_size=128):
    def predict(inputs):
        X = tokenizer(inputs, max_length=128, padding="max_length", truncation=True, return_tensors="pt")
        input_ids = X["input_ids"].to("cuda")
        attention_mask = X["attention_mask"].to("cuda")
        Y = model.generate(input_ids=input_ids, attention_mask=attention_mask)
        C = tokenizer.batch_decode(Y, skip_special_tokens=True)
        return C

def split_predict_crossencoder(model, split, batch_size=10*128):
    def predict(inputs):
        print(len(inputs))
        Y = model.predict(inputs).argmax(axis=1)
        C = [id2label[id_.item()] for id_ in Y]
        return C
        
    # SWITCHED CTX and CLAIM!!!
    inputs = [[context, claim] for claim, context in zip(split["claim"],  split["context"])]
    C = batch_apply(predict, inputs, batch_size=batch_size)
    T = [l for l in split["label"]]
    return C, T

# C, T = split_predict_generate(model, raw_nli["dev"])
C, T = split_predict_crossencoder(model, raw_nli["dev"])
print(f"acc: {accuracy_score(T, C)}")
print(f"cm:\n{confusion_matrix(T, C)}")

1280
1280
1280
1280
1280
1280
1280
1280
1280
1280
1280
1280
1280
1280
1280
1280
1280
1280
1280
647
acc: 0.446749709616694
cm:
[[   0    0    0]
 [2359 1251 8553]
 [2504  397 9903]]


In [31]:
C, T = split_predict(model, raw_nli["test"])
print(f"acc: {accuracy_score(T, C)}")
print(f"cm:\n{confusion_matrix(T, C)}")

acc: 0.5060786993741457
cm:
[[4183   68  188]
 [2349  826 1184]
 [2531  546 2026]]


In [26]:
C, T = split_predict(model, raw_nli["train"])
print(f"acc: {accuracy_score(T, C)}")
print(f"cm:\n{confusion_matrix(T, C)}")

acc: 0.49263804969579006
cm:
[[    0     0     0     0     0     0     0     0     0     0     0]
 [    0     0     0     0     0     0     0     0     0     0     0]
 [    0     0     0     0     0     0     0     0     0     0     0]
 [    0     0     0     0     0     0     0     0     0     0     0]
 [    0     0     0     0     0     0     0     0     0     0     0]
 [    0     0     0     0     0     0     0     0     0     0     0]
 [    0     0     0     0     0     0     0     0     0     0     0]
 [    0     0     0     0     0     0     0     0     0     0     0]
 [    0     0     1     0     0     0     0     0 41839  1438  4126]
 [    1     0     0     1     0     0     0     0 19317  7391 17376]
 [    1     1     1     0     1     1     2     1 23257  6941 21133]]
