In [1]:
import os
import sys
sys.path.append(f"{os.environ['MINERVA_HOME']}/code")
from mil_model import MILModel
from mil_dataset import MILTwitterDataset, get_acled_labels
from transformers import AutoTokenizer
import torch
from tqdm import tqdm
import jsonlines

In [2]:
results_dir = f"{os.environ['MINERVA_HOME']}/models/test"
seed = 42
instance_model = "vinai/bertweet-base"

In [3]:
tokenizer = AutoTokenizer.from_pretrained(instance_model)

# Set up dataset
# Ground truth labels from ACLED
# Set up model
model = MILModel(
    instance_model_path=instance_model,
    key_instance_ratio=0.2,
    finetune_instance_model=False
)
tokenizer = AutoTokenizer.from_pretrained(instance_model)

# Set up dataset
positive_bags = get_acled_labels()
eval_dataset = MILTwitterDataset.from_glob(
    f"{os.environ['MINERVA_HOME']}/data/tweets_en/2017_.*.gz",
    positive_bags, 
    tokenizer,
    samples_per_file=10,
    shuffle_samples=False, 
    random_seed=seed
)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Some weights of the model checkpoint at vinai/bertweet-base were not used when initializing RobertaForSequenceClassification: ['lm_head.dense.bias', 'roberta.pooler.dense.bias', 'roberta.pooler.dense.weight', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.decoder.bias', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSe

In [4]:
dataloader = torch.utils.data.DataLoader(
    eval_dataset, 
    shuffle=False, 
    batch_size=2, 
    collate_fn=eval_dataset.collate_function
)
for i, batch in enumerate(tqdm(dataloader, ncols=0)):
    print(batch["bag_id"], batch["instance_ids"])
    break


  0% 0/7240 [00:00<?, ?it/s]

['/home/aadelucia/MIL-civil-unrest/data/tweets_en/2017_07_24_AO.gz', '/home/aadelucia/MIL-civil-unrest/data/tweets_en/2017_04_27_TD.gz'] [['889523638193573888', '889531632641822720', '889531804818059264', '889542912157077504', '889557710219431940', '889569562320138241', '889569722232123392', '889570348034904065', '889570862365630465', '889572339616186371'], ['857454518342213632', '857494792439291904', '857496537626181632', '857643764478861313']]





In [6]:
output_file = f"{results_dir}/eval_predictions.json"
model = model.eval()
with jsonlines.open(output_file, "w") as f:
    with torch.inference_mode():
        for i, batch in enumerate(tqdm(dataloader, ncols=0)):
            # Run inference
            output = model(**batch)

            # Print the key instances
            for filename, label, ids, text, prob, key_ids, tweet_probs in zip(batch["bag_id"], batch["labels"], batch["instance_ids"], batch["instance_text"], output.logits, output.key_instances, output.instance_probs):
                row = {
                    "filename": filename,
                    "probability": prob.item(),
                    "prediction": 1 if prob.item() > 0.5 else 0,
                    "label": int(label.item()),
                    "key_tweet_ids": key_ids,
                    "key_tweets": [text[ids.index(t)] for t in key_ids],
                    # "key_tweet_probs": list(tweet_probs),
                }
                f.write(row)


  0% 0/7240 [00:00<?, ?it/s]

> [0;32m/home/aadelucia/MIL-civil-unrest/code/mil_model.py[0m(141)[0;36mforward[0;34m()[0m
[0;32m    139 [0;31m        [0mkey_instances[0m [0;34m=[0m [0;34m[[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    140 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m[0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 141 [0;31m        [0;32mfor[0m [0mb_idx[0m[0;34m,[0m [0mb_id[0m [0;32min[0m [0mzip[0m[0;34m([0m[0mkey_instance_idx[0m[0;34m,[0m [0mkwargs[0m[0;34m[[0m[0;34m"instance_ids"[0m[0;34m][0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    142 [0;31m            [0mkey_instances[0m[0;34m.[0m[0mappend[0m[0;34m([0m[0;34m[[0m[0mb_id[0m[0;34m[[0m[0mi[0m[0;34m][0m [0;32mfor[0m [0mi[0m [0;32min[0m [0mb_idx[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    143 [0;31m        [0;31m# Calculate loss[0m[0;34m[0m[0;34m[0m[0;3

  0% 1/7240 [00:02<4:07:37,  2.05s/it]

> [0;32m/home/aadelucia/MIL-civil-unrest/code/mil_model.py[0m(141)[0;36mforward[0;34m()[0m
[0;32m    139 [0;31m        [0mkey_instances[0m [0;34m=[0m [0;34m[[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    140 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m[0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 141 [0;31m        [0;32mfor[0m [0mb_idx[0m[0;34m,[0m [0mb_id[0m [0;32min[0m [0mzip[0m[0;34m([0m[0mkey_instance_idx[0m[0;34m,[0m [0mkwargs[0m[0;34m[[0m[0;34m"instance_ids"[0m[0;34m][0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    142 [0;31m            [0mkey_instances[0m[0;34m.[0m[0mappend[0m[0;34m([0m[0;34m[[0m[0mb_id[0m[0;34m[[0m[0mi[0m[0;34m][0m [0;32mfor[0m [0mi[0m [0;32min[0m [0mb_idx[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    143 [0;31m        [0;31m# Calculate loss[0m[0;34m[0m[0;34m[0m[0;3

  0% 2/7240 [00:05<5:49:08,  2.89s/it]

> [0;32m/home/aadelucia/MIL-civil-unrest/code/mil_model.py[0m(141)[0;36mforward[0;34m()[0m
[0;32m    139 [0;31m        [0mkey_instances[0m [0;34m=[0m [0;34m[[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    140 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m[0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 141 [0;31m        [0;32mfor[0m [0mb_idx[0m[0;34m,[0m [0mb_id[0m [0;32min[0m [0mzip[0m[0;34m([0m[0mkey_instance_idx[0m[0;34m,[0m [0mkwargs[0m[0;34m[[0m[0;34m"instance_ids"[0m[0;34m][0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    142 [0;31m            [0mkey_instances[0m[0;34m.[0m[0mappend[0m[0;34m([0m[0;34m[[0m[0mb_id[0m[0;34m[[0m[0mi[0m[0;34m][0m [0;32mfor[0m [0mi[0m [0;32min[0m [0mb_idx[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    143 [0;31m        [0;31m# Calculate loss[0m[0;34m[0m[0;34m[0m[0;3

  0% 2/7240 [00:35<35:55:08, 17.87s/it]


BdbQuit: 