In [1]:
kwargs = {
    "seed": 42,
    "data_dir": "data/",
    "train_dir": "outputs/multi_task_model",
    "model_file": "outputs/pytorch_model.bin",
    "model_id": "1U6Ek3c75RjxypFAj7_B-yfQ9NyDNk-eS",
    "num_past_utterances": 10,
    "num_future_utterances": 0,
    "speaker_in_context": False,
    "epoch": 6,
    "learning_rate": 1e-5,
    "batch_size": 8,
    "do_train": True,
    "checkpoint": "markussagen/xlm-roberta-longformer-base-4096",
    "train_dataset": ["MELD"],
    "train_task": ["Emotion"],
    "eval_dataset": "MELD",
    "eval_task": "Emotion",
    'output_file': 'outputs/predictions.out',
    "result_file": "results/scores.out",
}

###Initialize

In [2]:
from google.colab import drive
if kwargs["do_train"]:
  drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [3]:
!mkdir -p data/MELD
!cd data/MELD && wget https://raw.githubusercontent.com/UW-ling573-2022/data/main/MELD/dev_sent_emo.csv
!cd data/MELD && wget https://raw.githubusercontent.com/UW-ling573-2022/data/main/MELD/test_sent_emo.csv
!cd data/MELD && wget https://raw.githubusercontent.com/UW-ling573-2022/data/main/MELD/train_sent_emo.csv
!mkdir -p data/EMORYNLP
!cd data/EMORYNLP && wget https://raw.githubusercontent.com/UW-ling573-2022/data/main/EMORYNLP/dev.csv
!cd data/EMORYNLP && wget https://raw.githubusercontent.com/UW-ling573-2022/data/main/EMORYNLP/train.csv
!cd data/EMORYNLP && wget https://raw.githubusercontent.com/UW-ling573-2022/data/main/EMORYNLP/test.csv
!mkdir -p data/MPDD
!cd data/MPDD && wget https://raw.githubusercontent.com/UW-ling573-2022/data/main/MPDD/dev_mpdd.csv
!cd data/MPDD && wget https://raw.githubusercontent.com/UW-ling573-2022/data/main/MPDD/train_mpdd.csv
!cd data/MPDD && wget https://raw.githubusercontent.com/UW-ling573-2022/data/main/MPDD/test_mpdd.csv

--2022-05-29 08:19:14--  https://raw.githubusercontent.com/UW-ling573-2022/data/main/MELD/dev_sent_emo.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 120527 (118K) [text/plain]
Saving to: ‘dev_sent_emo.csv.8’


2022-05-29 08:19:14 (8.53 MB/s) - ‘dev_sent_emo.csv.8’ saved [120527/120527]

--2022-05-29 08:19:14--  https://raw.githubusercontent.com/UW-ling573-2022/data/main/MELD/test_sent_emo.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 291912 (285K) [text/plain]
Saving to: ‘test_sent_emo.csv.8’


2022-05-29 08:19:14 (15.

In [4]:
!mkdir -p outputs
import gdown
if kwargs["model_id"] and not kwargs["do_train"]:
  gdown.download(id=kwargs["model_id"], output=kwargs["model_file"], quiet=False)

In [5]:
!pip install datasets
!pip install transformers
!pip install sentencepiece

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [6]:
import datasets
import numpy as np
import torch
import torch.nn as nn
from datasets import ClassLabel, load_metric, Dataset
from datasets import load_dataset, concatenate_datasets
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments
from transformers import PreTrainedModel, PretrainedConfig
from transformers import Trainer
from transformers import is_datasets_available
from transformers.trainer_pt_utils import IterableDatasetShard

###MTL/data.py

In [7]:
class SingleTaskDataLoader:
    def __init__(self, task, **kwargs):
        self.task = task
        self.data_loader = DataLoader(**kwargs)
        self.batch_size = self.data_loader.batch_size
        self.dataset = self.data_loader.dataset
        
    def __len__(self) -> int:
        return len(self.data_loader)
    
    def __iter__(self):
        for batch in self.data_loader:
            batch["task"] = self.task
            yield batch
    
class MultiTaskDataLoader:
    def __init__(self, task_data_loaders):
        self.task_data_loaders = task_data_loaders
        self.dataset = [None] * sum([len(dl.dataset) for dl in task_data_loaders.values()])
        
    def __len__(self) -> int:
        return sum([len(dl) for dl in self.task_data_loaders.values()])
    
    def __iter__(self):
        task_choices = []
        for task, dl in self.task_data_loaders.items():
            task_choices.extend([task] * len(dl))
        task_choices = np.array(task_choices)
        np.random.shuffle(task_choices)
        for task in task_choices:
            yield next(iter(self.task_data_loaders[task]))

###MTL/model.py

In [8]:
class MultiTaskModel(PreTrainedModel):
    def __init__(self, encoder, task_models):
        super(MultiTaskModel, self).__init__(PretrainedConfig())
        self.encoder = encoder
        self.task_models = nn.ModuleDict(task_models)
        
    @classmethod
    def from_task_models(cls, task_models):
        shared_encoder = None
        for model in task_models.values():
            if shared_encoder is None:
                shared_encoder = getattr(model, cls.get_encoder_attr_name(model))
            else:
                setattr(model, cls.get_encoder_attr_name(model), shared_encoder)
        return cls(shared_encoder, task_models)
                  
    @staticmethod
    def get_encoder_attr_name(model):
        model_name = model.__class__.__name__
        if model_name.startswith('Bert'):
            return 'bert'
        elif model_name.startswith('Roberta') or model_name.startswith("XLMRoberta"):
            return 'roberta'
        elif model_name.startswith('Albert'):
            return 'albert'
        else:
            raise ValueError('Unsupported model: {}'.format(model_name))
        
    def forward(self, task, input_ids, attention_mask, **kwargs):
        model = self.task_models[task]
        return model(input_ids, attention_mask, **kwargs)

###MLT/train.py

In [9]:
class MultiTaskTrainer(Trainer):

    def get_single_task_dataloader(self, task, dataset, description):
        if description == "training" and self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        elif description == "evaluation" and dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")

        if is_datasets_available() and isinstance(dataset, Dataset):
            dataset = self._remove_unused_columns(dataset, description=description)

        if isinstance(dataset, torch.utils.data.IterableDataset):
            if self.args.world_size > 1:
                dataset = IterableDatasetShard(
                    dataset,
                    batch_size=self.args.train_batch_size,
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )

            return SingleTaskDataLoader(
                task,
                dataset=dataset,
                batch_size=self.args.per_device_train_batch_size,
                collate_fn=self.data_collator,
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

        if description == "training":
            self.train_dataset, dataset = dataset, self.train_dataset
            sampler = self._get_train_sampler()
            self.train_dataset, dataset = dataset, self.train_dataset
            batch_size = self.args.train_batch_size
        else:
            sampler = self._get_eval_sampler(dataset)
            batch_size = self.args.eval_batch_size

        return SingleTaskDataLoader(
            task,
            dataset=dataset,
            batch_size=batch_size,
            sampler=sampler,
            collate_fn=self.data_collator,
            drop_last=self.args.dataloader_drop_last,
            num_workers=self.args.dataloader_num_workers,
            pin_memory=self.args.dataloader_pin_memory,
        )

    def get_train_dataloader(self):
        return MultiTaskDataLoader({
            task: self.get_single_task_dataloader(task, dataset, description="training")
            for task, dataset in self.train_dataset.items()
        })

    def get_eval_dataloader(self, eval_dataset=None):
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
        task_to_eval = eval_dataset["task"]
        return self.get_single_task_dataloader(task_to_eval, eval_dataset[task_to_eval], description="evaluation")

    def get_test_dataloader(self, test_dataset):
        task_to_test = test_dataset["task"]
        return self.get_single_task_dataloader(task_to_test, test_dataset[task_to_test], description="test")


###preprocess.py

In [10]:
def preprocess(tokenizer, dataset_labels, **kwargs):

    meld_files = {
        "train": kwargs["data_dir"] + "MELD/train_sent_emo.csv", 
        "validation": kwargs["data_dir"] + "MELD/dev_sent_emo.csv",
        "test": kwargs["data_dir"] + "MELD/test_sent_emo.csv"
    }
    
    emorynlp_files = {
        "train": kwargs["data_dir"] + "EMORYNLP/train.csv",
        "validation": kwargs["data_dir"] + "EMORYNLP/dev.csv",
        "test": kwargs["data_dir"] + "EMORYNLP/test.csv"
    }
    
    mpdd_files = {
        "train": kwargs["data_dir"] + "MPDD/train_mpdd.csv", 
        "validation": kwargs["data_dir"] + "MPDD/dev_mpdd.csv",
        "test": kwargs["data_dir"] + "MPDD/test_mpdd.csv"
    }
    
    datasets = {"MELD": load_dataset("csv", data_files=meld_files),
                "EmoryNLP": load_dataset("csv", data_files=emorynlp_files),
                "MPDD": load_dataset("csv", data_files=mpdd_files)}
    
    def encode_label(example, labels):
        for task, label in labels.items():
            if task == "Speaker":
                example[task] = label.str2int(example[task]) \
                    if example[task] in label.names else label.str2int("Others")
            else:
                example[task] = label.str2int(example[task])
        return example

    for name, dataset in datasets.items():
        datasets[name] = dataset.map(lambda e: encode_label(e, dataset_labels[name]))

    def add_context(example, idx, dataset, labels):
        example["Past"] = ""
        example["Future"] = ""

        if example["Utterance_ID"] != 0:
            i = 1
            while idx - i >= 0:
                past = dataset[idx - i]
                past_utterance = past["Utterance"]
                if "Speaker" in labels and kwargs["speaker_in_context"]:
                    past_speaker = labels["Speaker"].int2str(past["Speaker"])
                    example["Past"] = past_speaker + ":" + past_utterance + " " + example["Past"]
                else:
                    example["Past"] = past_utterance + " " + example["Past"]
                if past["Utterance_ID"] == 0 or i >= kwargs["num_past_utterances"]:
                    break
                i += 1

        if idx + 1 < len(dataset) and dataset[idx + 1]["Utterance_ID"] != 0:
            i = 1
            while idx + i < len(dataset):
                future = dataset[idx + i]
                future_utterance = future["Utterance"]
                if "Speaker" in labels and kwargs["speaker_in_context"]:
                    future_speaker = labels["Speaker"].int2str(future["Speaker"])
                    example["Future"] += " " + future_speaker + ":" + future_utterance
                else:
                    example["Future"] += " " + future_utterance
                i += 1
                if idx + i < len(dataset) and dataset[idx + i]["Utterance_ID"] == 0 \
                    or i >= kwargs["num_future_utterances"]:
                    break

        return example

    for name, dataset in datasets.items():
        for split, ds in dataset.items():
            dataset[split] = ds.map(lambda e, i: add_context(e, i, ds, dataset_labels[name]), with_indices=True)

    def tokenize(example, add_past, add_future):
        if add_past:
            return tokenizer(example["Past"], example["Utterance"])
        elif add_future:
            return tokenizer(example["Utterance"], example["Future"])
        else:
            return tokenizer(example["Utterance"])
        
    for name, dataset in datasets.items():
        cx_datasets = {}
        cx_datasets["with_past"] = dataset.map(
            lambda e: tokenize(e, add_past=True, add_future=False), batched=True)
        cx_datasets["with_future"] = dataset.map(
            lambda e: tokenize(e, add_past=False, add_future=True), batched=True)
        cx_datasets["no_context"] = dataset.map(
            lambda e: tokenize(e, add_past=False, add_future=False), batched=True)

        tasks = list(dataset_labels[name].keys())
        for cx in cx_datasets:
            cols_to_keep = ["input_ids", "attention_mask"] + tasks
            cols_to_remove = [c for c in cx_datasets[cx]["train"].column_names if c not in cols_to_keep]
            cx_datasets[cx] = cx_datasets[cx].remove_columns(cols_to_remove)
            task_datasets = {}
            for task in tasks:
                label = dataset_labels[name][task]
                ds = cx_datasets[cx]
                ds = ds.cast_column(task, label)
                ds = ds.remove_columns([t for t in tasks if t != task])
                ds = ds.rename_column(task, "labels")
                ds.set_format()
                task_datasets[task] = (ds, label)
            cx_datasets[cx] = task_datasets
        datasets[name] = cx_datasets

    return datasets


###pipeline.py

In [11]:
def prepare_datasets(datasets, **kwargs):
    for dataset_name, cx_datasets in datasets.items():
        task_dataset = {}
        for split in ["train", "validation", "test"]:
            task_dataset[split] = {}
            for cx in cx_datasets:
                if cx == "with_past" and kwargs["num_past_utterances"] == 0:
                    continue
                elif cx == "with_future" and kwargs["num_future_utterances"] == 0:
                    continue
                elif cx == "no_context" and kwargs["num_past_utterances"] + kwargs["num_future_utterances"] > 0:
                    continue
                else:
                    for task, (ds, _) in cx_datasets[cx].items():
                        if split == "train" and task not in kwargs["train_task"]:
                            continue
                        if task not in task_dataset[split]:
                            task_dataset[split][task] = ds[split]
                        else:
                            ds_to_concat = [task_dataset[split][task], ds[split]]
                            task_dataset[split][task] = concatenate_datasets(ds_to_concat)

        train_dataset = task_dataset["train"]
        eval_dataset = task_dataset["validation"]
        test_dataset = task_dataset["test"]
        
        datasets[dataset_name] = {"train": train_dataset, "validation": eval_dataset, "test": test_dataset}
        
    train_dataset = {dataset_name + "_" + task: datasets[dataset_name]["train"][task] 
                     for dataset_name in datasets if dataset_name in kwargs["train_dataset"]
                     for task in datasets[dataset_name]["train"] if task in kwargs["train_task"]}

    eval_dataset_task = kwargs["eval_dataset"] + "_" + kwargs["eval_task"]
    eval_dataset = {eval_dataset_task: datasets[kwargs["eval_dataset"]]["validation"][kwargs["eval_task"]]}
    eval_dataset["task"] = eval_dataset_task
    test_dataset = {eval_dataset_task: datasets[kwargs["eval_dataset"]]["test"][kwargs["eval_task"]]}
    test_dataset["task"] = eval_dataset_task

    return train_dataset, eval_dataset, test_dataset


In [12]:
dataset_labels = {
    "MELD": 
    {
        "Speaker": ClassLabel(
            num_classes=7,
            names=["Chandler", "Joey", "Monica", "Rachel", "Ross", "Phoebe", "Others"]),
        "Emotion": ClassLabel(
            num_classes=7,
            names=["anger", "disgust", "fear", "joy", "neutral", "sadness", "surprise"]),
        "Sentiment": ClassLabel(
            num_classes=3,
            names=["positive", "neutral", "negative"])
    },
    "EmoryNLP": 
    {   
        "Speaker": ClassLabel(
            num_classes=7,
            names=["Chandler", "Joey", "Monica", "Rachel", "Ross", "Phoebe", "Others"]),
        "Emotion": ClassLabel(
            num_classes=7,
            names=["Sad", "Mad", "Scared", "Powerful", "Peaceful", "Joyful", "Neutral"])
    },
    "MPDD": 
    {
        "Emotion": ClassLabel(
            num_classes=7,
            names=["angry", "disgust", "fear", "joy", "neutral", "sadness", "surprise"])
    },
}
    
tokenizer = AutoTokenizer.from_pretrained(kwargs["checkpoint"])
datasets = preprocess(tokenizer, dataset_labels, **kwargs)
train_dataset, eval_dataset, test_dataset = prepare_datasets(datasets, **kwargs)

Using custom data configuration default-0460ef81eef09c38
Reusing dataset csv (/root/.cache/huggingface/datasets/csv/default-0460ef81eef09c38/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519)


  0%|          | 0/3 [00:00<?, ?it/s]

Using custom data configuration default-200f268ae07e8030
Reusing dataset csv (/root/.cache/huggingface/datasets/csv/default-200f268ae07e8030/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519)


  0%|          | 0/3 [00:00<?, ?it/s]

Using custom data configuration default-2bc3f446dfa1ff03
Reusing dataset csv (/root/.cache/huggingface/datasets/csv/default-2bc3f446dfa1ff03/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /root/.cache/huggingface/datasets/csv/default-0460ef81eef09c38/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519/cache-02c5f2c5b05f1aa4.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/csv/default-0460ef81eef09c38/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519/cache-2681517591946c2a.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/csv/default-0460ef81eef09c38/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519/cache-abcaaa7183473aaa.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/csv/default-200f268ae07e8030/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519/cache-a6b81c6507fda93a.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/csv/default-200f268ae07e8030/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519/cache-b3ba612d6a80ae5e.arrow
Loadi

In [13]:
tasks = {task: dataset_labels[task.split("_")[0]][task.split("_")[1]] for task in train_dataset.keys()}
task_models = {
        task: AutoModelForSequenceClassification.from_pretrained(
            kwargs["checkpoint"], 
            num_labels=label.num_classes,
            max_length=1024)
        for task, label in tasks.items()
}
multi_task_model = MultiTaskModel.from_task_models(task_models)

Some weights of the model checkpoint at markussagen/xlm-roberta-longformer-base-4096 were not used when initializing XLMRobertaForSequenceClassification: ['roberta.encoder.layer.0.attention.self.key_global.bias', 'roberta.encoder.layer.3.attention.self.key_global.weight', 'roberta.encoder.layer.0.attention.self.query_global.weight', 'roberta.encoder.layer.5.attention.self.query_global.bias', 'roberta.encoder.layer.10.attention.self.value_global.bias', 'roberta.encoder.layer.8.attention.self.query_global.bias', 'roberta.encoder.layer.5.attention.self.key_global.bias', 'roberta.encoder.layer.11.attention.self.key_global.bias', 'roberta.encoder.layer.3.attention.self.value_global.weight', 'lm_head.layer_norm.bias', 'roberta.encoder.layer.4.attention.self.key_global.weight', 'roberta.encoder.layer.10.attention.self.key_global.bias', 'lm_head.dense.weight', 'roberta.encoder.layer.10.attention.self.value_global.weight', 'roberta.encoder.layer.11.attention.self.query_global.bias', 'lm_head.bi

In [14]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [15]:
if not kwargs["do_train"]:
    multi_task_model.load_state_dict(
        torch.load(kwargs["model_file"], map_location=torch.device(device)))

In [16]:
multi_task_model.to(device)

MultiTaskModel(
  (encoder): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(250002, 768, padding_idx=1)
      (position_embeddings): Embedding(4098, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm(

In [17]:
def compute_metrics(eval_preds):
    metric = load_metric("f1")
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels, average="weighted")


training_args = TrainingArguments(
    output_dir=kwargs["train_dir"],
    seed=kwargs["seed"],
    overwrite_output_dir=True,
    label_names=["labels"],
    learning_rate=kwargs["learning_rate"],
    num_train_epochs=kwargs["epoch"],
    per_device_train_batch_size=kwargs["batch_size"],
    logging_strategy="epoch",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    remove_unused_columns=False,
    load_best_model_at_end=True,
    metric_for_best_model="eval_f1"
)

trainer = MultiTaskTrainer(
    multi_task_model,
    training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [18]:
if kwargs["do_train"]:
    trainer.train()

***** Running training *****
  Num examples = 9989
  Num Epochs = 6
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 7494


Epoch,Training Loss,Validation Loss,F1
1,1.4092,1.304453,0.520504


***** Running Evaluation *****
  Num examples = 1109
  Batch size = 8


Downloading builder script:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

Saving model checkpoint to outputs/multi_task_model/checkpoint-1249
Configuration saved in outputs/multi_task_model/checkpoint-1249/config.json
Model weights saved in outputs/multi_task_model/checkpoint-1249/pytorch_model.bin
tokenizer config file saved in outputs/multi_task_model/checkpoint-1249/tokenizer_config.json
Special tokens file saved in outputs/multi_task_model/checkpoint-1249/special_tokens_map.json


KeyboardInterrupt: ignored

In [None]:
pred = trainer.predict(test_dataset)
f1 = pred.metrics['test_f1']
print("Weighted F1:", f1)

In [None]:
pred_labels = dataset_labels[kwargs["eval_dataset"]][kwargs["eval_task"]].int2str(pred.predictions.argmax(axis=-1))
true_labels = dataset_labels[kwargs["eval_dataset"]][kwargs["eval_task"]].int2str(pred.label_ids)
inputs = tokenizer.batch_decode(test_dataset[kwargs["eval_task"]]["input_ids"])
f = open(kwargs["output_file"], "w")
f.write("Input\tPredicted\tTrue\n")
f.write("\n".join(["\t".join([input, pred_label, true_label]) 
                    for input, pred_label, true_label 
                    in zip(inputs, pred_labels, true_labels)]))
f.close()

In [None]:
!mkdir -p results
import json

f = open(kwargs["result_file"], "a+")
f.write(json.dumps(kwargs))
f.write("\nWeighted F1: {}\n".format(f1))
f.close()

In [None]:
from google.colab import files
files.download(kwargs["result_file"])
files.download(kwargs["output_file"])

In [None]:
trainer.save_model()

In [None]:
!cp "/content/outputs/multi_task_model/pytorch_model.bin" "/content/gdrive/MyDrive/pytorch_model.bin"

In [None]:
def get_random_sample(ds_test, tokenizer, idx=None, max_tokens=512):
    while True:
        if idx is None:
            idx_ = np.random.randint(0, len(ds_test))
        else:
            idx_ = idx
        random_sample = ds_test[idx_]
        input_ids, attention_mask, labelid = (
            random_sample["input_ids"],
            random_sample["attention_mask"],
            random_sample["labels"],
        )
        break

    decoded = tokenizer.decode(input_ids)

    input_ids = torch.tensor(input_ids).view(-1, len(input_ids))
    attention_mask = torch.tensor(attention_mask).view(-1, len(attention_mask))
    labelid = torch.tensor(labelid).view(-1, 1)

    return idx_, input_ids, attention_mask, labelid, decoded

def return_coeffs(
    tokenizer,
    input_ids,
    attentions,
    BATCH_IDX=0,
    LAYER=-1,
    QUERY_TOKEN_IDX=0,
    annoying_char="Ġ",
):
    tokens = tokenizer.convert_ids_to_tokens(input_ids[BATCH_IDX].tolist())
    QUERY_TOKEN = tokens[QUERY_TOKEN_IDX].split(annoying_char)[-1]

    coeffs = (
        attentions[LAYER][BATCH_IDX].cpu().detach().numpy().sum(axis=0)[QUERY_TOKEN_IDX]
    )
    coeffs /= coeffs.sum()

    idx_token_coeffs = [
        (idx, token.split(annoying_char)[-1], coeffs[idx])
        for idx, token in enumerate(tokens)
    ]

    assert len(coeffs) == len(tokens) == len(idx_token_coeffs)

    return QUERY_TOKEN, coeffs, tokens, idx_token_coeffs

In [None]:
import pprint
ds_test = test_dataset[kwargs["evaluation"]]
idx, input_ids, attention_mask, labelid, decoded = get_random_sample(
  ds_test, tokenizer
)

_, _, _, labelid_speaker, _ = get_random_sample(
  test_dataset["Speaker"], tokenizer, idx=idx
)

pprint.pprint(f"{decoded}")
print()

outputs = multi_task_model(
  **{"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device)},
  labels=labelid.to(device),
  output_attentions=True,
  output_hidden_states=True,
  task="Emotion"
)

#outputs_speaker = multi_task_model(
#  **{"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device)},
#  labels=labelid.to(device),
#  output_attentions=True,
#  output_hidden_states=True,
#  task="Speaker"
#)

attentions = outputs.attentions
pred = labels["MELD"]["Emotion"].int2str(int(outputs.logits.argmax().cpu().numpy()))
truth = labels["MELD"]["Emotion"].int2str(int(labelid[0][0].numpy()))

#pred_speaker = labels["MELD"]["Speaker"].int2str(int(outputs_speaker.logits.argmax().cpu().numpy()))
#truth_speaker = labels["MELD"]["Speaker"].int2str(int(labelid_speaker[0][0].numpy()))

pprint.pprint(f"data_idx: {idx}")
pprint.pprint(f"pred: {pred}")
pprint.pprint(f"truth: {truth}")
#pprint.pprint(f"pred: {pred_speaker}")
#pprint.pprint(f"truth: {truth_speaker}")
pprint.pprint(f"number of tokens in the input: {input_ids.shape[1]}")
print()

QUERY_TOKEN, coeffs, tokens, idx_token_coeffs = return_coeffs(
    tokenizer, input_ids, attentions, LAYER=-1, QUERY_TOKEN_IDX=0
)

top_10 = sorted(idx_token_coeffs, key=lambda x: -x[2])[:10]
print(top_10)
print()