In [None]:
# E-BART System Definition
# Author: Erik Brand, UQ
# Last Updated: 3/12/2021

README:

To Train: Run sections 1, 2, 3, 4, 5

To Run Inference: Run sections 1, 2, 3, 4, 6, 7 

To Run Inference With Temperature Scaling: Run sections 1, 2, 3, 4, 6, 8

# 1. Installs

In [None]:
!git clone https://github.com/huggingface/transformers
%cd transformers
!git checkout v4.10.3
!pip install .

In [None]:
!pip install datasets==1.9.0

In [None]:
!pip install rouge_score

# 2. E-BART Definition

In [None]:
from transformers import BartTokenizer, BartModel, BartPretrainedModel, BartConfig, Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers.models.bart.modeling_bart import BartClassificationHead, shift_tokens_right
from transformers.modeling_outputs import Seq2SeqLMOutput
from transformers.generation_logits_process import LogitsProcessorList, MinLengthLogitsProcessor
from transformers.generation_utils import GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput
from transformers.file_utils import ModelOutput

from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss

import numpy as np

ModuleNotFoundError: ignored

In [None]:
class Seq2SeqJointOutput(ModelOutput):
  classification_logits: torch.FloatTensor = None
  loss: Optional[torch.FloatTensor] = None
  logits: torch.FloatTensor = None
  past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
  decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
  cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
  encoder_last_hidden_state: Optional[torch.FloatTensor] = None
  encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None

In [None]:
class BartForJointPrediction(BartPretrainedModel):
    base_model_prefix = "model"
    _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]

    def __init__(self, config: BartConfig):
        super().__init__(config)
        self.model = BartModel(config)
        self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
        self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
        self.classification_head = BartClassificationHead(
            config.d_model,
            config.d_model,
            config.num_labels,  # Defaults to 3 in BART config
            config.classifier_dropout,
        )

        self.init_weights()

    def get_encoder(self):
        return self.model.get_encoder()

    def get_decoder(self):
        return self.model.get_decoder()

    def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
        new_embeddings = super().resize_token_embeddings(new_num_tokens)
        self._resize_final_logits_bias(new_num_tokens)
        return new_embeddings

    def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
        old_num_tokens = self.final_logits_bias.shape[-1]
        if new_num_tokens <= old_num_tokens:
            new_bias = self.final_logits_bias[:, :new_num_tokens]
        else:
            extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
            new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
        self.register_buffer("final_logits_bias", new_bias)

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        classification_labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        # """
        # labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
        #     Labels for computing the masked language modeling loss. Indices should either be in ``[0, ...,
        #     config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are ignored
        #     (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``.

        # Returns:
        # """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if labels is not None:
            if decoder_input_ids is None:
                decoder_input_ids = shift_tokens_right(
                    labels, self.config.pad_token_id, self.config.decoder_start_token_id
                )
        
        # print("INPUTS")
        # print(input_ids.shape)

        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # Summarization Logits
        lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias


        # Classification Logits
        hidden_states = outputs[0]  # last hidden state

        classification_logits = None
        if input_ids is not None:   # In all cases except generating autoregressively - we don't want to produce classficiation logits
            # Can't just use decoder_input_ids all the time as these are already shifted right - might not have ending eos_token_id <s>
            if labels is not None:
                # Training
                eos_mask = labels.eq(self.config.eos_token_id)
            else:
                # Inference
                # We want this to match what happens in the training task: 
                # labels are shifted right and passed to the decoder
                # The decoder predicts eos_token_id based on the last token in label (that is not eos_token_id as this has been shifted off)
                # For inference we pass in final predicted summary, make classification prediction based on last token before eos_token_id, just like training
                eos_mask = decoder_input_ids.eq(self.config.eos_token_id)
                # Shift mask left so that True lines up with token immediately before eos_token_id <s>
                shifted_mask = eos_mask.new_zeros(eos_mask.shape)
                shifted_mask[:, :-1] = eos_mask[:, 1:].clone()
                shifted_mask[:, -1] = False
                eos_mask = shifted_mask
         
            
            if len(torch.unique(eos_mask.sum(1))) > 1:
                # print(decoder_input_ids)
                raise ValueError("All examples must have the same number of <eos> tokens.")
            sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
                :, -1, :
            ]
            classification_logits = self.classification_head(sentence_representation)


        # Summarization Loss
        masked_lm_loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))

        # Classification Loss
        classification_loss = None
        if classification_labels is not None:
            if self.config.num_labels == 1:
                # regression
                loss_fct = MSELoss()
                classification_loss = loss_fct(classification_logits.view(-1), classification_labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                classification_loss = loss_fct(classification_logits.view(-1, self.config.num_labels), classification_labels.view(-1))

        loss = None
        if (labels is not None) and (classification_labels is not None):
            # Joint loss
            loss = 0.5 * masked_lm_loss + 0.5 * classification_loss
        elif labels is not None:
            # Only summarisation loss
            loss = masked_lm_loss
        elif classification_labels is not None:
            # Only classification loss
            loss = classification_loss

        if not return_dict:
            output = (lm_logits) + outputs[1:] + (classification_logits)
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqJointOutput(
            loss=loss,
            logits=lm_logits,
            classification_logits=classification_logits,
            past_key_values=outputs.past_key_values,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )


    def prepare_inputs_for_generation(
        self,
        decoder_input_ids,
        past=None,
        attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs
    ):
        # cut decoder_input_ids if past is used
        if past is not None:
            decoder_input_ids = decoder_input_ids[:, -1:]

        return {
            "input_ids": None,  # encoder_outputs is defined. input_ids not needed
            "encoder_outputs": encoder_outputs,
            "past_key_values": past,
            "decoder_input_ids": decoder_input_ids,
            "attention_mask": attention_mask,
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
            "use_cache": use_cache,  # change this to avoid caching (presumably for debugging)
        }

    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

    @staticmethod
    def _reorder_cache(past, beam_idx):
        reordered_past = ()
        for layer_past in past:
            # cached cross_attention states don't have to be reordered -> they are always the same
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
            )
        return reordered_past
    


    # Override greedy_search from generate MixIn:
    @torch.no_grad()
    def generate(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        max_length: Optional[int] = None,
        min_length: Optional[int] = None,
        do_sample: Optional[bool] = None,
        early_stopping: Optional[bool] = None,
        num_beams: Optional[int] = None,
        temperature: Optional[float] = None,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
        repetition_penalty: Optional[float] = None,
        bad_words_ids: Optional[Iterable[int]] = None,
        bos_token_id: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        length_penalty: Optional[float] = None,
        no_repeat_ngram_size: Optional[int] = None,
        encoder_no_repeat_ngram_size: Optional[int] = None,
        num_return_sequences: Optional[int] = None,
        max_time: Optional[float] = None,
        max_new_tokens: Optional[int] = None,
        decoder_start_token_id: Optional[int] = None,
        use_cache: Optional[bool] = None,
        num_beam_groups: Optional[int] = None,
        diversity_penalty: Optional[float] = None,
        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        forced_bos_token_id: Optional[int] = None,
        forced_eos_token_id: Optional[int] = None,
        remove_invalid_values: Optional[bool] = None,
        synced_gpus: Optional[bool] = None,
        **model_kwargs,
    ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
        
        # First (autoregressive) pass of model to generate summary
        summarization_outputs = super().generate(
            input_ids=input_ids,
            max_length=max_length,
            min_length=min_length,
            do_sample=do_sample,
            early_stopping=early_stopping,
            num_beams=num_beams,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            bad_words_ids=bad_words_ids,
            bos_token_id=bos_token_id,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            length_penalty=length_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size,
            encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
            num_return_sequences=num_return_sequences,
            max_time=max_time,
            max_new_tokens=max_new_tokens,
            decoder_start_token_id=decoder_start_token_id,
            use_cache=use_cache,
            num_beam_groups=num_beam_groups,
            diversity_penalty=diversity_penalty,
            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            output_scores=output_scores,
            return_dict_in_generate=True,
            forced_bos_token_id=forced_bos_token_id,
            forced_eos_token_id=forced_eos_token_id,
            remove_invalid_values=remove_invalid_values,
            synced_gpus=synced_gpus,
            **model_kwargs)

        
        # Second pass of model to get classification output
        classification_outputs = self(input_ids=input_ids,    # Do another pass of the encoder
                                      decoder_input_ids=summarization_outputs['sequences'], # The final summary
                                      use_cache=use_cache,
                                      output_attentions=output_attentions,
                                      output_hidden_states=output_hidden_states,
                                      **model_kwargs)
        
        return (classification_outputs['classification_logits'], summarization_outputs['sequences'])


# 3. Preprocessing Data

In [None]:
import pandas as pd
from datasets import Dataset, load_metric
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
from transformers.trainer_utils import IntervalStrategy, SchedulerType

## Load Data

Training:

In [None]:
df_efever = pd.read_json("/content/drive/MyDrive/Thesis/System_Development/fever_data/efever_train_set.jsonl",
                  orient="columns", lines=True)
df_efever = df_efever.set_index("id")

In [None]:
df_fever = pd.read_json("/content/drive/MyDrive/Thesis/System_Development/fever_data/train.jsonl",
                        orient="columns", lines=True)
df_fever = df_fever.set_index("id")

In [None]:
df = pd.concat([df_fever, df_efever], axis=1, join="inner")
df = df.drop(columns=['evidence', 'verifiable'])
# Convert labels to integer values
df["label"] = df["label"].replace(to_replace={'SUPPORTS':0, 'REFUTES':1, 'NOT ENOUGH INFO':2}, value=None)
# Remove + from retrieved_evidence
df["retrieved_evidence"] = df["retrieved_evidence"].str.replace("+", "")

In [None]:
# To get rid of rows without enough evidence
df = df[df['summary'] != '"The relevant information about the claim is lacking in the context."']

In [None]:
train_dataset = Dataset.from_pandas(df)
train_dataset = train_dataset.remove_columns(['id'])

In [None]:
print(train_dataset)

Validation:

In [None]:
df_efever_val = pd.read_json("/content/drive/MyDrive/Thesis/System_Development/fever_data/efever_dev_set.jsonl",
                  orient="columns", lines=True)
df_efever_val = df_efever_val.set_index("id")

In [None]:
df_fever_val = pd.read_json("/content/drive/MyDrive/Thesis/System_Development/fever_data/dev.jsonl",
                        orient="columns", lines=True)
df_fever_val = df_fever_val.set_index("id")

In [None]:
df_val = pd.concat([df_fever_val, df_efever_val], axis=1, join="inner")
df_val = df_val.drop(columns=['evidence', 'verifiable'])
# Convert labels to integer values
df_val["label"] = df_val["label"].replace(to_replace={'SUPPORTS':0, 'REFUTES':1, 'NOT ENOUGH INFO':2}, value=None)
# Remove + from retrieved_evidence
df_val["retrieved_evidence"] = df_val["retrieved_evidence"].str.replace("+", "")

In [None]:
# To get rid of rows without enough evidence
df_val = df_val[df_val['summary'] != '"The relevant information about the claim is lacking in the context."']

In [None]:
val_dataset = Dataset.from_pandas(df_val)
val_dataset = val_dataset.remove_columns(['id'])

In [None]:
print(val_dataset)

## Preprocess

In [None]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')

In [None]:
def preprocess(data):
  claim = data['claim']
  evidence = data['retrieved_evidence']
  summary = data['summary']
  label = data['label']

  model_inputs = tokenizer(claim, evidence, max_length=1024, truncation=True, padding=False)

  with tokenizer.as_target_tokenizer():
    summarization_labels = tokenizer(summary, max_length=128, truncation=True, padding=False)

  # Ensure padding not included in loss
  summarization_labels["input_ids"] = [
      [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in summarization_labels["input_ids"]
  ]

  model_inputs['classification_labels'] = label   # This doesn't require one-hot encoding because of the way pytorch CrossEntropy works
  model_inputs['labels'] = summarization_labels['input_ids']

  return model_inputs

In [None]:
train_dataset = train_dataset.map(
                preprocess,
                batched=True,
                num_proc=None,
                remove_columns=train_dataset.column_names,
                load_from_cache_file=True,
                desc="Running tokenizer on train dataset",
            )

In [None]:
val_dataset = val_dataset.map(
                preprocess,
                batched=True,
                num_proc=None,
                remove_columns=val_dataset.column_names,
                load_from_cache_file=True,
                desc="Running tokenizer on val dataset",
            )

# 4. Preparation

In [None]:
# Data collator
label_pad_token_id = -100
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=label_pad_token_id,
    pad_to_multiple_of=None,
)

In [None]:
# Metric
metric = load_metric("rouge")

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    if True: #data_args.ignore_pad_token_for_loss:
        # Replace -100 in the labels as we can't decode them.
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results from ROUGE
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

# 5. Training

In [None]:
model = BartForJointPrediction.from_pretrained('facebook/bart-large')

In [None]:
training_args = Seq2SeqTrainingArguments(
    adafactor=False,
    adam_beta1=0.9,
    adam_beta2=0.999,
    adam_epsilon=1e-08,
    dataloader_drop_last=False,
    dataloader_num_workers=0,
    dataloader_pin_memory=True,
    ddp_find_unused_parameters=None,
    debug=[],
    deepspeed=None,
    disable_tqdm=False,
    do_eval=True,
    do_predict=False,
    do_train=True,
    eval_accumulation_steps=None,
    eval_steps=500,
    evaluation_strategy='no',
    fp16=False,
    fp16_backend='auto',
    fp16_full_eval=False,
    fp16_opt_level='O1',
    gradient_accumulation_steps=1,
    greater_is_better=None,
    group_by_length=False,
    ignore_data_skip=False,
    label_names=None,
    label_smoothing_factor=0.0,
    learning_rate=5e-05,
    length_column_name='length',
    load_best_model_at_end=False,
    local_rank=-1,
    # log_level=-1,
    # log_level_replica=-1,
    log_on_each_node=True,
    logging_dir='/tmp/tst-summarization/runs/Jul04_02-41-44_9ee3aa777e7a',
    logging_first_step=False,
    logging_steps=500,
    logging_strategy='steps',
    lr_scheduler_type='linear',
    max_grad_norm=1.0,
    max_steps=-1,
    metric_for_best_model=None,
    # mp_parameters=,
    no_cuda=False,
    num_train_epochs=3.0,
    output_dir='/tmp/tst-summarization',
    overwrite_output_dir=True,
    past_index=-1,
    per_device_eval_batch_size=4,
    per_device_train_batch_size=4,
    predict_with_generate=True,
    prediction_loss_only=False,
    push_to_hub=False,
    push_to_hub_model_id='tst-summarization',
    push_to_hub_organization=None,
    push_to_hub_token=None,
    remove_unused_columns=True,
    report_to=['tensorboard'],
    resume_from_checkpoint=None,
    run_name='/tmp/tst-summarization',
    save_on_each_node=False,
    save_steps=10000,
    save_strategy='steps',
    save_total_limit=None,
    seed=42,
    sharded_ddp=[],
    skip_memory_metrics=True,
    sortish_sampler=False,
    tpu_metrics_debug=False,
    tpu_num_cores=None,
    use_legacy_prediction_loop=False,
    warmup_ratio=0.0,
    warmup_steps=0,
    weight_decay=0.0,
)

In [None]:
trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )

In [None]:
train_result = trainer.train(resume_from_checkpoint=None)
trainer.save_model()  # Saves the tokenizer too for easy upload

metrics = train_result.metrics
max_train_samples = (
    len(train_dataset)
)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))

trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

# 6. Custom Trainer Definition (For Evaluation)

In [None]:
from transformers import Trainer
from transformers.utils import logging
from typing import NamedTuple
from transformers.file_utils import is_torch_tpu_available
from transformers.trainer_pt_utils import (
    DistributedTensorGatherer,
    IterableDatasetShard,
    SequentialDistributedSampler,
    find_batch_size,
    nested_concat,
    nested_numpify,
    nested_truncate,
)
from transformers.deepspeed import deepspeed_init, is_deepspeed_zero3_enabled
from transformers.trainer_utils import (
    EvalLoopOutput,
    EvalPrediction,
    PredictionOutput,
    denumpify_detensorize,
    speed_metrics,
)
from transformers.debug_utils import DebugOption

from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset, IterableDataset

import time
import math
import collections
import nltk

In [None]:
nltk.download("punkt", quiet=True)

In [None]:
logger = logging.get_logger(__name__)

In [None]:
class JointPredictionOutput(NamedTuple):
    predictions: Union[np.ndarray, Tuple[np.ndarray]]
    label_ids: Optional[np.ndarray]
    metrics: Optional[Dict[str, float]]
    classification_predictions: Union[np.ndarray, Tuple[np.ndarray]]
    classification_label_ids: Optional[np.ndarray]


class JointEvalLoopOutput(NamedTuple):
    predictions: Union[np.ndarray, Tuple[np.ndarray]]
    label_ids: Optional[np.ndarray]
    metrics: Optional[Dict[str, float]]
    num_samples: Optional[int]
    classification_predictions: Union[np.ndarray, Tuple[np.ndarray]]
    classification_label_ids: Optional[np.ndarray]

In [None]:
class CustomTrainer(Trainer):
    def prediction_step(
          self,
          model: nn.Module,
          inputs: Dict[str, Union[torch.Tensor, Any]],
          prediction_loss_only: bool,
          ignore_keys: Optional[List[str]] = None,
      ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
          """
          Perform an evaluation step on :obj:`model` using obj:`inputs`.

          Subclass and override to inject custom behavior.

          Args:
              model (:obj:`nn.Module`):
                  The model to evaluate.
              inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                  The inputs and targets of the model.

                  The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                  argument :obj:`labels`. Check your model's documentation for all accepted arguments.
              prediction_loss_only (:obj:`bool`):
                  Whether or not to return the loss only.

          Return:
              Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
              labels (each being optional).
          """

          if not self.args.predict_with_generate or prediction_loss_only:
              return super().prediction_step(
                  model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
              )

          has_labels = "labels" in inputs
          inputs = self._prepare_inputs(inputs)

          # XXX: adapt synced_gpus for fairscale as well
          gen_kwargs = {
              "max_length": self._max_length if self._max_length is not None else self.model.config.max_length,
              "num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams,
              "synced_gpus": True if is_deepspeed_zero3_enabled() else False,
          }

          result = self.model.generate(
              inputs["input_ids"],
              attention_mask=inputs["attention_mask"],
              **gen_kwargs,
          )

          classification_logits = result[0]
          generated_tokens = result[1]

          # in case the batch is shorter than max length, the output should be padded
          if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
              generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])

          with torch.no_grad():
              if self.use_amp:
                  with autocast():
                      outputs = model(**inputs)
              else:
                  outputs = model(**inputs)
              if has_labels:
                  if self.label_smoother is not None:
                      loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
                  else:
                      loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
              else:
                  loss = None

          if self.args.prediction_loss_only:
              return (loss, None, None)

          labels = inputs["labels"]
          classification_labels = inputs["classification_labels"]
          if labels.shape[-1] < gen_kwargs["max_length"]:
              labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])

          return (loss, generated_tokens, labels, classification_logits, classification_labels)



    def evaluation_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> JointEvalLoopOutput:
        """
        Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.

        Works both with or without labels.
        """
        prediction_loss_only = (
            prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
        )

        # if eval is called w/o train init deepspeed here
        if self.args.deepspeed and not self.deepspeed:

            # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
            # from the checkpoint eventually
            deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None)
            self.model = deepspeed_engine.module
            self.model_wrapped = deepspeed_engine
            self.deepspeed = deepspeed_engine
            # XXX: we don't need optim/sched for inference, but this needs to be sorted out, since
            # for example the Z3-optimizer is a must for zero3 to work even for inference - what we
            # don't need is the deepspeed basic optimizer which is self.optimizer.optimizer
            deepspeed_engine.optimizer.optimizer = None
            deepspeed_engine.lr_scheduler = None

        model = self._wrap_model(self.model, training=False)

        # if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while
        # ``train`` is running, halve it first and then put on device
        if not self.is_in_train and self.args.fp16_full_eval:
            model = model.half().to(self.args.device)

        batch_size = dataloader.batch_size

        logger.info(f"***** Running {description} *****")
        if isinstance(dataloader.dataset, collections.abc.Sized):
            logger.info(f"  Num examples = {self.num_examples(dataloader)}")
        else:
            logger.info("  Num examples: Unknown")
        logger.info(f"  Batch size = {batch_size}")

        model.eval()

        self.callback_handler.eval_dataloader = dataloader
        # Do this before wrapping.
        eval_dataset = dataloader.dataset

        if is_torch_tpu_available():
            dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)

        if self.args.past_index >= 0:
            self._past = None

        # Initialize containers
        # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
        losses_host = None
        preds_host = None
        labels_host = None
        classification_preds_host = None
        classification_labels_host = None
        # losses/preds/labels on CPU (final containers)
        all_losses = None
        all_preds = None
        all_labels = None
        all_preds_classification = None
        all_labels_classification = None
        # Will be useful when we have an iterable dataset so don't know its length.

        observed_num_examples = 0
        # Main evaluation loop
        for step, inputs in enumerate(dataloader):
            # Update the observed num examples
            observed_batch_size = find_batch_size(inputs)
            if observed_batch_size is not None:
                observed_num_examples += observed_batch_size

            # Prediction step
            loss, logits, labels, classification_logits, classification_labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)

            # Update containers on host
            if loss is not None:
                losses = self._nested_gather(loss.repeat(batch_size))
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
            if logits is not None:
                logits = self._pad_across_processes(logits)
                logits = self._nested_gather(logits)
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
            if labels is not None:
                labels = self._pad_across_processes(labels)
                labels = self._nested_gather(labels)
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
            if classification_logits is not None:
                classification_logits = self._pad_across_processes(classification_logits)
                classification_logits = self._nested_gather(classification_logits)
                classification_preds_host = classification_logits if classification_preds_host is None else nested_concat(classification_preds_host, classification_logits, padding_index=-100)
            if classification_labels is not None:
                classification_labels = self._pad_across_processes(classification_labels)
                classification_labels = self._nested_gather(classification_labels)
                classification_labels_host = classification_labels if classification_labels_host is None else nested_concat(classification_labels_host, classification_labels, padding_index=-100)
            self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)

            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
            if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0:
                if losses_host is not None:
                    losses = nested_numpify(losses_host)
                    all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
                if preds_host is not None:
                    logits = nested_numpify(preds_host)
                    all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
                if labels_host is not None:
                    labels = nested_numpify(labels_host)
                    all_labels = (
                        labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
                    )
                if classification_preds_host is not None:
                    classification_logits = nested_numpify(classification_preds_host)
                    all_preds_classification = classification_logits if all_preds_classification is None else nested_concat(all_preds_classification, classification_logits, padding_index=-100)
                if classification_labels_host is not None:
                    classification_labels = nested_numpify(classification_labels_host)
                    all_labels_classification = (
                        classification_labels if all_labels_classification is None else nested_concat(all_labels_classification, classification_labels, padding_index=-100)
                    )

                # Set back to None to begin a new accumulation
                losses_host, preds_host, labels_host, classification_preds_host, classification_labels_host = None, None, None, None, None

        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        # Gather all remaining tensors and put them back on the CPU
        if losses_host is not None:
            losses = nested_numpify(losses_host)
            all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
        if preds_host is not None:
            logits = nested_numpify(preds_host)
            all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
        if labels_host is not None:
            labels = nested_numpify(labels_host)
            all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
        if classification_preds_host is not None:
            classification_logits = nested_numpify(classification_preds_host)
            all_preds_classification = classification_logits if all_preds_classification is None else nested_concat(all_preds_classification, classification_logits, padding_index=-100)
        if classification_labels_host is not None:
            classification_labels = nested_numpify(classification_labels_host)
            all_labels_classification = classification_labels if all_labels_classification is None else nested_concat(all_labels_classification, classification_labels, padding_index=-100)

        # Number of samples
        if not isinstance(eval_dataset, IterableDataset):
            num_samples = len(eval_dataset)
        elif isinstance(eval_dataset, IterableDatasetShard):
            num_samples = eval_dataset.num_examples
        else:
            num_samples = observed_num_examples

        # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
        # samplers has been rounded to a multiple of batch_size, so we truncate.
        if all_losses is not None:
            all_losses = all_losses[:num_samples]
        if all_preds is not None:
            all_preds = nested_truncate(all_preds, num_samples)
        if all_labels is not None:
            all_labels = nested_truncate(all_labels, num_samples)
        if all_preds_classification is not None:
            all_preds_classification = nested_truncate(all_preds_classification, num_samples)
        if all_labels_classification is not None:
            all_labels_classification = nested_truncate(all_labels_classification, num_samples)

        # Metrics!
        if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
            metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
        else:
            metrics = {}

        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

        if all_losses is not None:
            metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

        return JointEvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples, classification_predictions=all_preds_classification, classification_label_ids=all_labels_classification)



    def evaluate(
        self,
        eval_dataset: Optional[Dataset] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
        max_length: Optional[int] = None,
        num_beams: Optional[int] = None,
    ) -> Dict[str, float]:
        """
        Run evaluation and returns metrics.

        The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
        (pass it to the init :obj:`compute_metrics` argument).

        You can also subclass and override this method to inject custom behavior.

        Args:
            eval_dataset (:obj:`Dataset`, `optional`):
                Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
                columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the
                :obj:`__len__` method.
            ignore_keys (:obj:`Lst[str]`, `optional`):
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
            metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`):
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
                "eval_bleu" if the prefix is "eval" (default)

        Returns:
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
            dictionary also contains the epoch number which comes from the training state.
        """
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

        # From seq2seqTrainer:
        self._max_length = max_length
        self._num_beams = num_beams

        eval_dataloader = self.get_eval_dataloader(eval_dataset)
        start_time = time.time()

        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        output = eval_loop(
            eval_dataloader,
            description="Evaluation",
            # No point gathering the predictions if there are no metrics, otherwise we defer to
            # self.args.prediction_loss_only
            prediction_loss_only=True if self.compute_metrics is None else None,
            ignore_keys=ignore_keys,
            metric_key_prefix=metric_key_prefix,
        )

        total_batch_size = self.args.eval_batch_size * self.args.world_size
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )

        self.log(output.metrics)

        if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)

        self._memory_tracker.stop_and_update_metrics(output.metrics)

        return output.metrics




    def predict(
        self,
        test_dataset: Dataset,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
        max_length: Optional[int] = None,
        num_beams: Optional[int] = None,
    ) -> JointPredictionOutput:
        """
        Run prediction and returns predictions and potential metrics.

        Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
        will also return metrics, like in :obj:`evaluate()`.

        Args:
            test_dataset (:obj:`Dataset`):
                Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
                ``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__`
            ignore_keys (:obj:`Lst[str]`, `optional`):
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
            metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"test"`):
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
                "test_bleu" if the prefix is "test" (default)

        .. note::

            If your predictions or labels have different sequence length (for instance because you're doing dynamic
            padding in a token classification task) the predictions will be padded (on the right) to allow for
            concatenation into one array. The padding index is -100.

        Returns: `NamedTuple` A namedtuple with the following keys:

            - predictions (:obj:`np.ndarray`): The predictions on :obj:`test_dataset`.
            - label_ids (:obj:`np.ndarray`, `optional`): The labels (if the dataset contained some).
            - metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset
              contained labels).
        """
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

        # From seq2seqTrainer:
        self._max_length = max_length
        self._num_beams = num_beams

        test_dataloader = self.get_test_dataloader(test_dataset)
        start_time = time.time()

        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        output = eval_loop(
            test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
        )
        total_batch_size = self.args.eval_batch_size * self.args.world_size
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )

        self._memory_tracker.stop_and_update_metrics(output.metrics)

        return JointPredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics, classification_predictions=output.classification_predictions, classification_label_ids=output.classification_label_ids)



    def _pad_tensors_to_max_len(self, tensor, max_length):
        if self.tokenizer is None:
            raise ValueError(
                f"Tensor need to be padded to `max_length={max_length}` but no tokenizer was passed when creating "
                "this `Trainer`. Make sure to create your `Trainer` with the appropriate tokenizer."
            )
        # If PAD token is not defined at least EOS token has to be defined
        pad_token_id = (
            self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
        )

        padded_tensor = pad_token_id * torch.ones(
            (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
        )
        padded_tensor[:, : tensor.shape[-1]] = tensor
        return padded_tensor

# 7. Evaluation

In [None]:
from torch.nn import Softmax

In [None]:
# Copy saved model to this session
!mkdir /content/results
!cp ../drive/MyDrive/Thesis/System_Development/results/Result3.zip /content/results/
%cd ../results
!unzip Result3.zip
%cd ../transformers/

In [None]:
# ../results/tst-summarization/ contains the output from saving the model during the training section
model = BartForJointPrediction.from_pretrained('../results/tst-summarization/')

In [None]:
training_args = Seq2SeqTrainingArguments(
    adafactor=False,
    adam_beta1=0.9,
    adam_beta2=0.999,
    adam_epsilon=1e-08,
    dataloader_drop_last=False,
    dataloader_num_workers=0,
    dataloader_pin_memory=True,
    ddp_find_unused_parameters=None,
    debug=[],
    deepspeed=None,
    disable_tqdm=False,
    do_eval=True,
    do_predict=False,
    do_train=False,
    eval_accumulation_steps=None,
    eval_steps=500,
    evaluation_strategy='no',
    fp16=False,
    fp16_backend='auto',
    fp16_full_eval=False,
    fp16_opt_level='O1',
    gradient_accumulation_steps=1,
    greater_is_better=None,
    group_by_length=False,
    ignore_data_skip=False,
    label_names=None,
    label_smoothing_factor=0.0,
    learning_rate=5e-05,
    length_column_name='length',
    load_best_model_at_end=False,
    local_rank=-1,
    # log_level=-1,
    # log_level_replica=-1,
    log_on_each_node=True,
    logging_dir='/tmp/tst-summarization/runs/Jul04_02-41-44_9ee3aa777e7a',
    logging_first_step=False,
    logging_steps=500,
    logging_strategy='steps',
    lr_scheduler_type='linear',
    max_grad_norm=1.0,
    max_steps=-1,
    metric_for_best_model=None,
    # mp_parameters=,
    no_cuda=False,
    num_train_epochs=3.0,
    output_dir='/tmp/tst-summarization',
    overwrite_output_dir=True,
    past_index=-1,
    per_device_eval_batch_size=4,
    per_device_train_batch_size=4,
    predict_with_generate=True,
    prediction_loss_only=False,
    push_to_hub=False,
    push_to_hub_model_id='tst-summarization',
    push_to_hub_organization=None,
    push_to_hub_token=None,
    remove_unused_columns=True,
    report_to=['tensorboard'],
    resume_from_checkpoint=None,
    run_name='/tmp/tst-summarization',
    save_on_each_node=False,
    save_steps=10000,
    save_strategy='steps',
    save_total_limit=None,
    seed=42,
    sharded_ddp=[],
    skip_memory_metrics=True,
    sortish_sampler=False,
    tpu_metrics_debug=False,
    tpu_num_cores=None,
    use_legacy_prediction_loop=False,
    warmup_ratio=0.0,
    warmup_steps=0,
    weight_decay=0.0,
)

In [None]:
trainer = CustomTrainer(
        model=model,
        args=training_args,
        train_dataset=None,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )

In [None]:
predictions = trainer.predict(
            val_dataset, max_length=100, num_beams=4, metric_key_prefix="eval"
        )

In [None]:
predictions.metrics

Evaluation of Classification Accuracy:

In [None]:
# Prepare predictions
classification_logits = torch.from_numpy(predictions.classification_predictions)
final_layer = Softmax(dim=1)
classification_preds = torch.argmax(final_layer(classification_logits), dim=1)
classification_preds = classification_preds.numpy()

# Prepare labels
gold_label = np.array(val_dataset['classification_labels'])

# Compute accuracy
accuracy = np.mean(classification_preds == gold_label)

# Compute filtered accuracy
idx = gold_label != 2
classification_preds_filtered = classification_preds[idx]
gold_label_filtered = gold_label[idx]

filtered_accuracy = np.mean(classification_preds_filtered == gold_label_filtered)

print("FULL")
print(accuracy)
print("WITHOUT CLASS 2")
print(filtered_accuracy)

Evaluation of Explanations:

In [None]:
summary = torch.from_numpy(predictions.predictions)
decoded_summaries = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary]

Save Results:

In [None]:
df_val['pred_label'] = classification_preds
df_val['pred_explanation'] = decoded_summaries
df_val.to_csv("model_eFEVER_data_eFEVER.csv", index=False)

# 8. Temperature Scaling

In [None]:
training_len = 9999
device = 'cuda:0'

In [None]:
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.optim as optim

In [None]:
!mkdir /content/results
!cp ../drive/MyDrive/Thesis/System_Development/results/Result3.zip /content/results/
%cd ../results
!unzip Result3.zip
%cd ../transformers/

In [None]:
# ../results/tst-summarization/ contains the output from saving the model during the training section
model = BartForJointPrediction.from_pretrained('../results/tst-summarization/')

In [None]:
training_args = Seq2SeqTrainingArguments(
    adafactor=False,
    adam_beta1=0.9,
    adam_beta2=0.999,
    adam_epsilon=1e-08,
    dataloader_drop_last=False,
    dataloader_num_workers=0,
    dataloader_pin_memory=True,
    ddp_find_unused_parameters=None,
    debug=[],
    deepspeed=None,
    disable_tqdm=False,
    do_eval=True,
    do_predict=False,
    do_train=False,
    eval_accumulation_steps=None,
    eval_steps=500,
    evaluation_strategy='no',
    fp16=False,
    fp16_backend='auto',
    fp16_full_eval=False,
    fp16_opt_level='O1',
    gradient_accumulation_steps=1,
    greater_is_better=None,
    group_by_length=False,
    ignore_data_skip=False,
    label_names=None,
    label_smoothing_factor=0.0,
    learning_rate=5e-05,
    length_column_name='length',
    load_best_model_at_end=False,
    local_rank=-1,
    # log_level=-1,
    # log_level_replica=-1,
    log_on_each_node=True,
    logging_dir='/tmp/tst-summarization/runs/Jul04_02-41-44_9ee3aa777e7a',
    logging_first_step=False,
    logging_steps=500,
    logging_strategy='steps',
    lr_scheduler_type='linear',
    max_grad_norm=1.0,
    max_steps=-1,
    metric_for_best_model=None,
    # mp_parameters=,
    no_cuda=False,
    num_train_epochs=3.0,
    output_dir='/tmp/tst-summarization',
    overwrite_output_dir=True,
    past_index=-1,
    per_device_eval_batch_size=4,
    per_device_train_batch_size=4,
    predict_with_generate=True,
    prediction_loss_only=False,
    push_to_hub=False,
    push_to_hub_model_id='tst-summarization',
    push_to_hub_organization=None,
    push_to_hub_token=None,
    remove_unused_columns=True,
    report_to=['tensorboard'],
    resume_from_checkpoint=None,
    run_name='/tmp/tst-summarization',
    save_on_each_node=False,
    save_steps=10000,
    save_strategy='steps',
    save_total_limit=None,
    seed=42,
    sharded_ddp=[],
    skip_memory_metrics=True,
    sortish_sampler=False,
    tpu_metrics_debug=False,
    tpu_num_cores=None,
    use_legacy_prediction_loop=False,
    warmup_ratio=0.0,
    warmup_steps=0,
    weight_decay=0.0,
)

In [None]:
trainer = CustomTrainer(
        model=model,
        args=training_args,
        train_dataset=None,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )

In [None]:
predictions = trainer.predict(
            val_dataset, max_length=100, num_beams=4, metric_key_prefix="eval"
        )

Explanations:

In [None]:
summary = torch.from_numpy(predictions.predictions)
decoded_summaries = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary]
df_val['pred_explanation'] = decoded_summaries

Preliminary Setup:

In [None]:
# Split validation set into training/testing subsets
training_logits = predictions.classification_predictions[:training_len]
testing_logits = predictions.classification_predictions[training_len:]

training_labels = np.array(val_dataset['classification_labels'])[:training_len]
testing_labels = np.array(val_dataset['classification_labels'])[training_len:]

Evaluation Helper Functions:

In [None]:
def calc_bins(preds, labels_oneh):
  # Assign each prediction to a bin
  num_bins = 10
  bins = np.linspace(0.1, 1, num_bins)
  binned = np.digitize(preds, bins)

  # Save the accuracy, confidence and size of each bin
  bin_accs = np.zeros(num_bins)
  bin_confs = np.zeros(num_bins)
  bin_sizes = np.zeros(num_bins)

  for bin in range(num_bins):
    bin_sizes[bin] = len(preds[binned == bin])
    if bin_sizes[bin] > 0:
      bin_accs[bin] = (labels_oneh[binned==bin]).sum() / bin_sizes[bin]
      bin_confs[bin] = (preds[binned==bin]).sum() / bin_sizes[bin]

  return bins, binned, bin_accs, bin_confs, bin_sizes

In [None]:
def get_metrics(preds, labels_oneh):
  ECE = 0
  MCE = 0
  bins, _, bin_accs, bin_confs, bin_sizes = calc_bins(preds, labels_oneh)

  for i in range(len(bins)):
    abs_conf_dif = abs(bin_accs[i] - bin_confs[i])
    ECE += (bin_sizes[i] / sum(bin_sizes)) * abs_conf_dif
    MCE = max(MCE, abs_conf_dif)

  return ECE, MCE

In [None]:
import matplotlib.patches as mpatches
import matplotlib

def draw_reliability_graph(preds, labels_oneh):
  ECE, MCE = get_metrics(preds, labels_oneh)
  bins, _, bin_accs, _, _ = calc_bins(preds, labels_oneh)

  font = {'family' : 'normal',
        'weight' : 'bold',
        'size'   : 16}
  matplotlib.rc('font', **font)

  fig = plt.figure(figsize=(8, 8))
  ax = fig.gca()

  # x/y limits
  ax.set_xlim(0, 1.05)
  ax.set_ylim(0, 1)

  # x/y labels
  plt.xlabel('Confidence')
  plt.ylabel('Accuracy')

  # Create grid
  ax.set_axisbelow(True) 
  ax.grid(color='gray', linestyle='dashed')

  # Error bars
  plt.bar(bins, bins,  width=0.1, alpha=0.3, edgecolor='black', color='r', hatch='\\')

  # Draw bars and identity line
  plt.bar(bins, bin_accs, width=0.1, alpha=1, edgecolor='black', color='b')
  plt.plot([0,1],[0,1], '--', color='gray', linewidth=2)

  # Equally spaced axes
  plt.gca().set_aspect('equal', adjustable='box')

  # ECE and MCE legend
  ECE_patch = mpatches.Patch(color='green', label='ECE = {:.2f}%'.format(ECE*100))
  MCE_patch = mpatches.Patch(color='red', label='MCE = {:.2f}%'.format(MCE*100))
  plt.legend(handles=[ECE_patch, MCE_patch])

  #plt.show()
  
  plt.savefig('calibrated_network.png', bbox_inches='tight')

#draw_reliability_graph(preds)


## Evaluation Before

In [None]:
classification_logits = torch.from_numpy(testing_logits)
final_layer = nn.Softmax(dim=1)
class_softmax = final_layer(classification_logits)
class_softmax_flat = np.array(class_softmax).flatten()

gold_label = testing_labels
gold_label = torch.from_numpy(gold_label)
labels_oneh = torch.nn.functional.one_hot(gold_label, num_classes=3)
labels_oneh = np.array(labels_oneh).flatten()

Save Classification:

In [None]:
classification_preds = torch.argmax(final_layer(classification_logits), dim=1)
classification_preds = classification_preds.numpy()
filler = np.ones(training_len) * -1 # use -1 for subset of validation data used to train temperature parameter
classification_preds = np.concatenate((filler, classification_preds))
df_val['pred_label_orig'] = classification_preds

Reliability Diagram

In [None]:
draw_reliability_graph(class_softmax_flat, labels_oneh)

## Apply Temperature Scaling

In [None]:
def T_scaling(logits, args):
  temperature = args.get('temperature', None)
  return torch.div(logits, temperature)

In [None]:
temperature = nn.Parameter(torch.ones(1).cuda())
args = {'temperature': temperature}
criterion = nn.CrossEntropyLoss()

# Removing strong_wolfe line search results in jump after 50 epochs
optimizer = optim.LBFGS([temperature], lr=0.001, max_iter=10000, line_search_fn='strong_wolfe')

logits_list = []
labels_list = []
temps = []
losses = []

logits_list = torch.from_numpy(training_logits).to(device)
labels_list = torch.from_numpy(training_labels).to(device)

def _eval():
  loss = criterion(T_scaling(logits_list, args), labels_list)
  loss.backward()
  temps.append(temperature.item())
  losses.append(loss)
  return loss


optimizer.step(_eval)

print('Final T_scaling factor: {:.2f}'.format(temperature.item()))

plt.subplot(121)
plt.plot(list(range(len(temps))), temps)

plt.subplot(122)
plt.plot(list(range(len(losses))), losses)
plt.show()

## Evaluation After

In [None]:
args = {'temperature': temperature}

classification_logits = torch.from_numpy(testing_logits)
classification_logits = classification_logits.to(device)
classification_logits = T_scaling(classification_logits, args)
final_layer = nn.Softmax(dim=1)
class_softmax = final_layer(classification_logits)
class_softmax = class_softmax.cpu().detach().numpy()
class_softmax_flat = class_softmax.flatten()

gold_label = testing_labels
gold_label = torch.from_numpy(gold_label)
labels_oneh = torch.nn.functional.one_hot(gold_label, num_classes=3)
labels_oneh = np.array(labels_oneh).flatten()

In [None]:
classification_preds = torch.from_numpy(class_softmax)
classification_preds = torch.argmax(classification_preds, dim=1)
classification_preds = classification_preds.numpy()
filler = np.ones(training_len) * -1
classification_preds = np.concatenate((filler, classification_preds))
df_val['pred_label_final'] = classification_preds

In [None]:
draw_reliability_graph(class_softmax_flat, labels_oneh)

## Save Results

In [None]:
np.savetxt("softmax.csv", class_softmax, delimiter=",")

In [None]:
filler = np.ones(training_len) * -1 # use -1 for subset of validation data used to train temperature parameter
classification_probs = np.concatenate((filler, class_softmax[:,0]))
df_val['prob_class_0'] = classification_probs

filler = np.ones(training_len) * -1 # use -1 for subset of validation data used to train temperature parameter
classification_probs = np.concatenate((filler, class_softmax[:,1]))
df_val['prob_class_1'] = classification_probs

filler = np.ones(training_len) * -1 # use -1 for subset of validation data used to train temperature parameter
classification_probs = np.concatenate((filler, class_softmax[:,2]))
df_val['prob_class_2'] = classification_probs

In [None]:
df_val.to_csv("val_preds_temperature_scaling.csv", index=False)