# Fine-tune Idefics2 on your own use-case

<p align="center">
    <img src="https://huggingface.co/HuggingFaceM4/idefics-80b/resolve/main/assets/IDEFICS.png" alt="Idefics-Obelics logo" width="300" height="300">
</p>

This tutorial showcases how one can fine-tune Idefics2 on their own use-case.

Idefics2 is an open multimodal model that accepts arbitrary sequences of image and text inputs and produces text outputs. The model can answer questions about images, describe visual content, create stories grounded on multiple images, or simply behave as a pure language model without visual inputs. It improves upon Idefics1, significantly enhancing capabilities around OCR, document understanding and visual reasoning.

Read more about Idefics2 this [blogpost](https://huggingface.co/blog/idefics2) and the [model card](https://huggingface.co/HuggingFaceM4/idefics2-8b).

# Setup

We first setup the environment with the primary necessary libraries and login into Hugging Face.

In [3]:
!pip install -q git+https://github.com/huggingface/transformers.git
!pip install -q accelerate datasets peft bitsandbytes

In [21]:
!tensorboard --logdir_spec=name1:/kaggle/working/test_model/runs/Jul10_14-12-01_7d4e82ffdb37

E0000 00:00:1720705993.619687    1222 common_lib.cc:798] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/common_lib.cc:479
E0711 13:53:13.651337457    1222 oauth2_credentials.cc:238]            oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {grpc_status:2, created_time:"2024-07-11T13:53:13.651321541+00:00"}
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.16.2 at http://localhost:6006/ (Press CTRL+C to quit)
^C


In [4]:
import os
import torch
import json
from tqdm import tqdm
from huggingface_hub import login
from peft import LoraConfig, PeftModel, PeftConfig
from datasets import load_dataset, Image
from transformers import AutoProcessor, BitsAndBytesConfig, Idefics2ForConditionalGeneration, TrainingArguments, Trainer

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

token = "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"

login(token=token)

2024-07-12 12:23:50.732344: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-12 12:23:50.732467: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-12 12:23:50.865292: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /root/.cache/huggingface/token
Login successful


# Loading task dataset

In [5]:
train_dataset = load_dataset("json", data_files="/kaggle/input/hvmemes/HVVMemes/annotations/train.jsonl", data_dir="/kaggle/input/hvmemes/HVVMemes/images/", split="train").shuffle(seed=42)

Generating train split: 0 examples [00:00, ? examples/s]

In [6]:
train_dataset = train_dataset.map(lambda ex: {"image": "/kaggle/input/hvmemes/HVVMemes/images/" + ex["image"]})

train_dataset = train_dataset.cast_column("image", Image())

Map:   0%|          | 0/5552 [00:00<?, ? examples/s]

In [7]:
val_dataset = load_dataset("json", data_files="/kaggle/input/hvmemes/HVVMemes/annotations/dev.jsonl", data_dir="/kaggle/input/hvmemes/HVVMemes/images/", split="train")

val_dataset = val_dataset.map(lambda ex: {"image": "/kaggle/input/hvmemes/HVVMemes/images/" + ex["image"]})

val_dataset = val_dataset.cast_column("image", Image())

Generating train split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/650 [00:00<?, ? examples/s]

In [8]:
entities_column = [h + vil + vic + o for h, vil, vic, o in zip(train_dataset["hero"], train_dataset["villain"], train_dataset["victim"], train_dataset["other"])]

train_dataset = train_dataset.add_column("entities", entities_column)

In [9]:
entities_column = [h + vil + vic + o for h, vil, vic, o in zip(val_dataset["hero"], val_dataset["villain"], val_dataset["victim"], val_dataset["other"])]

val_dataset = val_dataset.add_column("entities", entities_column)

In [28]:
val_dataset[2]

{'OCR': 'HEY, WORKING\nCLASS TRUMP\nVÕTERS: HE JUST\nADDED ANOTHER\nGOLDMAN SACHS\nCEO TO HIS\nCABINET.\nTRUMP\nHATES YOU.\nHATES. YOU.\n- PATTON OSWALT\nAddieting Info\n',
 'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=224x225>,
 'hero': [],
 'villain': ['donald trump'],
 'victim': [],
 'other': ['working class', 'patton oswalt', 'goldman sachs'],
 'entities': ['donald trump',
  'working class',
  'patton oswalt',
  'goldman sachs']}

In [10]:
processor = AutoProcessor.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    do_image_splitting=False
)

processor_config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/460 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.64k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/92.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/1.04k [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


# Loading the model and the dataset

We load the model from the Hugging Face hub. `idefics2-8b` has gone through instruction fine-tuning on a large mixture of multimodal datasets and as such is a strong starting-point to fine-tune on your own use-case. We will start from this checkpoint.

In this tutorial, we will fine-tune (and evaluate) on the task of Document Visual Question Answering using a subset of the DocVQA dataset.

To accomodate the GPU poors, the default hyper-parameters in this tutorial are chosen so that the fine-tuning takes less than 32 GB of GPU memory. For instance, an V100 in Google Colab should be sufficient.

If you happen to have more ressources, you are encouraged to revisit some of these constraints, in particular:
- Using 4 bit quantization
- Lora fine-tuning
- Freezing the vision encoder
- Small batch size compensated with higher gradient accumulation degree
- Deactivate image splitting
- Using flash-attention

DEVICE = "cuda:0"
USE_LORA = False
USE_QLORA = True


# Three options for training, from the lowest precision training to the highest precision training:
# - QLora
# - Standard Lora
# - Full fine-tuning
if USE_QLORA or USE_LORA:
    lora_config = LoraConfig(
        r=8,
        lora_alpha=8,
        lora_dropout=0.1,
        target_modules='.*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$',
        use_dora=False if USE_QLORA else True,
        init_lora_weights="gaussian"
    )
    if USE_QLORA:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16
        )
    model = Idefics2ForConditionalGeneration.from_pretrained(
        "HuggingFaceM4/idefics2-8b",
        torch_dtype=torch.float16,
        quantization_config=bnb_config if USE_QLORA else None,
    )
    model.add_adapter(lora_config)
    model.enable_adapters()
else:
    model = Idefics2ForConditionalGeneration.from_pretrained(
        "HuggingFaceM4/idefics2-8b",
        torch_dtype=torch.float16,
        _attn_implementation="flash_attention_2", # Only available on A100 or H100
    ).to(DEVICE)



Let's look at an example. Each sample has an image of a document, a question (in multiple languages, although we'll use the English portion only), and a list of answers.



# Training loop

We first define the data collator which takes list of samples and return input tensors fed to the model. There are 4 tensors types we are interested:
- `input_ids`: these are the input indices fed to the language model
- `attention_mask`: the attention mask for the `input_ids` in the language model
- `pixel_values`: the (pre-processed) pixel values that encode the image(s). Idefics2 treats images in their native resolution (up to 980) and their native aspect ratio
- `pixel_attention_mask`: when multiple image(s) are packed into the same sample (or in the batch), attention masks for the images are necessary because of these images can have different sizes and aspect ratio. This masking ensures that the vision encoder properly forwards the images.


In [11]:
class MyDataCollator:
    def __init__(self, processor):
        self.processor = processor
        self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
            processor.tokenizer.additional_special_tokens.index("<image>")
        ]

    def __call__(self, examples):
        def gen_answer(ex):
            res_str = "{"
            
            cats = ["hero", "villain", "victim"]
            
            for cat in cats:
                res_str += cat + f": {ex[cat]}"
                if cat != cats[-1]:
                    res_str += ", "
                    
            return res_str + "}"
        def gen_question(ex):
            entities = ex["entities"]
            text = ex["OCR"]
    
            res_str = f"Given the meme and its text: " +  f'"{text}"' + f" , group the entities: {entities} in regards to the predefined format."         
            return res_str
        
        
        texts = []
        images = []
        for example in examples:
            image = example["image"]
            ocr_text = example["OCR"]
            answer = gen_answer(example)
            question = gen_question(example)
            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": "Answer in the following format: {hero: [], villain: [], victim: []}"},
                        {"type": "image"},
                        {"type": "text", "text": question}
                    ]
                },
                {
                    "role": "assistant",
                    "content": [
                        {"type": "text", "text": answer}
                    ]
                }
            ]
            text = processor.apply_chat_template(messages, add_generation_prompt=False)
            texts.append(text.strip())
            images.append([image])

        batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

        labels = batch["input_ids"].clone()
        labels[labels == processor.tokenizer.pad_token_id] = -100
        labels[labels == self.image_token_id] = -100
        batch["labels"] = labels

        return batch

data_collator = MyDataCollator(processor)

In [12]:
bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16
        )
model = Idefics2ForConditionalGeneration.from_pretrained(
        "HuggingFaceM4/idefics2-8b",
        torch_dtype=torch.float16,
        quantization_config=bnb_config
    
)
config = PeftConfig.from_pretrained("/kaggle/input/idefics2-train-checkpoints/test_model/checkpoint-5000")
model = PeftModel.from_pretrained(model,
                                  "/kaggle/input/idefics2-train-checkpoints/test_model/checkpoint-5000",
                                  is_trainable=False, # 👈 here,

)
# check if it's working
model.print_trainable_parameters()

config.json:   0%|          | 0.00/684 [00:00<?, ?B/s]

`low_cpu_mem_usage` was None, now set to True since model is quantized.


model.safetensors.index.json:   0%|          | 0.00/74.4k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/7 [00:00<?, ?it/s]

model-00001-of-00007.safetensors:   0%|          | 0.00/4.64G [00:00<?, ?B/s]

model-00002-of-00007.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00003-of-00007.safetensors:   0%|          | 0.00/4.90G [00:00<?, ?B/s]

model-00004-of-00007.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00005-of-00007.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00006-of-00007.safetensors:   0%|          | 0.00/4.83G [00:00<?, ?B/s]

model-00007-of-00007.safetensors:   0%|          | 0.00/4.25G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

trainable params: 0 || all params: 8,426,094,832 || trainable%: 0.0000


We will use HuggingFace Trainer.

In [14]:
training_args = TrainingArguments(
    num_train_epochs=2,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=8,
    warmup_steps=50,
    learning_rate=1e-4,
    weight_decay=0.01,
    logging_steps=25,
    output_dir="/kaggle/working/test_model",
    save_strategy="steps",
    save_steps=500,
    save_total_limit=10,
    eval_strategy="epoch",
    remove_unused_columns=False,
    report_to="tensorboard",
)



In [15]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=val_dataset, # You can also evaluate (loss) on the eval set, note that it will incur some additional GPU memory
)

RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 64.00M. That was not possible. There are 62.83M free.; (0x0x0_HBM0)

# Training and pushing to the hub

We have all the core building blocks now, so we fine-tune the model!

The training can take a few minutes depending on the hardware you use.

In [None]:
os.system(f'echo \"Commencing training...\n \"')
trainer.train()

In [None]:
trainer.save_model("/kaggle/working/test_model/res")

We push to the fine-tuned checkpoint to the hub!

# Evaluation

Let's evaluate the model. First, we can have a look at a qualitative generation from the model.

In [13]:
test_dataset = load_dataset("json", data_files="/kaggle/input/hvmemes/HVVMemes/annotations/dev_test.jsonl", data_dir="/kaggle/input/hvmemes/HVVMemes/images/", split="train")

test_dataset = test_dataset.map(lambda ex: {"image": "/kaggle/input/hvmemes/HVVMemes/images/" + ex["image"]})

test_dataset = test_dataset.cast_column("image", Image())

entities_column = [h + vil + vic + o for h, vil, vic, o in zip(test_dataset["hero"], test_dataset["villain"], test_dataset["victim"], test_dataset["other"])]

test_dataset = test_dataset.add_column("entities", entities_column)

Generating train split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/718 [00:00<?, ? examples/s]

In [14]:
def batch_over_data(data, batch_size=4):
    batches = []
    batch = []
    
    i = 1
    for row in data:
        if i % batch_size == 0 and i != 0:
            batch.append(row)
            batches.append(batch)
            batch = []
        else:
            batch.append(row)
        i += 1
    return batches

In [15]:
def tensorise_input(example):
    def gen_question(ex):
        entities = ex["entities"]
        text = ex["OCR"]
        
        res_str = f"Given the meme and its text: " +  f'"{text}"' + f" , group the entities: {entities} in regards to the predefined format."         
        return res_str
    
    image = example["image"]
    ocr_text = example["OCR"]
    question = gen_question(example)
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": "Answer in the following format: {hero: [], villain: [], victim: []}"},
                {"type": "image"},
                {"type": "text", "text": question}
            ]
        }
    ]
    text = processor.apply_chat_template(messages, add_generation_prompt=True)
    inputs = processor(text=[text.strip()], images=[image], return_tensors="pt", padding=True)
    
    return inputs

def format_pred(pred_str):
    pred_str = pred_str.replace("{hero: [", '{"hero": [')
    pred_str = pred_str.replace(" villain: [", ' "villain": [')
    pred_str = pred_str.replace(" victim: [", ' "victim": [')
    
    pred_str = pred_str.replace("'", '"')
    
    return json.loads(pred_str)
    
    
def score_accuracy(preds, test):
    preds = [format_pred(pred) for pred in preds]
    
    for pred, t in zip(preds, test):
        t_heroes = t["hero"]
        t_villains = t["villain"]
        t_victims = t["victim"]
        
        t_sum = len(t_heroes) + len(t_villains) + len(t_victims)
        
        acc = t_sum
        
        acc -= len([p for p in pred["hero"] if p not in t_heroes] + [t for t in t_heroes if t not in pred["hero"]])
        acc -= len([p for p in pred["villain"] if p not in t_villains] + [t for t in t_villains if t not in pred["villain"]])
        acc -= len([p for p in pred["victim"] if p not in t_victims] + [t for t in t_victims if t not in pred["victim"]])
        
        if acc < 0:
            return 0
        elif t_sum == 0 and acc == 0:
            return 1
        else:
            return acc / t_sum

test_dataset[0]

test_dataset.take(2)

torch.cuda.empty_cache()

In [51]:
def f1_for_class(preds, test, class_):
    preds = format_pred(preds)
    tp = len([tru_pred for tru_pred in preds[class_] if tru_pred in test[class_]])
    fp = len(preds[class_]) - tp
    fn = len([fneg for fneg in test[class_] if fneg not in preds[class_]])
    
    if tp + fp == 0:
        precision = 1
    else:
        precision = tp / (tp + fp)
    
    if tp + fn == 0:
        recall = 1
    else:
        recall = tp / (tp + fn)
    
    if precision + recall == 0:
        f1 = 0
    else:
        f1 = ((2 * precision * recall) / (precision + recall))
    
    return f1, precision, recall

In [42]:
model.eval()

classes = ["hero", "villain", "victim"]
f1s = [0, 0, 0]
precs = [0, 0, 0]
recall = [0, 0, 0]

avg_acc = 0
n = 0
broken_preds = 0

for ex in tqdm(test_dataset):
    #if n < 370:
    #    n += 1
    #    continue
    
    #if n == 370:
    #    avg_acc = 252.932
    
    test_inputs = tensorise_input(ex)
    
    gen_ids = model.generate(**test_inputs, max_new_tokens=64)
    preds = processor.batch_decode(gen_ids[:, test_inputs["input_ids"].size(1):], skip_special_tokens=True)
    
    avg_acc += score_accuracy([preds[0]], [ex])
    for idx, c in enumerate(classes):
        f1c, pc, rc = f1_for_class(preds[0], ex, c)
        f1s[idx] += f1c
        precs[idx] += pc
        recall[idx] += rc
        
    n += 1
    
    if n % 10 == 0:
        torch.cuda.empty_cache()
        print("Average accuracy {:.2f}%".format((avg_acc / (n - broken_preds)) * 100))
        print("Average f1 scores per class:    hero: {:.4f}  ,villain: {:.4f}  ,victim: {:.4f}".format(f1s[0] / n, f1s[1] / n, f1s[2] / n))
    
avg_acc /= (n - broken_preds)
print(avg_acc)

  1%|▏         | 10/718 [00:27<32:28,  2.75s/it]

Average accuracy 55.00%
Average f1 scores per class:    hero: 0.9000  ,villain: 0.6667  ,victim: 0.9000


  3%|▎         | 20/718 [01:00<42:52,  3.68s/it]

Average accuracy 50.00%
Average f1 scores per class:    hero: 0.8000  ,villain: 0.7500  ,victim: 0.8000


  4%|▍         | 30/718 [01:24<27:27,  2.39s/it]

Average accuracy 63.33%
Average f1 scores per class:    hero: 0.8667  ,villain: 0.8000  ,victim: 0.8667


  6%|▌         | 40/718 [01:51<28:59,  2.57s/it]

Average accuracy 60.00%
Average f1 scores per class:    hero: 0.8500  ,villain: 0.7250  ,victim: 0.8750


  7%|▋         | 50/718 [02:23<28:38,  2.57s/it]

Average accuracy 56.00%
Average f1 scores per class:    hero: 0.8600  ,villain: 0.6600  ,victim: 0.8600


  8%|▊         | 60/718 [02:53<34:40,  3.16s/it]

Average accuracy 55.83%
Average f1 scores per class:    hero: 0.8667  ,villain: 0.6444  ,victim: 0.8500


 10%|▉         | 70/718 [03:23<30:24,  2.82s/it]

Average accuracy 57.62%
Average f1 scores per class:    hero: 0.8857  ,villain: 0.6595  ,victim: 0.8429


 11%|█         | 80/718 [03:51<27:07,  2.55s/it]

Average accuracy 57.29%
Average f1 scores per class:    hero: 0.8750  ,villain: 0.6521  ,victim: 0.8125


 13%|█▎        | 90/718 [04:22<33:36,  3.21s/it]

Average accuracy 56.48%
Average f1 scores per class:    hero: 0.8667  ,villain: 0.6463  ,victim: 0.8000


 14%|█▍        | 100/718 [04:50<25:09,  2.44s/it]

Average accuracy 58.33%
Average f1 scores per class:    hero: 0.8700  ,villain: 0.6717  ,victim: 0.7900


 15%|█▌        | 110/718 [05:17<27:50,  2.75s/it]

Average accuracy 59.85%
Average f1 scores per class:    hero: 0.8636  ,villain: 0.6955  ,victim: 0.7909


 17%|█▋        | 120/718 [05:44<27:31,  2.76s/it]

Average accuracy 58.19%
Average f1 scores per class:    hero: 0.8583  ,villain: 0.6764  ,victim: 0.7917


 18%|█▊        | 130/718 [06:10<28:15,  2.88s/it]

Average accuracy 59.10%
Average f1 scores per class:    hero: 0.8692  ,villain: 0.6833  ,victim: 0.8000


 19%|█▉        | 140/718 [06:41<29:14,  3.04s/it]

Average accuracy 59.52%
Average f1 scores per class:    hero: 0.8786  ,villain: 0.6893  ,victim: 0.8000


 21%|██        | 150/718 [07:10<25:13,  2.66s/it]

Average accuracy 61.06%
Average f1 scores per class:    hero: 0.8867  ,villain: 0.7033  ,victim: 0.8000


 22%|██▏       | 160/718 [07:38<24:16,  2.61s/it]

Average accuracy 62.24%
Average f1 scores per class:    hero: 0.8938  ,villain: 0.7094  ,victim: 0.8063


 24%|██▎       | 170/718 [08:06<26:35,  2.91s/it]

Average accuracy 62.70%
Average f1 scores per class:    hero: 0.9000  ,villain: 0.7186  ,victim: 0.8059


 25%|██▌       | 180/718 [08:34<24:52,  2.77s/it]

Average accuracy 63.47%
Average f1 scores per class:    hero: 0.9056  ,villain: 0.7287  ,victim: 0.8093


 26%|██▋       | 190/718 [09:01<25:03,  2.85s/it]

Average accuracy 62.50%
Average f1 scores per class:    hero: 0.9053  ,villain: 0.7167  ,victim: 0.8140


 28%|██▊       | 200/718 [09:29<27:27,  3.18s/it]

Average accuracy 61.88%
Average f1 scores per class:    hero: 0.9050  ,villain: 0.7225  ,victim: 0.8133


 29%|██▉       | 210/718 [09:55<20:36,  2.43s/it]

Average accuracy 62.26%
Average f1 scores per class:    hero: 0.9048  ,villain: 0.7262  ,victim: 0.8175


 31%|███       | 220/718 [10:22<20:19,  2.45s/it]

Average accuracy 63.30%
Average f1 scores per class:    hero: 0.9091  ,villain: 0.7341  ,victim: 0.8167


 32%|███▏      | 230/718 [10:48<22:53,  2.81s/it]

Average accuracy 63.15%
Average f1 scores per class:    hero: 0.9087  ,villain: 0.7283  ,victim: 0.8203


 33%|███▎      | 240/718 [11:18<23:50,  2.99s/it]

Average accuracy 63.85%
Average f1 scores per class:    hero: 0.9083  ,villain: 0.7312  ,victim: 0.8236


 35%|███▍      | 250/718 [11:47<21:33,  2.76s/it]

Average accuracy 64.50%
Average f1 scores per class:    hero: 0.9080  ,villain: 0.7380  ,victim: 0.8307


 36%|███▌      | 260/718 [12:17<24:18,  3.19s/it]

Average accuracy 65.03%
Average f1 scores per class:    hero: 0.9077  ,villain: 0.7417  ,victim: 0.8321


 38%|███▊      | 270/718 [12:45<20:37,  2.76s/it]

Average accuracy 65.59%
Average f1 scores per class:    hero: 0.9111  ,villain: 0.7463  ,victim: 0.8383


 39%|███▉      | 280/718 [13:16<22:38,  3.10s/it]

Average accuracy 66.46%
Average f1 scores per class:    hero: 0.9143  ,villain: 0.7524  ,victim: 0.8440


 40%|████      | 290/718 [13:44<22:36,  3.17s/it]

Average accuracy 65.89%
Average f1 scores per class:    hero: 0.9103  ,villain: 0.7506  ,victim: 0.8425


 42%|████▏     | 300/718 [14:12<19:18,  2.77s/it]

Average accuracy 66.53%
Average f1 scores per class:    hero: 0.9133  ,villain: 0.7567  ,victim: 0.8478


 43%|████▎     | 310/718 [14:38<17:51,  2.63s/it]

Average accuracy 67.28%
Average f1 scores per class:    hero: 0.9161  ,villain: 0.7613  ,victim: 0.8527


 45%|████▍     | 320/718 [15:04<17:24,  2.62s/it]

Average accuracy 67.37%
Average f1 scores per class:    hero: 0.9125  ,villain: 0.7594  ,victim: 0.8573


 46%|████▌     | 330/718 [15:31<17:42,  2.74s/it]

Average accuracy 67.45%
Average f1 scores per class:    hero: 0.9121  ,villain: 0.7606  ,victim: 0.8616


 47%|████▋     | 340/718 [16:00<17:28,  2.77s/it]

Average accuracy 67.08%
Average f1 scores per class:    hero: 0.9118  ,villain: 0.7608  ,victim: 0.8569


 49%|████▊     | 350/718 [16:26<16:01,  2.61s/it]

Average accuracy 67.45%
Average f1 scores per class:    hero: 0.9114  ,villain: 0.7648  ,victim: 0.8610


 50%|█████     | 360/718 [16:54<17:19,  2.90s/it]

Average accuracy 67.38%
Average f1 scores per class:    hero: 0.9111  ,villain: 0.7648  ,victim: 0.8593


 52%|█████▏    | 370/718 [17:21<15:33,  2.68s/it]

Average accuracy 67.18%
Average f1 scores per class:    hero: 0.9135  ,villain: 0.7604  ,victim: 0.8631


 53%|█████▎    | 380/718 [17:48<16:10,  2.87s/it]

Average accuracy 66.60%
Average f1 scores per class:    hero: 0.9132  ,villain: 0.7535  ,victim: 0.8640


 54%|█████▍    | 390/718 [18:17<16:23,  3.00s/it]

Average accuracy 66.43%
Average f1 scores per class:    hero: 0.9103  ,villain: 0.7556  ,victim: 0.8650


 56%|█████▌    | 400/718 [18:52<18:07,  3.42s/it]

Average accuracy 66.31%
Average f1 scores per class:    hero: 0.9117  ,villain: 0.7572  ,victim: 0.8625


 57%|█████▋    | 410/718 [19:23<17:17,  3.37s/it]

Average accuracy 66.28%
Average f1 scores per class:    hero: 0.9138  ,villain: 0.7550  ,victim: 0.8659


 58%|█████▊    | 420/718 [19:52<15:12,  3.06s/it]

Average accuracy 66.61%
Average f1 scores per class:    hero: 0.9159  ,villain: 0.7580  ,victim: 0.8667


 60%|█████▉    | 430/718 [20:19<12:47,  2.67s/it]

Average accuracy 66.22%
Average f1 scores per class:    hero: 0.9178  ,villain: 0.7520  ,victim: 0.8698


 61%|██████▏   | 440/718 [20:46<12:43,  2.75s/it]

Average accuracy 66.31%
Average f1 scores per class:    hero: 0.9197  ,villain: 0.7508  ,victim: 0.8682


 63%|██████▎   | 450/718 [21:15<12:19,  2.76s/it]

Average accuracy 65.80%
Average f1 scores per class:    hero: 0.9170  ,villain: 0.7485  ,victim: 0.8667


 64%|██████▍   | 460/718 [21:46<12:40,  2.95s/it]

Average accuracy 65.74%
Average f1 scores per class:    hero: 0.9188  ,villain: 0.7467  ,victim: 0.8667


 65%|██████▌   | 470/718 [22:13<12:18,  2.98s/it]

Average accuracy 66.12%
Average f1 scores per class:    hero: 0.9184  ,villain: 0.7493  ,victim: 0.8695


 67%|██████▋   | 480/718 [22:39<10:24,  2.62s/it]

Average accuracy 65.99%
Average f1 scores per class:    hero: 0.9181  ,villain: 0.7462  ,victim: 0.8701


 68%|██████▊   | 490/718 [23:09<09:24,  2.48s/it]

Average accuracy 65.80%
Average f1 scores per class:    hero: 0.9177  ,villain: 0.7473  ,victim: 0.8687


 70%|██████▉   | 500/718 [23:39<11:21,  3.13s/it]

Average accuracy 65.48%
Average f1 scores per class:    hero: 0.9153  ,villain: 0.7436  ,victim: 0.8653


 71%|███████   | 510/718 [24:06<09:28,  2.73s/it]

Average accuracy 65.77%
Average f1 scores per class:    hero: 0.9150  ,villain: 0.7467  ,victim: 0.8680


 72%|███████▏  | 520/718 [24:33<08:48,  2.67s/it]

Average accuracy 65.18%
Average f1 scores per class:    hero: 0.9109  ,villain: 0.7471  ,victim: 0.8667


 74%|███████▍  | 530/718 [25:01<08:09,  2.60s/it]

Average accuracy 65.74%
Average f1 scores per class:    hero: 0.9126  ,villain: 0.7512  ,victim: 0.8692


 75%|███████▌  | 540/718 [25:27<07:20,  2.47s/it]

Average accuracy 66.00%
Average f1 scores per class:    hero: 0.9123  ,villain: 0.7521  ,victim: 0.8698


 77%|███████▋  | 550/718 [25:55<07:47,  2.78s/it]

Average accuracy 66.26%
Average f1 scores per class:    hero: 0.9103  ,villain: 0.7556  ,victim: 0.8721


 78%|███████▊  | 560/718 [26:26<08:11,  3.11s/it]

Average accuracy 66.04%
Average f1 scores per class:    hero: 0.9119  ,villain: 0.7522  ,victim: 0.8708


 79%|███████▉  | 570/718 [26:50<06:33,  2.66s/it]

Average accuracy 66.11%
Average f1 scores per class:    hero: 0.9117  ,villain: 0.7548  ,victim: 0.8713


 81%|████████  | 580/718 [27:19<06:22,  2.77s/it]

Average accuracy 66.35%
Average f1 scores per class:    hero: 0.9115  ,villain: 0.7564  ,victim: 0.8736


 82%|████████▏ | 590/718 [27:46<05:45,  2.70s/it]

Average accuracy 66.24%
Average f1 scores per class:    hero: 0.9079  ,villain: 0.7572  ,victim: 0.8723


 84%|████████▎ | 600/718 [28:15<06:09,  3.13s/it]

Average accuracy 66.03%
Average f1 scores per class:    hero: 0.9078  ,villain: 0.7562  ,victim: 0.8694


 85%|████████▍ | 610/718 [28:41<05:23,  2.99s/it]

Average accuracy 66.09%
Average f1 scores per class:    hero: 0.9093  ,villain: 0.7553  ,victim: 0.8699


 86%|████████▋ | 620/718 [29:14<05:01,  3.07s/it]

Average accuracy 65.99%
Average f1 scores per class:    hero: 0.9108  ,villain: 0.7534  ,victim: 0.8688


 88%|████████▊ | 630/718 [29:40<03:50,  2.61s/it]

Average accuracy 66.21%
Average f1 scores per class:    hero: 0.9122  ,villain: 0.7542  ,victim: 0.8693


 89%|████████▉ | 640/718 [30:06<03:15,  2.51s/it]

Average accuracy 66.01%
Average f1 scores per class:    hero: 0.9089  ,villain: 0.7512  ,victim: 0.8714


 91%|█████████ | 650/718 [30:33<02:47,  2.47s/it]

Average accuracy 65.92%
Average f1 scores per class:    hero: 0.9056  ,villain: 0.7504  ,victim: 0.8733


 92%|█████████▏| 660/718 [31:00<02:22,  2.46s/it]

Average accuracy 66.01%
Average f1 scores per class:    hero: 0.9071  ,villain: 0.7522  ,victim: 0.8732


 93%|█████████▎| 670/718 [31:26<02:01,  2.53s/it]

Average accuracy 66.37%
Average f1 scores per class:    hero: 0.9085  ,villain: 0.7544  ,victim: 0.8751


 95%|█████████▍| 680/718 [31:50<01:34,  2.48s/it]

Average accuracy 66.49%
Average f1 scores per class:    hero: 0.9098  ,villain: 0.7551  ,victim: 0.8755


 96%|█████████▌| 690/718 [32:21<01:35,  3.42s/it]

Average accuracy 66.33%
Average f1 scores per class:    hero: 0.9097  ,villain: 0.7543  ,victim: 0.8744


 97%|█████████▋| 700/718 [32:49<00:53,  2.97s/it]

Average accuracy 66.66%
Average f1 scores per class:    hero: 0.9110  ,villain: 0.7564  ,victim: 0.8748


 99%|█████████▉| 710/718 [33:15<00:24,  3.11s/it]

Average accuracy 66.85%
Average f1 scores per class:    hero: 0.9122  ,villain: 0.7570  ,victim: 0.8765


100%|██████████| 718/718 [33:36<00:00,  2.81s/it]

0.6708217270194986





In [43]:
d = 0
edge_cases = 0
edge_cases_idx = []
dummy_acc = 0
for ex in tqdm(test_dataset):
    sample_acc = score_accuracy(["{hero: [], villain: [], victim: []}"], [ex])
    if sample_acc == 1:
        edge_cases += 1
        edge_cases_idx.append(d)
    dummy_acc += sample_acc
    d += 1
    
dummy_acc /= d
print(dummy_acc)
print(edge_cases)

100%|██████████| 718/718 [00:03<00:00, 230.26it/s]

0.5501392757660167
395





In [59]:
model.eval()

ef1s = [0, 0, 0]
eprecs = [0, 0, 0]
erecall = [0, 0, 0]


avg_acc = 0
n = 0

for idx, ex in enumerate(tqdm(test_dataset)):
    if idx in edge_cases_idx:
        continue
    test_inputs = tensorise_input(ex)
    
    gen_ids = model.generate(**test_inputs, max_new_tokens=64)
    preds = processor.batch_decode(gen_ids[:, test_inputs["input_ids"].size(1):], skip_special_tokens=True)
    
    current_f1s = [0, 0, 0]
    avg_acc += score_accuracy([preds[0]], [ex])
    for idx, c in enumerate(classes):
        f1c, pc, rc = f1_for_class(preds[0], ex, c)
        current_f1s[idx] = f1c
        eprecs[idx] += pc
        erecall[idx] += rc
    
    ef1s[0] += current_f1s[0]
    ef1s[1] += current_f1s[1]
    ef1s[2] += current_f1s[2]
    
    hero = ex["hero"]
    villain = ex["villain"]
    victim = ex["victim"]
    print(f"Prediction: {preds[0]}")
    print(f"Ground truth: hero:{hero}, villain:{villain}, victim:{victim}")
    print("f1 scores per class:    hero: {:.4f}  ,villain: {:.4f}  ,victim: {:.4f}".format(current_f1s[0], current_f1s[1], current_f1s[2]))
    
    
    
    n += 1
    
    if n % 20 == 0:
        torch.cuda.empty_cache()
        print("Average accuracy {:.2f}%".format((avg_acc / n) * 100))
        print("Average f1 scores per class:    hero: {:.4f}  ,villain: {:.4f}  ,victim: {:.4f}".format(ef1s[0] / n, ef1s[1] / n, ef1s[2] / n))
    
    
avg_acc /= n
print(avg_acc)

  0%|          | 1/718 [00:02<31:50,  2.66s/it]

Prediction: {hero: [], villain: ['libertarians'], victim: []}
Ground truth: hero:[], villain:['libertarians', 'libertarian party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.6667  ,victim: 1.0000


  1%|          | 5/718 [00:05<11:09,  1.06it/s]

Prediction: {hero: [], villain: ['chinese'], victim: []}
Ground truth: hero:[], villain:['chinese'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


  1%|          | 6/718 [00:09<19:08,  1.61s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['kamala harris', 'joe biden'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


  1%|▏         | 10/718 [00:11<12:45,  1.08s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:[], victim:['donald trump']
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 0.0000


  2%|▏         | 11/718 [00:14<16:15,  1.38s/it]

Prediction: {hero: ['bill clinton'], villain: [], victim: []}
Ground truth: hero:[], villain:[], victim:['bill clinton']
f1 scores per class:    hero: 0.0000  ,villain: 1.0000  ,victim: 0.0000


  2%|▏         | 13/718 [00:17<17:06,  1.46s/it]

Prediction: {hero: [], villain: ['republican party'], victim: []}
Ground truth: hero:[], villain:['republican party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


  2%|▏         | 14/718 [00:21<21:15,  1.81s/it]

Prediction: {hero: [], villain: ['adolf hitler', 'donald trump'], victim: []}
Ground truth: hero:[], villain:['adolf hitler', 'donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


  2%|▏         | 15/718 [00:23<22:05,  1.89s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:['republican'], villain:[], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 1.0000  ,victim: 1.0000


  2%|▏         | 16/718 [00:25<23:02,  1.97s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:[], victim:['new yorkers']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000


  2%|▏         | 17/718 [00:30<32:56,  2.82s/it]

Prediction: {hero: ['kamal harris'], villain: ['joe biden'], victim: []}
Ground truth: hero:[], villain:['kamal harris', 'joe biden'], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 0.6667  ,victim: 1.0000


  3%|▎         | 18/718 [00:34<35:57,  3.08s/it]

Prediction: {hero: [], villain: ['joe biden', 'alexandria cortez'], victim: []}
Ground truth: hero:[], villain:['joe biden', 'alexandria cortez', 'democrats', 'hillary clinton'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.6667  ,victim: 1.0000


  3%|▎         | 19/718 [00:37<35:12,  3.02s/it]

Prediction: {hero: [], villain: ['hillary clinton'], victim: []}
Ground truth: hero:[], villain:['hillary clinton'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


  3%|▎         | 20/718 [00:41<39:10,  3.37s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: ['barack obama']}
Ground truth: hero:[], villain:[], victim:['donald trump']
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 0.0000


  3%|▎         | 21/718 [00:44<37:04,  3.19s/it]

Prediction: {hero: [], villain: ['trump supporters'], victim: []}
Ground truth: hero:[], villain:['trump supporters'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


  4%|▎         | 26/718 [00:46<15:29,  1.34s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['covid19'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


  4%|▍         | 29/718 [00:49<13:19,  1.16s/it]

Prediction: {hero: [], villain: ['china'], victim: []}
Ground truth: hero:[], villain:['china'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


  5%|▍         | 34/718 [00:52<10:09,  1.12it/s]

Prediction: {hero: ['voluntaryist'], villain: [], victim: []}
Ground truth: hero:['voluntaryist'], villain:[], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


  5%|▌         | 38/718 [00:55<09:20,  1.21it/s]

Prediction: {hero: [], villain: ['democrats'], victim: []}
Ground truth: hero:['democrats'], villain:['republicans'], victim:['democrats']
f1 scores per class:    hero: 0.0000  ,villain: 0.0000  ,victim: 0.0000


  6%|▌         | 43/718 [00:57<07:52,  1.43it/s]

Prediction: {hero: [], villain: ['media'], victim: []}
Ground truth: hero:[], villain:['media'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


  6%|▋         | 46/718 [01:00<08:34,  1.31it/s]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['germs'], victim:['people']
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 0.0000
Average accuracy 45.00%
Average f1 scores per class:    hero: 0.8000  ,villain: 0.6500  ,victim: 0.7000


  7%|▋         | 47/718 [01:02<10:19,  1.08it/s]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['group text'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


  7%|▋         | 50/718 [01:05<10:15,  1.09it/s]

Prediction: {hero: ['donald trump'], villain: [], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 0.0000  ,victim: 1.0000


  7%|▋         | 51/718 [01:08<13:32,  1.22s/it]

Prediction: {hero: ['donald trump'], villain: ['barack obama'], victim: []}
Ground truth: hero:['donald trump'], villain:['barack obama'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


  7%|▋         | 52/718 [01:11<16:08,  1.45s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


  8%|▊         | 55/718 [01:13<12:55,  1.17s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['kim jong-un', 'donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


  8%|▊         | 56/718 [01:17<16:49,  1.52s/it]

Prediction: {hero: [], villain: ['barack obama'], victim: ['george bush']}
Ground truth: hero:['barack obama', 'george bush'], villain:['donald trump'], victim:['americans']
f1 scores per class:    hero: 0.0000  ,villain: 0.0000  ,victim: 0.0000


  8%|▊         | 57/718 [01:19<19:02,  1.73s/it]

Prediction: {hero: [], villain: ['democratic party'], victim: []}
Ground truth: hero:[], villain:[], victim:['democratic party']
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 0.0000


  8%|▊         | 59/718 [01:22<18:38,  1.70s/it]

Prediction: {hero: [], villain: ['barack obama'], victim: ['donald trump']}
Ground truth: hero:[], villain:['barack obama'], victim:['donald trump']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


  8%|▊         | 60/718 [01:25<21:15,  1.94s/it]

Prediction: {hero: [], villain: ['hillary clinton'], victim: []}
Ground truth: hero:[], villain:['hillary clinton', 'democratic party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.6667  ,victim: 1.0000


  9%|▉         | 65/718 [01:28<12:14,  1.12s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:['united states of america']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000


  9%|▉         | 66/718 [01:31<14:41,  1.35s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


  9%|▉         | 67/718 [01:33<16:26,  1.51s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['office'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


 10%|▉         | 69/718 [01:36<15:44,  1.45s/it]

Prediction: {hero: [], villain: ['democrats'], victim: []}
Ground truth: hero:[], villain:['democrats'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 10%|▉         | 70/718 [01:38<18:02,  1.67s/it]

Prediction: {hero: [], villain: ['islam'], victim: []}
Ground truth: hero:[], villain:['islam', 'democratic national islamic party', 'democratic party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.5000  ,victim: 1.0000


 10%|█         | 72/718 [01:41<16:48,  1.56s/it]

Prediction: {hero: ['barack obama'], villain: [], victim: []}
Ground truth: hero:['barack obama'], villain:[], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 10%|█         | 73/718 [01:43<18:02,  1.68s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['work from home'], victim:['plans']
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 0.0000


 10%|█         | 74/718 [01:47<23:11,  2.16s/it]

Prediction: {hero: ['barack obama'], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['barack obama'], victim:['donald trump']
f1 scores per class:    hero: 0.0000  ,villain: 0.0000  ,victim: 0.0000


 10%|█         | 75/718 [01:50<24:37,  2.30s/it]

Prediction: {hero: [], villain: ['joe biden'], victim: []}
Ground truth: hero:[], villain:['joe biden'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 11%|█         | 76/718 [01:52<24:49,  2.32s/it]

Prediction: {hero: [], villain: ['ai'], victim: []}
Ground truth: hero:[], villain:['ai'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 11%|█         | 79/718 [01:55<17:16,  1.62s/it]

Prediction: {hero: [], villain: ['republican party'], victim: []}
Ground truth: hero:[], villain:['republican party'], victim:['poor']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000
Average accuracy 47.08%
Average f1 scores per class:    hero: 0.8250  ,villain: 0.6042  ,victim: 0.7000


 11%|█         | 80/718 [01:57<18:32,  1.74s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['texas', 'democrat'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


 11%|█▏        | 81/718 [02:00<20:51,  1.96s/it]

Prediction: {hero: [], villain: ['republican party'], victim: []}
Ground truth: hero:[], villain:['republican party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 12%|█▏        | 86/718 [02:04<13:08,  1.25s/it]

Prediction: {hero: [], villain: ['joe biden'], victim: []}
Ground truth: hero:[], villain:['joe biden'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 12%|█▏        | 87/718 [02:07<15:28,  1.47s/it]

Prediction: {hero: ['president trump'], villain: [], victim: []}
Ground truth: hero:[], villain:['president trump'], victim:['arizona man']
f1 scores per class:    hero: 0.0000  ,villain: 0.0000  ,victim: 0.0000


 12%|█▏        | 88/718 [02:10<17:56,  1.71s/it]

Prediction: {hero: [], villain: ['china'], victim: []}
Ground truth: hero:[], villain:[], victim:['china', 'donald trump']
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 0.0000


 12%|█▏        | 89/718 [02:12<20:11,  1.93s/it]

Prediction: {hero: [], villain: ['democrats'], victim: []}
Ground truth: hero:[], villain:['democrats'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 13%|█▎        | 90/718 [02:16<24:03,  2.30s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['joe biden'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


 13%|█▎        | 91/718 [02:19<24:34,  2.35s/it]

Prediction: {hero: [], villain: ['american politics'], victim: []}
Ground truth: hero:[], villain:['american politics'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 13%|█▎        | 92/718 [02:22<28:13,  2.71s/it]

Prediction: {hero: [], villain: ['leftwing'], victim: ['face mask', 'non-white people']}
Ground truth: hero:[], villain:['leftwing'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000


 13%|█▎        | 94/718 [02:25<22:42,  2.18s/it]

Prediction: {hero: [], villain: ['millenials'], victim: []}
Ground truth: hero:[], villain:['millenials'], victim:['boomers']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000


 13%|█▎        | 96/718 [02:29<20:47,  2.01s/it]

Prediction: {hero: [], villain: ['republican party'], victim: ['democrats']}
Ground truth: hero:['republican party'], villain:['democrats'], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 0.0000  ,victim: 0.0000


 14%|█▍        | 104/718 [02:31<08:26,  1.21it/s]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:['millenials', 'gen x'], villain:[], victim:['boomers']
f1 scores per class:    hero: 0.0000  ,villain: 1.0000  ,victim: 0.0000


 15%|█▍        | 105/718 [02:35<11:41,  1.14s/it]

Prediction: {hero: [], villain: ['hillary clinton', 'donald trump'], victim: ['children']}
Ground truth: hero:['hillary clinton'], villain:['donald trump'], victim:['children']
f1 scores per class:    hero: 0.0000  ,villain: 0.6667  ,victim: 1.0000


 15%|█▍        | 106/718 [02:37<13:54,  1.36s/it]

Prediction: {hero: ['immigrants'], villain: [], victim: []}
Ground truth: hero:['immigrants'], villain:[], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 15%|█▌        | 108/718 [02:41<15:42,  1.54s/it]

Prediction: {hero: ['god'], villain: ['republican party', 'democratic party'], victim: ['america']}
Ground truth: hero:['god'], villain:['republican party', 'democratic party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000


 15%|█▌        | 109/718 [02:44<17:29,  1.72s/it]

Prediction: {hero: [], villain: ['trump supporters'], victim: []}
Ground truth: hero:[], villain:['trump supporters', 'donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.6667  ,victim: 1.0000


 15%|█▌        | 111/718 [02:47<16:42,  1.65s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:['donald trump'], villain:['barack obama'], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 0.0000  ,victim: 1.0000


 16%|█▌        | 112/718 [02:49<17:57,  1.78s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['combined vaccines'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


 16%|█▌        | 116/718 [02:52<11:40,  1.16s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['centre of disease control (cdc)'], victim:['japan']
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 0.0000


 16%|█▋        | 117/718 [02:55<15:24,  1.54s/it]

Prediction: {hero: [], villain: ['lord voldemort'], victim: ['donald trump']}
Ground truth: hero:[], villain:['lord voldemort', 'donald trump'], victim:['muggles']
f1 scores per class:    hero: 1.0000  ,villain: 0.6667  ,victim: 0.0000
Average accuracy 43.06%
Average f1 scores per class:    hero: 0.8000  ,villain: 0.5861  ,victim: 0.6500


 17%|█▋        | 119/718 [02:58<14:41,  1.47s/it]

Prediction: {hero: [], villain: ['democratic party'], victim: []}
Ground truth: hero:[], villain:['democratic party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 17%|█▋        | 120/718 [03:01<17:13,  1.73s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:['donald trump'], villain:[], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 0.0000  ,victim: 1.0000


 17%|█▋        | 122/718 [03:04<15:49,  1.59s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 18%|█▊        | 128/718 [03:07<10:09,  1.03s/it]

Prediction: {hero: [], villain: ['donald trump', 'republican party'], victim: ['americans']}
Ground truth: hero:[], villain:['donald trump'], victim:['republican party']
f1 scores per class:    hero: 1.0000  ,villain: 0.6667  ,victim: 0.0000


 18%|█▊        | 129/718 [03:10<12:22,  1.26s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['work week'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


 18%|█▊        | 130/718 [03:13<14:46,  1.51s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 19%|█▊        | 134/718 [03:16<11:28,  1.18s/it]

Prediction: {hero: [], villain: ['us military'], victim: ['us citizen']}
Ground truth: hero:[], villain:[], victim:['us military']
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 0.0000


 19%|█▉        | 135/718 [03:19<13:36,  1.40s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 19%|█▉        | 137/718 [03:24<15:51,  1.64s/it]

Prediction: {hero: [], villain: ['democrats'], victim: []}
Ground truth: hero:[], villain:['democrats'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 19%|█▉        | 138/718 [03:27<19:07,  1.98s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: ['illegal', 'muslim']}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000


 19%|█▉        | 140/718 [03:30<16:48,  1.74s/it]

Prediction: {hero: [], villain: ['antifa'], victim: []}
Ground truth: hero:[], villain:['antifa', 'democratic party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.6667  ,victim: 1.0000


 20%|█▉        | 141/718 [03:34<21:34,  2.24s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: ['democratic party']}
Ground truth: hero:[], villain:['donald trump'], victim:['democratic party']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 20%|█▉        | 142/718 [03:37<22:45,  2.37s/it]

Prediction: {hero: ['gaston'], villain: [], victim: []}
Ground truth: hero:['gaston'], villain:[], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 20%|██        | 144/718 [03:40<19:43,  2.06s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:['estonia', 'lithuania', 'latvia']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000


 21%|██        | 148/718 [03:44<14:04,  1.48s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: ['mexican']}
Ground truth: hero:[], villain:['donald trump'], victim:['mexican']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 21%|██        | 152/718 [03:47<10:40,  1.13s/it]

Prediction: {hero: [], villain: ['china'], victim: []}
Ground truth: hero:[], villain:[], victim:['china', 'spain']
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 0.0000


 21%|██▏       | 153/718 [03:50<13:04,  1.39s/it]

Prediction: {hero: [], villain: ['republican party'], victim: ['people']}
Ground truth: hero:[], villain:['republican party'], victim:['people']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 22%|██▏       | 155/718 [03:54<15:07,  1.61s/it]

Prediction: {hero: ['libertarian'], villain: ['libertarians', 'republicans', 'democrats'], victim: []}
Ground truth: hero:['libertarian'], villain:[], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


 22%|██▏       | 157/718 [03:57<15:02,  1.61s/it]

Prediction: {hero: ['libertarian'], villain: ['democrat'], victim: []}
Ground truth: hero:['libertarian'], villain:['democrat'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 22%|██▏       | 158/718 [04:00<16:48,  1.80s/it]

Prediction: {hero: [], villain: [], victim: ['america']}
Ground truth: hero:[], villain:[], victim:['america']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000
Average accuracy 46.98%
Average f1 scores per class:    hero: 0.8375  ,villain: 0.6187  ,victim: 0.6750


 22%|██▏       | 159/718 [04:03<18:16,  1.96s/it]

Prediction: {hero: [], villain: ['american politics'], victim: []}
Ground truth: hero:[], villain:['american politics'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 23%|██▎       | 164/718 [04:06<10:45,  1.17s/it]

Prediction: {hero: [], villain: ['government'], victim: ['people']}
Ground truth: hero:[], villain:['government'], victim:['people']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 23%|██▎       | 166/718 [04:09<12:16,  1.33s/it]

Prediction: {hero: ['vladimir putin'], villain: [], victim: []}
Ground truth: hero:['vladimir putin'], villain:[], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 23%|██▎       | 167/718 [04:12<14:39,  1.60s/it]

Prediction: {hero: ['barack obama'], villain: [], victim: ['child']}
Ground truth: hero:['barack obama'], villain:[], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000


 24%|██▎       | 169/718 [04:16<15:20,  1.68s/it]

Prediction: {hero: [], villain: ['barack obama'], victim: ['donald trump']}
Ground truth: hero:[], villain:['barack obama', 'donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.6667  ,victim: 0.0000


 24%|██▍       | 171/718 [04:20<15:47,  1.73s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: ['american people']}
Ground truth: hero:[], villain:['donald trump'], victim:['american people', 'mexico']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.6667


 24%|██▍       | 172/718 [04:22<17:14,  1.90s/it]

Prediction: {hero: [], villain: [], victim: ['parents']}
Ground truth: hero:[], villain:[], victim:['parents']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 24%|██▍       | 173/718 [04:25<17:50,  1.96s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:[], victim:['american flag']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000


 25%|██▍       | 176/718 [04:27<12:35,  1.39s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['melania trump', 'donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


 25%|██▍       | 178/718 [04:31<14:39,  1.63s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 25%|██▌       | 183/718 [04:34<09:27,  1.06s/it]

Prediction: {hero: ['libertarians'], villain: [], victim: []}
Ground truth: hero:[], villain:['libertarians'], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 0.0000  ,victim: 1.0000


 26%|██▌       | 186/718 [04:36<08:38,  1.03it/s]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['fox news', 'roger ailes', 'republican party', 'donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


 26%|██▌       | 188/718 [04:39<09:49,  1.11s/it]

Prediction: {hero: [], villain: ['government'], victim: []}
Ground truth: hero:[], villain:['government'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 26%|██▋       | 189/718 [04:42<11:49,  1.34s/it]

Prediction: {hero: [], villain: ['government'], victim: []}
Ground truth: hero:[], villain:['government'], victim:['veterans']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000


 27%|██▋       | 191/718 [04:45<12:24,  1.41s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: ['coronavirus']}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000


 27%|██▋       | 192/718 [04:48<14:12,  1.62s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 27%|██▋       | 193/718 [04:51<16:30,  1.89s/it]

Prediction: {hero: [], villain: ['bats', 'rats'], victim: []}
Ground truth: hero:[], villain:['bats'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.6667  ,victim: 1.0000


 27%|██▋       | 194/718 [04:54<18:11,  2.08s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:['us federal pandemic response team']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000


 28%|██▊       | 199/718 [04:57<10:08,  1.17s/it]

Prediction: {hero: [], villain: ['trump supporters'], victim: []}
Ground truth: hero:[], villain:['trump supporters', 'donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.6667  ,victim: 1.0000


 28%|██▊       | 201/718 [05:00<11:18,  1.31s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: ['joe biden']}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000
Average accuracy 46.75%
Average f1 scores per class:    hero: 0.8600  ,villain: 0.6550  ,victim: 0.6667


 28%|██▊       | 203/718 [05:03<11:26,  1.33s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:['donald trump'], villain:['wuhan virus', 'liberals'], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 0.0000  ,victim: 1.0000


 30%|██▉       | 212/718 [05:07<06:23,  1.32it/s]

Prediction: {hero: ['donald trump'], villain: ['barack obama'], victim: ['america']}
Ground truth: hero:['donald trump'], villain:['barack obama'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000


 30%|██▉       | 215/718 [05:09<06:38,  1.26it/s]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 30%|███       | 218/718 [05:12<06:53,  1.21it/s]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 32%|███▏      | 233/718 [05:15<03:16,  2.46it/s]

Prediction: {hero: [], villain: [], victim: ['people']}
Ground truth: hero:[], villain:[], victim:['people']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 33%|███▎      | 234/718 [05:17<04:20,  1.85it/s]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 33%|███▎      | 235/718 [05:20<05:50,  1.38it/s]

Prediction: {hero: [], villain: ['hillary clinton'], victim: []}
Ground truth: hero:[], villain:['hillary clinton'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 33%|███▎      | 236/718 [05:24<07:40,  1.05it/s]

Prediction: {hero: ['donald trump'], villain: [], victim: ['america']}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 0.0000  ,victim: 0.0000


 33%|███▎      | 238/718 [05:28<09:34,  1.20s/it]

Prediction: {hero: [], villain: ['black lives matter riots'], victim: ['black lives']}
Ground truth: hero:[], villain:['black lives matter riots'], victim:['black lives']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 34%|███▎      | 242/718 [05:30<07:37,  1.04it/s]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['police officer'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


 34%|███▍      | 243/718 [05:34<10:34,  1.34s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 34%|███▍      | 244/718 [05:37<12:11,  1.54s/it]

Prediction: {hero: [], villain: ['democratic party'], victim: []}
Ground truth: hero:[], villain:['democratic party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 34%|███▍      | 245/718 [05:41<15:30,  1.97s/it]

Prediction: {hero: [], villain: ['antifa', 'hillary clinton', 'democratic party'], victim: []}
Ground truth: hero:[], villain:['antifa', 'hillary clinton', 'democratic party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 34%|███▍      | 247/718 [05:45<16:07,  2.05s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 35%|███▍      | 251/718 [05:48<10:48,  1.39s/it]

Prediction: {hero: [], villain: ['joe biden'], victim: []}
Ground truth: hero:[], villain:['joe biden'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 35%|███▌      | 254/718 [05:51<09:58,  1.29s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: ['people in america']}
Ground truth: hero:[], villain:['donald trump'], victim:['people in america']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 36%|███▌      | 255/718 [05:54<11:33,  1.50s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 36%|███▌      | 257/718 [05:58<12:20,  1.61s/it]

Prediction: {hero: [], villain: ['democratic party', 'republicans'], victim: ['black community']}
Ground truth: hero:[], villain:['democratic party'], victim:['republicans', 'black community']
f1 scores per class:    hero: 1.0000  ,villain: 0.6667  ,victim: 0.6667


 36%|███▌      | 258/718 [06:01<14:50,  1.94s/it]

Prediction: {hero: [], villain: ['dr. dr. anthony fauci'], victim: []}
Ground truth: hero:[], villain:['dr. dr. anthony fauci', 'dr. anthony fauci'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.6667  ,victim: 1.0000


 36%|███▌      | 259/718 [06:04<16:07,  2.11s/it]

Prediction: {hero: [], villain: ['democratic party'], victim: []}
Ground truth: hero:[], villain:['democratic party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000
Average accuracy 51.74%
Average f1 scores per class:    hero: 0.8667  ,villain: 0.6819  ,victim: 0.7028


 36%|███▌      | 260/718 [06:08<18:33,  2.43s/it]

Prediction: {hero: ['barack obama'], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['barack obama'], victim:['donald trump']
f1 scores per class:    hero: 0.0000  ,villain: 0.0000  ,victim: 0.0000


 37%|███▋      | 265/718 [06:11<09:47,  1.30s/it]

Prediction: {hero: [], villain: ['republican party'], victim: []}
Ground truth: hero:[], villain:['republican party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 37%|███▋      | 267/718 [06:14<10:44,  1.43s/it]

Prediction: {hero: [], villain: ['republican party', 'mike pence'], victim: []}
Ground truth: hero:[], villain:['republican party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.6667  ,victim: 1.0000


 37%|███▋      | 269/718 [06:17<10:40,  1.43s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 38%|███▊      | 272/718 [06:20<09:21,  1.26s/it]

Prediction: {hero: [], villain: ['eric trump'], victim: []}
Ground truth: hero:[], villain:['eric trump', 'donald trump jr.', 'donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.5000  ,victim: 1.0000


 38%|███▊      | 274/718 [06:23<10:07,  1.37s/it]

Prediction: {hero: ['donald trump'], villain: [], victim: []}
Ground truth: hero:['donald trump'], villain:[], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 38%|███▊      | 276/718 [06:27<10:38,  1.44s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 39%|███▊      | 277/718 [06:31<14:21,  1.95s/it]

Prediction: {hero: [], villain: ['alexandria ocasio-cortez (aoc )'], victim: ['kids']}
Ground truth: hero:[], villain:['alexandria ocasio-cortez (aoc )', 'aoc challenge'], victim:['kids']
f1 scores per class:    hero: 1.0000  ,villain: 0.6667  ,victim: 1.0000


 39%|███▉      | 279/718 [06:35<13:53,  1.90s/it]

Prediction: {hero: [], villain: ['republican party', 'democratic party'], victim: []}
Ground truth: hero:[], villain:['republican party', 'democratic party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 39%|███▉      | 280/718 [06:37<14:51,  2.04s/it]

Prediction: {hero: [], villain: ['democratic party'], victim: []}
Ground truth: hero:[], villain:['democratic party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 39%|███▉      | 283/718 [06:41<11:44,  1.62s/it]

Prediction: {hero: [], villain: ['parents'], victim: ['scientists']}
Ground truth: hero:[], villain:[], victim:['parents']
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 0.0000


 40%|███▉      | 286/718 [06:43<09:18,  1.29s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:[], victim:['401k']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000


 40%|███▉      | 287/718 [06:46<10:48,  1.51s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 40%|████      | 288/718 [06:49<13:05,  1.83s/it]

Prediction: {hero: [], villain: ['parents of unvaccinated children'], victim: []}
Ground truth: hero:[], villain:['parents of unvaccinated children'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 40%|████      | 290/718 [06:52<12:32,  1.76s/it]

Prediction: {hero: [], villain: ['barack obama'], victim: []}
Ground truth: hero:['barack obama'], villain:['nugent', 'doanld trump', 'giuliani', "o'reilly"], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 0.0000  ,victim: 1.0000


 41%|████      | 291/718 [06:55<14:33,  2.05s/it]

Prediction: {hero: [], villain: ['republican', 'democrat'], victim: []}
Ground truth: hero:[], villain:['republican', 'democrat'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 41%|████      | 294/718 [06:58<10:51,  1.54s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 41%|████      | 295/718 [07:01<12:46,  1.81s/it]

Prediction: {hero: [], villain: ['leftists'], victim: ['donald trump']}
Ground truth: hero:[], villain:['leftists'], victim:['donald trump']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 41%|████▏     | 297/718 [07:05<12:30,  1.78s/it]

Prediction: {hero: [], villain: ['alt-right', 'neo-nazi'], victim: []}
Ground truth: hero:[], villain:['alt-right'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.6667  ,victim: 1.0000


 42%|████▏     | 298/718 [07:08<13:49,  1.97s/it]

Prediction: {hero: [], villain: ['democrats'], victim: []}
Ground truth: hero:[], villain:['democrats'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000
Average accuracy 53.63%
Average f1 scores per class:    hero: 0.8714  ,villain: 0.6952  ,victim: 0.7238


 42%|████▏     | 300/718 [07:11<12:34,  1.81s/it]

Prediction: {hero: [], villain: ['republican party'], victim: []}
Ground truth: hero:[], villain:['republican party', 'grand old party (gop)'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.6667  ,victim: 1.0000


 42%|████▏     | 303/718 [07:13<09:47,  1.41s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 42%|████▏     | 304/718 [07:16<11:52,  1.72s/it]

Prediction: {hero: [], villain: ['gary johnson'], victim: []}
Ground truth: hero:[], villain:['gary johnson'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 43%|████▎     | 306/718 [07:19<10:58,  1.60s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 43%|████▎     | 307/718 [07:22<12:24,  1.81s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 43%|████▎     | 310/718 [07:25<09:33,  1.41s/it]

Prediction: {hero: ['communist party'], villain: [], victim: []}
Ground truth: hero:['communist party'], villain:[], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 43%|████▎     | 311/718 [07:28<11:20,  1.67s/it]

Prediction: {hero: ['republican party'], villain: [], victim: []}
Ground truth: hero:['republican party'], villain:[], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 44%|████▎     | 314/718 [07:30<09:01,  1.34s/it]

Prediction: {hero: ['rex bell'], villain: [], victim: []}
Ground truth: hero:[], villain:['rex bell'], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 0.0000  ,victim: 1.0000


 44%|████▍     | 317/718 [07:34<08:59,  1.34s/it]

Prediction: {hero: ['green party', 'jill stein'], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['green party', 'jill stein'], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 0.0000  ,victim: 1.0000


 45%|████▌     | 326/718 [07:37<04:35,  1.42it/s]

Prediction: {hero: [], villain: ['democratic party'], victim: []}
Ground truth: hero:[], villain:['democratic party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 46%|████▌     | 327/718 [07:41<06:34,  1.01s/it]

Prediction: {hero: [], villain: ['joe biden'], victim: []}
Ground truth: hero:[], villain:['joe biden'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 46%|████▌     | 330/718 [07:44<06:20,  1.02it/s]

Prediction: {hero: [], villain: ['democrats'], victim: []}
Ground truth: hero:[], villain:['democrats'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 46%|████▌     | 332/718 [07:46<06:28,  1.01s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:[], victim:['courier', 'post']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000


 46%|████▋     | 333/718 [07:51<09:27,  1.48s/it]

Prediction: {hero: [], villain: ['hunter biden'], victim: []}
Ground truth: hero:[], villain:['hunter biden'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 47%|████▋     | 334/718 [07:53<10:35,  1.66s/it]

Prediction: {hero: ['donald trump'], villain: [], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 0.0000  ,victim: 1.0000


 47%|████▋     | 336/718 [07:56<09:55,  1.56s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:['syria']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000


 47%|████▋     | 337/718 [07:59<11:53,  1.87s/it]

Prediction: {hero: [], villain: ['donald trump', 'republicans'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.6667  ,victim: 1.0000


 47%|████▋     | 338/718 [08:03<13:36,  2.15s/it]

Prediction: {hero: [], villain: ['democratic'], victim: ['george floyd']}
Ground truth: hero:[], villain:['democratic'], victim:['george floyd']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 47%|████▋     | 339/718 [08:05<13:42,  2.17s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['scaremongering'], victim:['vaccine']
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 0.0000


 47%|████▋     | 340/718 [08:08<14:37,  2.32s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000
Average accuracy 55.05%
Average f1 scores per class:    hero: 0.8688  ,villain: 0.7042  ,victim: 0.7396


 48%|████▊     | 347/718 [08:12<06:58,  1.13s/it]

Prediction: {hero: [], villain: ['joe biden'], victim: []}
Ground truth: hero:[], villain:['joe biden'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 48%|████▊     | 348/718 [08:14<07:51,  1.28s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:['sweden'], villain:[], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 1.0000  ,victim: 1.0000


 49%|████▉     | 354/718 [08:18<05:46,  1.05it/s]

Prediction: {hero: [], villain: ['china', 'donald trump'], victim: []}
Ground truth: hero:[], villain:['china'], victim:['donald trump', 'united states']
f1 scores per class:    hero: 1.0000  ,villain: 0.6667  ,victim: 0.0000


 49%|████▉     | 355/718 [08:21<06:57,  1.15s/it]

Prediction: {hero: [], villain: ['democrat party'], victim: []}
Ground truth: hero:[], villain:['democrat party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 50%|████▉     | 357/718 [08:24<07:11,  1.19s/it]

Prediction: {hero: [], villain: ['china'], victim: []}
Ground truth: hero:[], villain:['china'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 50%|████▉     | 358/718 [08:26<08:31,  1.42s/it]

Prediction: {hero: [], villain: ['democratic party'], victim: []}
Ground truth: hero:[], villain:['democratic party'], victim:['america']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000


 50%|█████     | 359/718 [08:29<10:10,  1.70s/it]

Prediction: {hero: [], villain: ['members of libertarian party'], victim: []}
Ground truth: hero:[], villain:['members of libertarian party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 50%|█████     | 360/718 [08:33<11:43,  1.97s/it]

Prediction: {hero: [], villain: ['democratic party'], victim: []}
Ground truth: hero:['democratic party'], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 0.0000  ,victim: 1.0000


 51%|█████     | 365/718 [08:36<06:55,  1.18s/it]

Prediction: {hero: ['libertarian party'], villain: [], victim: []}
Ground truth: hero:['libertarian party'], villain:[], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 51%|█████     | 366/718 [08:38<08:14,  1.40s/it]

Prediction: {hero: [], villain: ['liberalism'], victim: []}
Ground truth: hero:[], villain:['liberalism'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 52%|█████▏    | 371/718 [08:41<05:20,  1.08it/s]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['joe biden'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


 52%|█████▏    | 372/718 [08:43<06:21,  1.10s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['warsaw ghetto'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


 52%|█████▏    | 376/718 [08:46<05:17,  1.08it/s]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:['victim']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000


 53%|█████▎    | 378/718 [08:49<06:09,  1.09s/it]

Prediction: {hero: ['libertarians'], villain: ['libertarian party'], victim: []}
Ground truth: hero:[], villain:['libertarians'], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 0.0000  ,victim: 1.0000


 53%|█████▎    | 380/718 [08:52<06:31,  1.16s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 54%|█████▎    | 385/718 [08:54<04:43,  1.18it/s]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:[], victim:['zoom']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000


 54%|█████▍    | 386/718 [09:00<07:47,  1.41s/it]

Prediction: {hero: [], villain: ['mike pence', 'kamala harris'], victim: []}
Ground truth: hero:['mike pence'], villain:['kamala harris'], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 0.6667  ,victim: 1.0000


 54%|█████▍    | 387/718 [09:02<08:46,  1.59s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 54%|█████▍    | 388/718 [09:06<10:20,  1.88s/it]

Prediction: {hero: [], villain: ['norway', 'britain'], victim: []}
Ground truth: hero:['norway'], villain:['britain'], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 0.6667  ,victim: 1.0000


 55%|█████▍    | 392/718 [09:09<07:08,  1.31s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000
Average accuracy 54.49%
Average f1 scores per class:    hero: 0.8556  ,villain: 0.7093  ,victim: 0.7463


 55%|█████▍    | 393/718 [09:13<09:40,  1.79s/it]

Prediction: {hero: [], villain: ['xi jinping', 'democrats'], victim: ['donald trump', 'americans']}
Ground truth: hero:[], villain:['xi jinping', 'democrats'], victim:['donald trump']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.6667


 55%|█████▍    | 394/718 [09:17<11:12,  2.08s/it]

Prediction: {hero: [], villain: [], victim: ['architects']}
Ground truth: hero:[], villain:[], victim:['architects']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 55%|█████▌    | 396/718 [09:23<13:19,  2.48s/it]

Prediction: {hero: [], villain: ['donald trump', 'media', 'muslims', 'women', 'debates', 'black lives matter', 'mosques', 'mexicans'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:['media', 'muslims', 'women', 'debates', 'black lives matter', 'mosques', 'mexicans']
f1 scores per class:    hero: 1.0000  ,villain: 0.2222  ,victim: 0.0000


 55%|█████▌    | 397/718 [09:27<15:03,  2.81s/it]

Prediction: {hero: ['alcohol', 'corona beer'], villain: [], victim: []}
Ground truth: hero:['alcohol'], villain:[], victim:[]
f1 scores per class:    hero: 0.6667  ,villain: 1.0000  ,victim: 1.0000


 56%|█████▌    | 399/718 [09:30<12:20,  2.32s/it]

Prediction: {hero: [], villain: ['ronald reagan'], victim: []}
Ground truth: hero:[], villain:['ronald reagan'], victim:['usa']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000


 56%|█████▌    | 400/718 [09:34<13:27,  2.54s/it]

Prediction: {hero: ['donald trump'], villain: ['joe biden'], victim: []}
Ground truth: hero:['donald trump'], villain:['joe biden'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 56%|█████▌    | 402/718 [09:36<10:37,  2.02s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['gun ban'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


 56%|█████▌    | 403/718 [09:39<12:03,  2.30s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: ['muslim']}
Ground truth: hero:[], villain:['donald trump'], victim:['muslim']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 56%|█████▋    | 404/718 [09:42<12:32,  2.40s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 56%|█████▋    | 405/718 [09:45<13:07,  2.52s/it]

Prediction: {hero: [], villain: ['republican party'], victim: []}
Ground truth: hero:[], villain:['republican party', 'donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.6667  ,victim: 1.0000


 57%|█████▋    | 406/718 [09:49<15:02,  2.89s/it]

Prediction: {hero: [], villain: ['arab muslims', 'democratic party'], victim: ['black people']}
Ground truth: hero:[], villain:['arab muslims', 'democratic party'], victim:['black people']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 57%|█████▋    | 407/718 [09:52<14:50,  2.86s/it]

Prediction: {hero: [], villain: ['democratic party'], victim: []}
Ground truth: hero:[], villain:['democratic party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 57%|█████▋    | 411/718 [09:54<07:27,  1.46s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


 58%|█████▊    | 413/718 [09:58<08:11,  1.61s/it]

Prediction: {hero: [], villain: ['republicans', 'vladimir putin'], victim: ['america']}
Ground truth: hero:[], villain:['republicans', 'vladimir putin'], victim:['america']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 58%|█████▊    | 414/718 [10:00<09:08,  1.80s/it]

Prediction: {hero: [], villain: ['democratic party'], victim: []}
Ground truth: hero:[], villain:['democratic party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 58%|█████▊    | 415/718 [10:03<09:58,  1.97s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 58%|█████▊    | 417/718 [10:06<09:19,  1.86s/it]

Prediction: {hero: [], villain: ['democratic party'], victim: ['women']}
Ground truth: hero:[], villain:['democratic party'], victim:['women']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 58%|█████▊    | 419/718 [10:10<09:37,  1.93s/it]

Prediction: {hero: [], villain: ['american', 'hillary clinton'], victim: ['donald trump', 'american people']}
Ground truth: hero:[], villain:['american', 'hillary clinton', 'donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.8000  ,victim: 0.0000


 59%|█████▉    | 423/718 [10:13<06:07,  1.25s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['libertarians'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


 59%|█████▉    | 424/718 [10:16<07:24,  1.51s/it]

Prediction: {hero: [], villain: ['libertarian party'], victim: []}
Ground truth: hero:[], villain:['libertarian party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000
Average accuracy 55.38%
Average f1 scores per class:    hero: 0.8683  ,villain: 0.7168  ,victim: 0.7550


 60%|█████▉    | 429/718 [10:19<05:18,  1.10s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: ['barack obama']}
Ground truth: hero:[], villain:['donald trump'], victim:['barack obama']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 60%|██████    | 431/718 [10:22<05:27,  1.14s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['joe biden'], victim:['donald trump']
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 0.0000


 60%|██████    | 433/718 [10:25<05:47,  1.22s/it]

Prediction: {hero: [], villain: ['barack obama'], victim: []}
Ground truth: hero:[], villain:[], victim:['barack obama']
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 0.0000


 60%|██████    | 434/718 [10:28<07:30,  1.59s/it]

Prediction: {hero: [], villain: ['joe biden'], victim: []}
Ground truth: hero:[], villain:['joe biden'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 61%|██████    | 436/718 [10:31<07:09,  1.52s/it]

Prediction: {hero: [], villain: ['democratic party'], victim: []}
Ground truth: hero:[], villain:['democratic party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 61%|██████    | 439/718 [10:34<06:11,  1.33s/it]

Prediction: {hero: [], villain: ['joe biden'], victim: ['america']}
Ground truth: hero:[], villain:['joe biden'], victim:['america']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 61%|██████▏   | 441/718 [10:37<05:51,  1.27s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['election'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


 62%|██████▏   | 442/718 [10:39<06:55,  1.50s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:['melania trump']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000


 62%|██████▏   | 444/718 [10:42<06:30,  1.43s/it]

Prediction: {hero: [], villain: ['atf'], victim: []}
Ground truth: hero:[], villain:['atf', 'donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.6667  ,victim: 1.0000


 62%|██████▏   | 445/718 [10:46<08:54,  1.96s/it]

Prediction: {hero: [], villain: ['teenage trump supporters', 'hillary clinton', 'liberals'], victim: []}
Ground truth: hero:['teenage trump supporters'], villain:['hillary clinton', 'liberals'], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 0.8000  ,victim: 1.0000


 63%|██████▎   | 450/718 [10:49<05:13,  1.17s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 63%|██████▎   | 452/718 [10:52<05:45,  1.30s/it]

Prediction: {hero: [], villain: ['bolsanaro'], victim: ['brazil']}
Ground truth: hero:[], villain:['bolsanaro'], victim:['brazil']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 63%|██████▎   | 453/718 [10:57<08:14,  1.86s/it]

Prediction: {hero: [], villain: ['joe biden'], victim: ['joe biden supporters', 'little girls']}
Ground truth: hero:[], villain:['joe biden', 'joe biden supporters'], victim:['little girls']
f1 scores per class:    hero: 1.0000  ,villain: 0.6667  ,victim: 0.6667


 63%|██████▎   | 455/718 [11:00<07:29,  1.71s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 64%|██████▎   | 456/718 [11:03<08:14,  1.89s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 64%|██████▎   | 457/718 [11:05<08:35,  1.98s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['asteroid', 'cyclone', 'corona', 'anxiety'], victim:['human']
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 0.0000


 64%|██████▍   | 462/718 [11:07<04:33,  1.07s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['2020'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


 64%|██████▍   | 463/718 [11:10<05:30,  1.30s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 65%|██████▌   | 468/718 [11:14<04:15,  1.02s/it]

Prediction: {hero: ['hiliary clinton'], villain: ['donald trump'], victim: ['donald trump']}
Ground truth: hero:[], villain:['hiliary clinton', 'donald trump'], victim:['donald trump']
f1 scores per class:    hero: 0.0000  ,villain: 0.6667  ,victim: 1.0000


 65%|██████▌   | 469/718 [11:18<06:03,  1.46s/it]

Prediction: {hero: ['gary johnson'], villain: ['hillary clinton', 'donald trump'], victim: []}
Ground truth: hero:['gary johnson'], villain:['hillary clinton', 'donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000
Average accuracy 55.80%
Average f1 scores per class:    hero: 0.8712  ,villain: 0.7143  ,victim: 0.7576


 66%|██████▌   | 472/718 [11:22<05:32,  1.35s/it]

Prediction: {hero: [], villain: ['grand old party (gop)'], victim: ['people']}
Ground truth: hero:[], villain:['grand old party (gop)'], victim:['people']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 66%|██████▌   | 474/718 [11:25<05:35,  1.38s/it]

Prediction: {hero: [], villain: ['barack obama'], victim: []}
Ground truth: hero:[], villain:['barack obama'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 66%|██████▌   | 475/718 [11:27<06:09,  1.52s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


 66%|██████▋   | 476/718 [11:29<06:36,  1.64s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['traffic'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


 67%|██████▋   | 478/718 [11:32<06:11,  1.55s/it]

Prediction: {hero: [], villain: ['democrats'], victim: []}
Ground truth: hero:[], villain:['democrats'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 67%|██████▋   | 479/718 [11:35<07:29,  1.88s/it]

Prediction: {hero: [], villain: ['hiliary clinton'], victim: ['america']}
Ground truth: hero:['hiliary clinton'], villain:[], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 0.0000  ,victim: 0.0000


 67%|██████▋   | 481/718 [11:40<08:23,  2.12s/it]

Prediction: {hero: [], villain: ['mitch mcconnell', 'grand old party (gop)'], victim: ['people', 'bernie sanders']}
Ground truth: hero:[], villain:['mitch mcconnell', 'grand old party (gop)'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000


 67%|██████▋   | 484/718 [11:45<07:26,  1.91s/it]

Prediction: {hero: [], villain: ['illegal', 'liberal'], victim: []}
Ground truth: hero:[], villain:['illegal', 'liberal'], victim:['donald trump']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000


 68%|██████▊   | 485/718 [11:48<08:23,  2.16s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: ['barack obama']}
Ground truth: hero:[], villain:['donald trump'], victim:['barack obama']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 68%|██████▊   | 487/718 [11:51<06:56,  1.80s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:['ads'], villain:[], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 1.0000  ,victim: 1.0000


 68%|██████▊   | 490/718 [11:53<05:09,  1.36s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['libertarian party', 'third party candidates'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


 69%|██████▊   | 492/718 [11:55<05:00,  1.33s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['sarscov2'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


 69%|██████▉   | 494/718 [11:58<05:12,  1.39s/it]

Prediction: {hero: [], villain: ['landon spradlin'], victim: []}
Ground truth: hero:[], villain:['landon spradlin'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 69%|██████▉   | 496/718 [12:01<05:11,  1.40s/it]

Prediction: {hero: [], villain: ['mike pence'], victim: []}
Ground truth: hero:[], villain:['mike pence'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 69%|██████▉   | 497/718 [12:05<06:54,  1.87s/it]

Prediction: {hero: [], villain: ['donald trump', 'kamala harris'], victim: ['jewish']}
Ground truth: hero:['donald trump'], villain:['kamala harris'], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 0.6667  ,victim: 0.0000


 69%|██████▉   | 498/718 [12:11<09:20,  2.55s/it]

Prediction: {hero: ['alexandria ocasio-cortez'], villain: [], victim: ['medicaid', 'illegals']}
Ground truth: hero:[], villain:['alexandria ocasio-cortez'], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 0.0000  ,victim: 0.0000


 69%|██████▉   | 499/718 [12:14<09:36,  2.63s/it]

Prediction: {hero: [], villain: ['chinese'], victim: []}
Ground truth: hero:[], villain:[], victim:['chinese']
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 0.0000


 70%|██████▉   | 501/718 [12:16<07:42,  2.13s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 71%|███████   | 508/718 [12:20<03:36,  1.03s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 71%|███████   | 509/718 [12:22<04:18,  1.24s/it]

Prediction: {hero: [], villain: [], victim: ['american people']}
Ground truth: hero:[], villain:[], victim:['american people']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000
Average accuracy 55.17%
Average f1 scores per class:    hero: 0.8653  ,villain: 0.7076  ,victim: 0.7528


 71%|███████   | 511/718 [12:25<04:12,  1.22s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:['hajia farouq sadiya'], villain:[], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 1.0000  ,victim: 1.0000


 71%|███████▏  | 512/718 [12:28<05:21,  1.56s/it]

Prediction: {hero: [], villain: ['conservative', 'liberal'], victim: []}
Ground truth: hero:['conservative'], villain:['liberal'], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 0.6667  ,victim: 1.0000


 72%|███████▏  | 515/718 [12:31<04:24,  1.30s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:['nation']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000


 73%|███████▎  | 521/718 [12:34<03:03,  1.08it/s]

Prediction: {hero: [], villain: ['vladimir putin', 'republican party'], victim: []}
Ground truth: hero:[], villain:['vladimir putin', 'republican party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 73%|███████▎  | 523/718 [12:37<03:17,  1.01s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 73%|███████▎  | 526/718 [12:42<03:41,  1.15s/it]

Prediction: {hero: [], villain: ['joe biden'], victim: []}
Ground truth: hero:[], villain:['joe biden'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 73%|███████▎  | 527/718 [12:44<04:14,  1.33s/it]

Prediction: {hero: [], villain: ['senators'], victim: []}
Ground truth: hero:[], villain:['senators', 'gun lobby'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.6667  ,victim: 1.0000


 74%|███████▍  | 531/718 [12:47<03:27,  1.11s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 74%|███████▍  | 532/718 [12:51<04:19,  1.40s/it]

Prediction: {hero: [], villain: ['barack obama'], victim: ['donald trump']}
Ground truth: hero:[], villain:['barack obama'], victim:['donald trump']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 75%|███████▍  | 535/718 [12:53<03:43,  1.22s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:[], victim:['donald trump']
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 0.0000


 75%|███████▍  | 536/718 [12:56<04:09,  1.37s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:['gen x'], villain:['boomers', 'millenials'], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 0.0000  ,victim: 1.0000


 75%|███████▍  | 537/718 [12:58<04:53,  1.62s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 75%|███████▌  | 539/718 [13:01<04:32,  1.52s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 76%|███████▌  | 544/718 [13:04<03:00,  1.04s/it]

Prediction: {hero: [], villain: [], victim: ['dudes on dating apps']}
Ground truth: hero:[], villain:[], victim:['dudes on dating apps']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 76%|███████▌  | 546/718 [13:08<03:37,  1.27s/it]

Prediction: {hero: ['anthony fauci'], villain: ['dr. anthony fauci'], victim: []}
Ground truth: hero:[], villain:['anthony fauci', 'dr. anthony fauci', 'covid19', 'donald trump'], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 0.4000  ,victim: 1.0000


 77%|███████▋  | 551/718 [13:12<02:50,  1.02s/it]

Prediction: {hero: [], villain: ['joe biden'], victim: []}
Ground truth: hero:[], villain:['joe biden'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 77%|███████▋  | 552/718 [13:15<03:34,  1.29s/it]

Prediction: {hero: [], villain: ['reagan', 'conservatives'], victim: []}
Ground truth: hero:[], villain:['reagan', 'conservatives', 'fox news', 'donald trump'], victim:['liberal media']
f1 scores per class:    hero: 1.0000  ,villain: 0.6667  ,victim: 0.0000


 77%|███████▋  | 556/718 [13:18<02:52,  1.07s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 78%|███████▊  | 557/718 [13:21<03:19,  1.24s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['chris wallace'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


 78%|███████▊  | 558/718 [13:23<03:57,  1.49s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000
Average accuracy 55.70%
Average f1 scores per class:    hero: 0.8603  ,villain: 0.7124  ,victim: 0.7603


 79%|███████▊  | 564/718 [13:26<02:10,  1.18it/s]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:[], victim:['economy']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000


 79%|███████▊  | 565/718 [13:28<02:33,  1.01s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:['barack obama'], villain:[], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 1.0000  ,victim: 1.0000


 79%|███████▉  | 569/718 [13:30<02:09,  1.15it/s]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 79%|███████▉  | 570/718 [13:33<02:49,  1.15s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: ['america']}
Ground truth: hero:[], villain:['donald trump'], victim:['america']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 80%|███████▉  | 572/718 [13:36<02:54,  1.19s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 80%|███████▉  | 573/718 [13:39<03:30,  1.45s/it]

Prediction: {hero: [], villain: ['state laws'], victim: ['people']}
Ground truth: hero:[], villain:['state laws'], victim:['people']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 80%|███████▉  | 574/718 [13:43<04:28,  1.86s/it]

Prediction: {hero: [], villain: ['adolf hitler', 'donald trump'], victim: []}
Ground truth: hero:[], villain:['adolf hitler', 'donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 80%|████████  | 576/718 [13:46<04:10,  1.77s/it]

Prediction: {hero: [], villain: ['reopen america protestors'], victim: []}
Ground truth: hero:[], villain:['reopen america protestors'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 80%|████████  | 577/718 [13:50<05:23,  2.29s/it]

Prediction: {hero: [], villain: ['barack obama', 'george bush', 'donald trump'], victim: ['americans']}
Ground truth: hero:['barack obama', 'george bush'], villain:['donald trump'], victim:['americans']
f1 scores per class:    hero: 0.0000  ,villain: 0.5000  ,victim: 1.0000


 81%|████████  | 578/718 [13:53<05:26,  2.33s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['new york times editorial board'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


 81%|████████▏ | 584/718 [13:56<02:31,  1.13s/it]

Prediction: {hero: [], villain: ['randy bryce'], victim: []}
Ground truth: hero:[], villain:['randy bryce'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 81%|████████▏ | 585/718 [13:58<02:58,  1.34s/it]

Prediction: {hero: [], villain: ['capitalism'], victim: []}
Ground truth: hero:[], villain:['capitalism'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 82%|████████▏ | 586/718 [14:00<03:14,  1.47s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:['green'], villain:[], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 1.0000  ,victim: 1.0000


 82%|████████▏ | 588/718 [14:03<03:10,  1.47s/it]

Prediction: {hero: [], villain: ['republican party'], victim: []}
Ground truth: hero:[], villain:['republican party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 82%|████████▏ | 589/718 [14:07<03:52,  1.80s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: ['barack obama']}
Ground truth: hero:['donald trump'], villain:['barack obama'], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 0.0000  ,victim: 0.0000


 83%|████████▎ | 594/718 [14:10<02:17,  1.11s/it]

Prediction: {hero: [], villain: ['irish'], victim: ['america']}
Ground truth: hero:[], villain:[], victim:['irish']
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 0.0000


 83%|████████▎ | 598/718 [14:12<01:52,  1.07it/s]

Prediction: {hero: [], villain: ['moslem'], victim: []}
Ground truth: hero:[], villain:['moslem'], victim:['president', 'constitution']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000


 84%|████████▎ | 601/718 [14:15<01:46,  1.10it/s]

Prediction: {hero: [], villain: ['china'], victim: []}
Ground truth: hero:[], villain:['china'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 84%|████████▍ | 603/718 [14:18<01:57,  1.02s/it]

Prediction: {hero: [], villain: ['democratic party'], victim: []}
Ground truth: hero:[], villain:['democratic party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 84%|████████▍ | 604/718 [14:21<02:24,  1.27s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000
Average accuracy 56.13%
Average f1 scores per class:    hero: 0.8560  ,villain: 0.7204  ,victim: 0.7631


 84%|████████▍ | 605/718 [14:23<02:42,  1.44s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['qanon'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


 85%|████████▍ | 608/718 [14:25<02:11,  1.20s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['west nile virus', 'bird flu', 'ebola', 'swine flu', 'sars', 'coronavirus', 'zika virus'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


 85%|████████▍ | 610/718 [14:30<02:36,  1.45s/it]

Prediction: {hero: [], villain: [], victim: ['labor jobs']}
Ground truth: hero:[], villain:['labor jobs', 'poor people'], victim:['capitalism']
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 0.0000


 85%|████████▌ | 611/718 [14:32<02:57,  1.66s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 85%|████████▌ | 612/718 [14:36<03:47,  2.14s/it]

Prediction: {hero: ['gary johnson'], villain: ['hillary clinton', 'donald trump'], victim: []}
Ground truth: hero:['gary johnson'], villain:['hillary clinton', 'donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 86%|████████▌ | 615/718 [14:42<03:21,  1.95s/it]

Prediction: {hero: [], villain: ['democratic party', 'nancy pelosi', 'maxine waters', 'elizabeth warrens'], victim: []}
Ground truth: hero:[], villain:['democratic party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.4000  ,victim: 1.0000


 86%|████████▌ | 618/718 [14:45<02:36,  1.57s/it]

Prediction: {hero: [], villain: ['joe biden'], victim: []}
Ground truth: hero:[], villain:[], victim:['joe biden', 'moderator']
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 0.0000


 86%|████████▌ | 619/718 [14:48<03:03,  1.85s/it]

Prediction: {hero: [], villain: ['joe biden'], victim: []}
Ground truth: hero:[], villain:['joe biden'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 87%|████████▋ | 623/718 [14:51<02:05,  1.32s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: ['people']}
Ground truth: hero:[], villain:['donald trump'], victim:['people']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 87%|████████▋ | 626/718 [14:53<01:46,  1.16s/it]

Prediction: {hero: [], villain: [], victim: ['black people']}
Ground truth: hero:[], villain:['black people', 'poll workers'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 0.0000


 88%|████████▊ | 629/718 [14:56<01:35,  1.07s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 88%|████████▊ | 630/718 [14:59<01:54,  1.31s/it]

Prediction: {hero: [], villain: ['democrat party'], victim: []}
Ground truth: hero:[], villain:['democrat party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 88%|████████▊ | 631/718 [15:02<02:16,  1.57s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: ['usa']}
Ground truth: hero:[], villain:['donald trump'], victim:['usa']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 88%|████████▊ | 634/718 [15:04<01:43,  1.24s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:['libertarian party'], villain:['republican liberty caucus'], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 0.0000  ,victim: 1.0000


 89%|████████▊ | 636/718 [15:07<01:44,  1.27s/it]

Prediction: {hero: ['donald trump'], villain: [], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 0.0000  ,victim: 1.0000


 89%|████████▊ | 637/718 [15:11<02:21,  1.75s/it]

Prediction: {hero: ['jeffrey epstein'], villain: ['donald trump'], victim: ['13 year old girl']}
Ground truth: hero:[], villain:['jeffrey epstein', 'donald trump'], victim:['13 year old girl']
f1 scores per class:    hero: 0.0000  ,villain: 0.6667  ,victim: 1.0000


 89%|████████▉ | 641/718 [15:14<01:36,  1.26s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 89%|████████▉ | 642/718 [15:17<01:54,  1.51s/it]

Prediction: {hero: ['nancy pelosi'], villain: [], victim: []}
Ground truth: hero:[], villain:['nancy pelosi'], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 0.0000  ,victim: 1.0000


 91%|█████████ | 651/718 [15:20<00:49,  1.34it/s]

Prediction: {hero: [], villain: ['white house (wh)'], victim: []}
Ground truth: hero:[], villain:['white house (wh)', 'donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.6667  ,victim: 1.0000


 91%|█████████ | 653/718 [15:23<00:59,  1.10it/s]

Prediction: {hero: [], villain: ['barack obama'], victim: ['americans']}
Ground truth: hero:[], villain:['barack obama'], victim:['americans', 'donald trump']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.6667
Average accuracy 55.55%
Average f1 scores per class:    hero: 0.8522  ,villain: 0.7082  ,victim: 0.7678


 91%|█████████ | 654/718 [15:26<01:10,  1.10s/it]

Prediction: {hero: [], villain: ['green party'], victim: []}
Ground truth: hero:[], villain:['green party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 93%|█████████▎| 669/718 [15:29<00:23,  2.12it/s]

Prediction: {hero: ['narendra modi'], villain: [], victim: []}
Ground truth: hero:['narendra modi'], villain:[], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 94%|█████████▎| 672/718 [15:32<00:25,  1.84it/s]

Prediction: {hero: ['genovia'], villain: [], victim: []}
Ground truth: hero:['genovia'], villain:[], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 94%|█████████▍| 674/718 [15:34<00:28,  1.53it/s]

Prediction: {hero: [], villain: ['government'], victim: []}
Ground truth: hero:[], villain:['government'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 94%|█████████▍| 677/718 [15:37<00:29,  1.38it/s]

Prediction: {hero: [], villain: ['republican party'], victim: []}
Ground truth: hero:[], villain:['republican party'], victim:['transportation security administration (tsa)']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000


 94%|█████████▍| 678/718 [15:40<00:35,  1.14it/s]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['non commissioned officer (nco)'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


 95%|█████████▍| 682/718 [15:42<00:29,  1.22it/s]

Prediction: {hero: [], villain: ['republican party'], victim: []}
Ground truth: hero:['republican party'], villain:[], victim:[]
f1 scores per class:    hero: 0.0000  ,villain: 0.0000  ,victim: 1.0000


 95%|█████████▌| 684/718 [15:45<00:31,  1.08it/s]

Prediction: {hero: [], villain: ['government'], victim: []}
Ground truth: hero:[], villain:['government'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 95%|█████████▌| 685/718 [15:48<00:38,  1.17s/it]

Prediction: {hero: [], villain: ['republican party'], victim: []}
Ground truth: hero:[], villain:['republican party', 'donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.6667  ,victim: 1.0000


 96%|█████████▌| 686/718 [15:50<00:43,  1.36s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['virus'], victim:['china', 'asia']
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 0.0000


 96%|█████████▌| 687/718 [15:55<00:59,  1.92s/it]

Prediction: {hero: [], villain: ['joe biden'], victim: []}
Ground truth: hero:[], villain:['joe biden', 'liberals'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.6667  ,victim: 1.0000


 96%|█████████▌| 688/718 [15:59<01:10,  2.34s/it]

Prediction: {hero: [], villain: ['democrat party', 'alexandria cortez'], victim: []}
Ground truth: hero:[], villain:['democrat party'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.6667  ,victim: 1.0000


 96%|█████████▌| 690/718 [16:03<01:02,  2.22s/it]

Prediction: {hero: ['gary johnson'], villain: ['hillary clinton'], victim: ['american citizens']}
Ground truth: hero:['gary johnson'], villain:['hillary clinton'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 0.0000


 96%|█████████▌| 691/718 [16:07<01:09,  2.57s/it]

Prediction: {hero: [], villain: ['adolf hitler', 'donald trump'], victim: ['muslims']}
Ground truth: hero:[], villain:['adolf hitler', 'donald trump'], victim:['muslims']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 97%|█████████▋| 697/718 [16:09<00:25,  1.20s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:[], victim:['donald trump']
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 0.0000


 97%|█████████▋| 698/718 [16:13<00:31,  1.56s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: ['economy', 'people']}
Ground truth: hero:[], villain:['donald trump'], victim:['economy', 'people']
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 97%|█████████▋| 699/718 [16:17<00:36,  1.89s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 98%|█████████▊| 704/718 [16:20<00:16,  1.17s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 98%|█████████▊| 705/718 [16:22<00:17,  1.34s/it]

Prediction: {hero: [], villain: [], victim: []}
Ground truth: hero:[], villain:['london'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 0.0000  ,victim: 1.0000


 98%|█████████▊| 706/718 [16:25<00:18,  1.58s/it]

Prediction: {hero: [], villain: ['libertarians'], victim: []}
Ground truth: hero:[], villain:['libertarians'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000
Average accuracy 55.83%
Average f1 scores per class:    hero: 0.8583  ,villain: 0.7077  ,victim: 0.7698


 99%|█████████▉| 712/718 [16:27<00:05,  1.10it/s]

Prediction: {hero: [], villain: ['parents'], victim: []}
Ground truth: hero:[], villain:['parents'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


 99%|█████████▉| 714/718 [16:30<00:04,  1.01s/it]

Prediction: {hero: [], villain: ['hillary clinton'], victim: []}
Ground truth: hero:[], villain:['hillary clinton'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000


100%|██████████| 718/718 [16:33<00:00,  1.38s/it]

Prediction: {hero: [], villain: ['donald trump'], victim: []}
Ground truth: hero:[], villain:['donald trump'], victim:[]
f1 scores per class:    hero: 1.0000  ,villain: 1.0000  ,victim: 1.0000
0.5623839009287926





In [63]:
print("Average accuracy {:.2f}%".format((avg_acc) * 100))
print("Average f1 scores per class:    hero: {:.4f}  ,villain: {:.4f}  ,victim: {:.4f}".format(ef1s[0] / n, ef1s[1] / n, ef1s[2] / n))
print("Macro f1 score: ", ((ef1s[0] / n) + (ef1s[1] / n) + (ef1s[2] / n)) / 3)

Average accuracy 56.24%
Average f1 scores per class:    hero: 0.8596  ,villain: 0.7104  ,victim: 0.7719
Macro f1 score:  0.7806558880862284


In [65]:
classes = [0, 0, 0]
len_td = 0

for idx, en in enumerate(test_dataset):
    if idx in edge_cases_idx:
        continue
    classes[0] += len(en["hero"])
    classes[1] += len(en["villain"])
    classes[2] += len(en["victim"])
    len_td += 1
    
print(f"Hero entities: {classes[0]}, Villain entities: {classes[1]}, Victim entities: {classes[2]}, Test dataset len: {len_td}")

Hero entities: 52, Villain entities: 350, Victim entities: 114, Test dataset len: 323


model.eval()

image = example["image"]
query = example["query"]

messages = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "Answer briefly."},
            {"type": "image"},
            {"type": "text", "text": query["en"]}
        ]
    }
]
text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=[text.strip()], images=[image], return_tensors="pt", padding=True)
generated_ids = model.generate(**inputs, max_new_tokens=64)
generated_texts = processor.batch_decode(generated_ids[:, inputs["input_ids"].size(1):], skip_special_tokens=True)
print(generated_texts)

During the training, we tracked the loss on the evaluation split. It is interesting to measure the performance using the "true metric" used for DocVQA.

The metric at hand is the *Average Normalized Levenshtein Similarity* (ANLS). The Average Normalized Levenshtein Similarity (ANLS) proposed by [Biten+ ICCV'19](https://arxiv.org/abs/1905.13648) smoothly captures the OCR mistakes applying a slight penalization in case of correct intended responses, but badly recognized. It also makes use of a threshold of value 0.5 that dictates whether the output of the metric will be the ANLS if its value is equal or bigger than 0.5 or 0 otherwise. The key point of this threshold is to determine if the answer has been correctly selected but not properly recognized, or on the contrary, the output is a wrong text selected from the options and given as an answer.

We first define a few utilities to compute the ANLS.

!pip install Levenshtein

import Levenshtein

def normalized_levenshtein(s1, s2):
    len_s1, len_s2 = len(s1), len(s2)
    distance = Levenshtein.distance(s1, s2)
    return distance / max(len_s1, len_s2)

def similarity_score(a_ij, o_q_i, tau=0.5):
    nl = normalized_levenshtein(a_ij, o_q_i)
    return 1 - nl if nl < tau else 0

def average_normalized_levenshtein_similarity(ground_truth, predicted_answers):
    assert len(ground_truth) == len(predicted_answers), "Length of ground_truth and predicted_answers must match."

    N = len(ground_truth)
    total_score = 0

    for i in range(N):
        a_i = ground_truth[i]
        o_q_i = predicted_answers[i]
        if o_q_i == "":
            print("Warning: Skipped an empty prediction.")
            max_score = 0
        else:
            max_score = max(similarity_score(a_ij, o_q_i) for a_ij in a_i)

        total_score += max_score

    return total_score / N


# Some gpu mem cleaning before inferencing eval. necessary because we are in memory constrained env
torch.cuda.empty_cache()


from tqdm import tqdm

EVAL_BATCH_SIZE = 1

answers_unique = []
generated_texts_unique = []

for i in tqdm(range(0, len(eval_dataset), EVAL_BATCH_SIZE)):
    examples = eval_dataset[i: i + EVAL_BATCH_SIZE]
    answers_unique.extend(examples["answers"])
    images = [[im] for im in examples["image"]]
    texts = []
    for q in examples["query"]:
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": "Answer briefly."},
                    {"type": "image"},
                    {"type": "text", "text": q["en"]}
                ]
            }
        ]
        text = processor.apply_chat_template(messages, add_generation_prompt=True)
        texts.append(text.strip())
    inputs = processor(text=texts, images=images, return_tensors="pt", padding=True)
    generated_ids = model.generate(**inputs, max_new_tokens=64)
    generated_texts = processor.batch_decode(generated_ids[:, inputs["input_ids"].size(1):], skip_special_tokens=True)
    generated_texts_unique.extend(generated_texts)


generated_texts_unique = [g.strip().strip(".") for g in generated_texts_unique]
anls = average_normalized_levenshtein_similarity(
    ground_truth=answers_unique, predicted_answers=generated_texts_unique,
)
print(anls)


dataset.map(lambda example, idx: {"sentence2": f"{idx}: " + example["sentence2"]}, with_indices=True)We obtain an ANLS score of ~60. This is relatively low compared to well-trained models on DocVQA, although keep in mind that we are training and evaluating on a relatively small subset of the data as an exercise.

You should now have all the tools you need to fine-tuned Idefics-2 on your own dataset!