<a href="https://colab.research.google.com/github/CakeVision/collabs/blob/main/SummaRerankerV0_8.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers
!pip install sentencepiece
!pip install rouge_score

Collecting transformers
  Downloading transformers-4.35.0-py3-none-any.whl (7.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.9/7.9 MB[0m [31m26.9 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.16.4 (from transformers)
  Downloading huggingface_hub-0.19.0-py3-none-any.whl (311 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m311.2/311.2 kB[0m [31m43.1 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers<0.15,>=0.14 (from transformers)
  Downloading tokenizers-0.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m66.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers)
  Downloading safetensors-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m77.8 MB/s[0m eta [36m0:00:00[0m
Col

In [None]:
# most of the code belongs to https://github.com/ntunlp/SummaReranker. (arxiv paper:2203.06569)
## i had to make changes so it would run better in the collab enviorment(modify some utils and some class components)

import tqdm
import pickle
import os
from shutil import copyfile
import random

import numpy as np
import torch
import torch.nn as nn
from torch.distributions.normal import Normal

import transformers
from transformers import PegasusTokenizer,PegasusForConditionalGeneration,RobertaTokenizerFast, RobertaTokenizer, RobertaModel
from transformers import AdamW, get_linear_schedule_with_warmup

import nltk
from nltk.tokenize import sent_tokenize

In [None]:
#SET Data HyperParams(fancy name for things set by the dev)
data_folder = "/content/root/"
data_threshhold = 143000

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything()

In [None]:
def dataset_builder():
  train_summaries, train_texts = [], []
  with open(data_folder + "train_summary.txt", "rb") as f:
    for l in f.readlines():
        train_summaries.append(l)
  with open(data_folder + "train_text.txt", "rb") as f:
    for l in f.readlines():
      train_texts.append(l)

  perm = np.random.permutation(len(train_texts))
  with open(data_folder + "training_permutation.pkl", "wb") as f:
    pickle.dump(p, f)
  train_summaries = [train_summaries[i] for i in p]
  train_texts = [train_texts[i] for i in p]
  # reverse permutations
  for i in range(len(perm)):
    reverse_perm[perm[i]] = i

  # separation of dataset into 2 halves (mostly for computation efficiency and storage),
  # most of the delay in my understanding is caused by file size

  #first half
  first_half_summaries = train_summaries[:data_threshhold]
  first_half_texts = train_texts[:data_threshhold]

  with open(data_folder + "fh_shuffled_summary.txt", "wb"):
    for l in first_half_summaries:
      f.write(l)

  with open(data_folder + "fh_shuffled_text.txt", "wb"):
    for l in first_half_summaries:
      f.write(l)

  #second half(copy paste first)
  second_half_summaries = train_summaries[:data_threshhold]
  second_half_texts = train_texts[:data_threshhold]

  with open(data_folder + "sh_shuffled_summary.txt", "wb"):
    for l in second_half_summaries:
      f.write(l)

  with open(data_folder + "sh_shuffled_text.txt", "wb"):
    for l in second_half_summaries:
      f.write(l)

  ftypes = ["summary", "text"]
  for ftype in ftypes:
    path = data_folder + "train/{}/".format(ftype)
    idx_first = 0
    idx_second = 0
    for i in tqdm(range(len(p))):
      cur_path = path + "train_{}_{}.txt".format(doc, perm[i])
      if i < data_threshhold:
        nf_path = data_folder + "fh_train_shuffled/{}/shuffled_train_{}_{}.txt".format(ftype, ftype, idx_first)
        idx_first += 1
      else:
        nf_path = data_folder + "sh_train_shuffled/{}/shuffled_train_{}_{}.txt".format(ftype, ftype, idx_first)
        idx_first += 1
      copyfile(cur_path, nf_path)

In [None]:
#Base Data UTILS

def load_data(set,data_folder, individual_txt = False):
  if individual_txt:
    texts, summaries = read_data_files_individual(set, data_folder)
  else:
    text_files, summary_files = prepare_data_files(set, data_folder)
    texts, summaries = read_data_files(set, text_files, summary_files, data_folder)

  print("Total # of texts: {}, # summaries: {}".format(len(texts), len(summaries)))

  print("\n")
  print(texts[0])
  print(summaries[0])

  return texts, summaries

def read_data_files_individual(set, data_folder):
    texts = []
    summaries = []
    set_text_path = data_folder + "/" + set + "/" + "text/"
    set_summary_path = data_folder + "/" + set + "/" + "summary/"
    n_docs = len(os.listdir(set_text_path))
    print("There are {} {} documents".format(n_docs, set))
    for i in range(n_docs):
      text_path_i = set_text_path + "{}_text_{}.txt".format(set, i)
      text_i = "".join(open(text_path_i, "r").readlines())
      texts.append(text_i)
    for i in range(n_docs):
      summary_path_i = set_summary_path + "{}_summary_{}.txt".format(set, i)
      summary_i = "".join(open(summary_path_i, "r").readlines())
      summaries.append(summary_i)

    return texts, summaries

def read_data_files(set, text_files, summary_files, args):
    # read the .txt files
    texts = []
    summaries = []
    idx = 0
    for text_file in text_files:
        with open(text_file, 'r') as f:
            lines = []
            for l in tqdm(f.readlines()):
                lines.append(l)
            print("# lines: {}".format(len(lines)))
            texts += lines

    for summary_file in summary_files:
        with open(summary_file, 'r') as f:
            lines = []
            for l in tqdm(f.readlines()):
                lines.append(l)
            print("# lines: {}".format(len(lines)))
            summaries += lines

    return texts, summaries

def prepare_data_files(set, data_folder):
    # find the files
    text_files = []
    summary_files = []
    text_file = data_folder + "/{}_text.txt".format(set)
    text_files.append(text_file)
    summary_file = data_folder + "/{}_summary.txt".format(set)
    summary_files.append(summary_file)
    print(text_files)
    print(summary_files)

    return text_files, summary_files


In [None]:
#DATA UTILS WITH FOR SCORES SUMMARIES

#TODO: construct other datasets and add them as a training option

#rewrite load_data
def load_data(set, size, args, individual_txt=True, train=False):
    # texts & summaries
    "was easier to just add the individual text module, rather than remake it(the split to individual files wasn't the initial plan for the project)"
    "args - dataset, scoring_method, n_beams"
    "args.dataset is just for creating good file paths(can be replaced with a string)"
    "would've been used for the multiple datasets mentioned in the paper"
    if individual_txt:
        texts, summaries = read_data_files_individual(set, args, train=train)
    else:
        text_files, summary_files = prepare_data_files(set, args, train=train)
        texts, summaries = read_data_files(text_files, summary_files, args)

    # scored summaries - with multiple scores
    if train:
        all_scored_summaries = []
        for generation_method in args.generation_methods:
            gen_scored_summaries = []
            for j in range(len(args.scoring_methods)):
                scored_summaries_j = []
                for i in range(len(set)):
                    set_ = set[i]
                    size_ = size[i]
                    model_name_ = args.train_model_names[i]
                    print(set_)
                    print(size_)
                    print(model_name_)
                    scored_summaries_path_j_i = "../../scored_summaries/{}/{}/{}/{}/{}_scored_summaries_{}_{}_beams_{}.pkl".format(
                        args.dataset, set_, generation_method, args.scoring_methods[j],
                        set_, model_name_, size_, args.n_beams
                    )
                    print(scored_summaries_path_j_i)
                    with open(scored_summaries_path_j_i, "rb") as f:
                        scored_summaries_j_i = pickle.load(f)
                    scored_summaries_j += scored_summaries_j_i
                gen_scored_summaries.append(scored_summaries_j)
            scored_summaries = []
            for i in range(len(gen_scored_summaries[0])):
                summaries_i = gen_scored_summaries[0][i][0]
                scores_i = []
                for j in range(len(args.scoring_methods)):
                    scores_i_j = gen_scored_summaries[j][i][1]
                    scores_i.append(scores_i_j)
                scored_summaries.append((summaries_i, scores_i))
            print(len(scored_summaries), len(scored_summaries[0]), len(scored_summaries[0][0]), len(scored_summaries[0][1]), len(scored_summaries[0][1][0]))
            all_scored_summaries.append(scored_summaries)
        scored_summaries = combine_summaries(all_scored_summaries)
    else:
        all_scored_summaries = []
        for generation_method in args.generation_methods:
            gen_scored_summaries = []
            for j in range(len(args.scoring_methods)):
                scored_summaries_path_j = "../../scored_summaries/{}/{}/{}/{}/{}_scored_summaries_{}_{}_beams_{}.pkl".format(
                    args.dataset, set, generation_method, args.scoring_methods[j],
                    set, args.model_name, size, args.n_beams
                )
                print(scored_summaries_path_j)
                with open(scored_summaries_path_j, "rb") as f:
                    scored_summaries = pickle.load(f)
                gen_scored_summaries.append(scored_summaries)
            scored_summaries = []
            for i in range(len(gen_scored_summaries[0])):
                summaries_i = gen_scored_summaries[0][i][0]
                scores_i = []
                for j in range(len(args.scoring_methods)):
                    scores_i_j = gen_scored_summaries[j][i][1]
                    scores_i.append(scores_i_j)
                scored_summaries.append((summaries_i, scores_i))
            print(len(scored_summaries), len(scored_summaries[0]), len(scored_summaries[0][0]), len(scored_summaries[0][1]), len(scored_summaries[0][1][0]))
            all_scored_summaries.append(scored_summaries)
        scored_summaries = combine_summaries(all_scored_summaries)

    print("Total # of texts: {}, labels: {}, summary_candidates: {}, # candidates / text: {}".format(
        len(texts), len(summaries), len(scored_summaries), len(scored_summaries[0][0])))

    return texts, summaries, scored_summaries
def read_data_files_individual(set, train=False):
    texts = []
    summaries = []
    if train:
        for set_ in set:
            set_text_path = "../../data/{}/{}_text.txt".format(dataset, set_)
            set_summary_path = "../../data/{}/{}_summary.txt".format(dataset, set_)
            n_docs = len(os.listdir(set_text_path))
            print("There are {} {} documents".format(n_docs, set_))
            for i in tqdm(range(n_docs)):
                text_path_i = set_text_path + "{}_text_{}.txt".format(set_, i)
                text_i = "".join(open(text_path_i, "r").readlines())
                texts.append(text_i)
            for i in tqdm(range(n_docs)):
                summary_path_i = set_summary_path + "{}_summary_{}.txt".format(set_, i)
                summary_i = "".join(open(summary_path_i, "r").readlines())
                summaries.append(summary_i)
    else:
        set_text_path = "../../data/{}/{}_text.txt".format(dataset, set)
        set_summary_path = "../../data/{}/{}_summary.txt".format(dataset, set)
        n_docs = len(os.listdir(set_text_path))
        print("There are {} {} documents".format(n_docs, set))
        for i in tqdm(range(n_docs)):
            text_path_i = set_text_path + "{}_text_{}.txt".format(set, i)
            text_i = "".join(open(text_path_i, "r").readlines())
            texts.append(text_i)
        for i in tqdm(range(n_docs)):
            summary_path_i = set_summary_path + "{}_summary_{}.txt".format(set, i)
            summary_i = "".join(open(summary_path_i, "r").readlines())
            summaries.append(summary_i)

    return texts, summaries


def prepare_data_files(set, dataset, train):
    text_files = []
    summary_files = []
    if train:
        for set_ in set:
            text_file = "../../data/{}/{}_text.txt".format(dataset, set_)
            summary_file = "../../data/{}/{}_summary.txt".format(dataset, set_)
            text_files.append(text_file)
            summary_files.append(summary_file)
    else:
        text_file = "../../data/{}/{}_text.txt".format(dataset, set)
        summary_file = "../../data/{}/{}_summary.txt".format(dataset, set)
        text_files.append(text_file)
        summary_files.append(summary_file)

    print("For set {}, loading the following files:".format(set))
    print(text_files)
    print(summary_files)

    return text_files, summary_files


def read_data_files(text_files, summary_files):
    # read the .txt files
    texts = []
    summaries = []

    for text_file in text_files:
        lines = read_one_file(text_file)
        texts += lines
    for summary_file in summary_files:
        lines = read_one_file(summary_file)
        summaries += lines

    return texts, summaries


def read_one_file(file):
    lines = []
    with open(file, 'r') as f:
        for l in tqdm(f.readlines()):
            lines.append(l)
    print(file, len(lines))

    return lines


def combine_summaries(all_scored_summaries):
    res = []
    for i in tqdm(range(len(all_scored_summaries[0]))):
        summaries_i = []
        scores_i = []
        for k in range(len(all_scored_summaries[0][i][1])):
            scores_i.append([])
        for j in range(len(all_scored_summaries)):
            summaries_i_j = all_scored_summaries[j][i][0]
            summaries_i += summaries_i_j
            scores_i_j = all_scored_summaries[j][i][1]
            for k in range(len(scores_i_j)):
                scores_i[k] += scores_i_j[k]
        res.append((summaries_i, scores_i))

    return res


In [None]:
#dataloaders

class MultitaskRerankingDataset:
    def __init__(self, mode, tokenizer, texts, scored_summaries, labels, args):
        self.mode = mode
        self.tokenizer = tokenizer
        self.texts = texts
        self.scored_summaries = scored_summaries
        self.labels = labels
        self.args = args

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, item):
        text = self.texts[item]
        label = self.labels[item]
        scored_summaries = self.scored_summaries[item]
        summary_candidates = scored_summaries[0]
        summary_scores = scored_summaries[1]
        for i in range(len(summary_scores)):
            # Re-adjust for BERTScore
            if min(summary_scores[i]) > 0.0 and max(summary_scores[i]) < 1.0:
              for j in range(len(summary_scores[i])):
                summary_scores[i][j] *= 100
            # Re-adjust for BARTScore
            if min(summary_scores[i]) > -10.0 and max(summary_scores[i]) < 0.0:
              for j in range(len(summary_scores[i])):
                summary_scores[i][j] *= 30

        text_inputs = self.tokenizer(text, return_tensors="pt", max_length=self.args.max_length, padding='max_length')
        text_inputs["input_ids"] = text_inputs["input_ids"][:, :self.args.max_length]
        text_inputs["attention_mask"] = text_inputs["attention_mask"][:, :self.args.max_length]
        summary_candidates_inputs = self.tokenizer(summary_candidates, return_tensors="pt", truncation=True, max_length=self.args.max_summary_length, padding='max_length')
        summary_candidates_inputs["input_ids"] = summary_candidates_inputs["input_ids"][:,:self.args.max_summary_length]
        summary_candidates_inputs["attention_mask"] = summary_candidates_inputs["attention_mask"][:,:self.args.max_summary_length]

        text_and_summaries = [self.tokenizer.decode(text_inputs["input_ids"][0], skip_special_tokens=True) + " " + self.args.sep_symbol + " " \
                              + self.tokenizer.decode(summary_candidates_inputs["input_ids"][i], skip_special_tokens=True) for i in range(len(summary_candidates_inputs["input_ids"]))]
        text_and_summaries_inputs = self.tokenizer(text_and_summaries, return_tensors="pt", truncation=True, max_length=self.args.max_length + self.args.max_summary_length, padding='max_length')
        text_and_summaries_inputs["input_ids"] = text_and_summaries_inputs["input_ids"][:, :(self.args.max_length + self.args.max_summary_length)]
        text_and_summaries_inputs["attention_mask"] = text_and_summaries_inputs["attention_mask"][:, :(self.args.max_length + self.args.max_summary_length)]

        scores = torch.cat([torch.tensor(summary_scores[i]).unsqueeze(0) for i in range(len(summary_scores))], 0)
        labels = torch.max(scores, dim = 1)[0]
        mode = torch.tensor([1])
        if self.mode != "train":
            mode = torch.tensor([0])

        batch = {
            "mode": mode,
            "text": text,
            "label": label,
            "text_and_summaries_input_ids": text_and_summaries_inputs["input_ids"],
            "text_and_summaries_attn_mask": text_and_summaries_inputs["attention_mask"],
            "scores": scores,
            "labels": labels
        }

        return batch

class MultitaskRerankingDatasetTrain:
  def __init__(self, mode, tokenizer, texts, scored_summaries, labels, args):
    self.mode = mode
    self.tokenizer = tokenizer
    self.texts = texts
    self.scored_summaries = scored_summaries
    self.labels = labels
    self.args = args

  def __len__(self):
    return len(self.texts)

  def __getitem__(self, item):
    text = self.texts[item]
    scored_summaries = self.scored_summaries[item]
    summary_candidates = scored_summaries[0]
    summary_scores = scored_summaries[1]
    for i in range(len(summary_scores)):
      # re-adjust BERTScore
      if min(summary_scores[i]) > 0.0 and max(summary_scores[i]) < 1.0:
        for j in range(len(summary_scores[i])):
          summary_scores[i][j] *= 100
      # re-adjust BARTScore
      elif min(summary_scores[i]) > -10.0 and max(summary_scores[i]) < 0.0:
        for j in range(len(summary_scores[i])):
          summary_scores[i][j] *= 30

    text_inputs = self.tokenizer(text, return_tensors="pt", max_length=self.args.max_length, padding='max_length')
    text_inputs["input_ids"] = text_inputs["input_ids"][:, :self.args.max_length]
    text_inputs["attention_mask"] = text_inputs["attention_mask"][:, :self.args.max_length]

    summary_candidates_inputs = self.tokenizer(summary_candidates, return_tensors="pt", truncation=True, max_length=self.args.max_summary_length, padding='max_length')
    summary_candidates_inputs["input_ids"] = summary_candidates_inputs["input_ids"][:,:self.args.max_summary_length]
    summary_candidates_inputs["attention_mask"] = summary_candidates_inputs["attention_mask"][:,:self.args.max_summary_length]

    text_and_summaries = [self.tokenizer.decode(text_inputs["input_ids"][0], skip_special_tokens=True) + " " + self.args.sep_symbol + " " \
                          + self.tokenizer.decode(summary_candidates_inputs["input_ids"][i], skip_special_tokens=True) for i in range(len(summary_candidates_inputs["input_ids"]))]
    text_and_summaries_inputs = self.tokenizer(text_and_summaries, return_tensors="pt", truncation=True, max_length=self.args.max_length + self.args.max_summary_length, padding='max_length')
    text_and_summaries_inputs["input_ids"] = text_and_summaries_inputs["input_ids"][:, :(self.args.max_length + self.args.max_summary_length)]
    text_and_summaries_inputs["attention_mask"] = text_and_summaries_inputs["attention_mask"][:, :(self.args.max_length + self.args.max_summary_length)]

    scores = torch.cat([torch.tensor(summary_scores[i]).unsqueeze(0) for i in range(len(summary_scores))], 0)
    labels = torch.max(scores, dim = 1)[0]
    mode = torch.tensor([1])
    if self.mode != "train":
        mode = torch.tensor([0])

    batch = {
        "mode": mode,
        "text_and_summaries_input_ids": text_and_summaries_inputs["input_ids"],
        "text_and_summaries_attn_mask": text_and_summaries_inputs["attention_mask"],
        "scores": scores,
        "labels": labels
    }

    return batch

def pre_rouge_processing(summary, args):
    if args.clean_n:
      summary = summary.replace("<n>", " ")
    if args.highlights:
      summary = "\n".join(sent_tokenize(summary))
    return summary

In [None]:
#MODEL
class MoE(nn.Module):
  """Call a Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
  Args:
  input_size: integer - size of the input
  output_size: integer - size of the input
  num_experts: an integer - number of experts
  hidden_size: an integer - hidden size of the experts
  noisy_gating: a boolean
  k: an integer - how many experts to use for each batch element
  """

  def __init__(self, device, n_tasks, input_size, output_size, num_experts, hidden_size, k=4):
    super(MoE, self).__init__()
    self.device = device
    self.n_tasks = n_tasks
    self.num_experts = num_experts
    self.output_size = output_size
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.k = k
    # instantiate experts
    self.experts = nn.ModuleList([MLPExpert(self.input_size, self.output_size, self.hidden_size) for i in range(self.num_experts)])
    self.w_gate = nn.ParameterList([nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True) for i in range(n_tasks)])
    self.w_noise = nn.ParameterList([nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True) for i in range(n_tasks)])

    self.softplus = nn.Softplus()
    self.softmax = nn.Softmax(1)
    self.normal = Normal(torch.tensor([0.0]).to(device), torch.tensor([1.0]).to(device))

    assert(self.k <= self.num_experts)

    self.init_tasks_probs()

  def init_tasks_probs(self):
    self.tasks_probs = []
    for j in range(self.n_tasks):
      temp = []
      for i in range(self.num_experts):
          temp.append([])
      self.tasks_probs.append(temp)

  def display_tasks_probs(self):
    print("\nProbability distribution on experts for each task, computed over {} data points:".format(len(self.tasks_probs[0][0])))
    for j in range(self.n_tasks):
      probs = self.tasks_probs[j]
      probs = np.array([np.mean(x) for x in probs])
      prob_std = np.std(probs)
      probs = ["{:.4f}".format(x) for x in probs]
      print("Task {} / {}, distribution across experts: {}, std: {:.4f}".format(j+1, self.n_tasks, probs, prob_std))
    self.init_tasks_probs()

  def cv_squared(self, x):
    """The squared coefficient of variation of a sample.
    Useful as a loss to encourage a positive distribution to be more uniform.
    Epsilons added for numerical stability.
    Returns 0 for an empty Tensor.
    Args:
    x: a `Tensor`.
    Returns:
    a `Scalar`.
    """
    eps = 1e-10
    # if only num_experts = 1
    if x.shape[0] == 1:
        return torch.Tensor([0])
    return x.float().var() / (x.float().mean()**2 + eps)

  def _gates_to_load(self, gates):
    """Compute the true load per expert, given the gates.
    The load is the number of examples for which the corresponding gate is >0.
    Args:
    gates: a `Tensor` of shape [batch_size, n]
    Returns:
    a float32 `Tensor` of shape [n]
    """
    return (gates > 0).sum(0)

  def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values):
    """Helper function to NoisyTopKGating.
    Computes the probability that value is in top k, given different random noise.
    This gives us a way of backpropagating from a loss that balances the number
    of times each expert is in the top k experts per example.
    In the case of no noise, pass in None for noise_stddev, and the result will
    not be differentiable.
    Args:
    clean_values: a `Tensor` of shape [batch, n].
    noisy_values: a `Tensor` of shape [batch, n].  Equal to clean values plus
      normally distributed noise with standard deviation noise_stddev.
    noise_stddev: a `Tensor` of shape [batch, n], or None
    noisy_top_values: a `Tensor` of shape [batch, m].
        "values" Output of tf.top_k(noisy_top_values, m).  m >= k+1
    Returns:
    a `Tensor` of shape [batch, n].
    """

    batch = clean_values.size(0)
    m = noisy_top_values.size(1)
    top_values_flat = noisy_top_values.flatten()
    threshold_positions_if_in = torch.arange(batch) * m + self.k
    threshold_positions_if_in = threshold_positions_if_in.to(self.device)
    threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1)
    is_in = torch.gt(noisy_values, threshold_if_in)
    threshold_positions_if_out = threshold_positions_if_in - 1
    threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat,0 , threshold_positions_if_out), 1)
    # is each value currently in the top k.
    prob_if_in = self.normal.cdf((clean_values - threshold_if_in)/noise_stddev)
    prob_if_out = self.normal.cdf((clean_values - threshold_if_out)/noise_stddev)
    prob = torch.where(is_in, prob_if_in, prob_if_out)
    return prob

  def noisy_top_k_gating(self, gate_idx, x, train, noise_epsilon=1e-2):
    """Noisy top-k gating.
      See paper: https://arxiv.org/abs/1701.06538.
      Args:
        x: input Tensor with shape [batch_size, input_size]
        train: a boolean - we only add noise at training time.
        noise_epsilon: a float
      Returns:
        gates: a Tensor with shape [batch_size, num_experts]
        load: a Tensor with shape [num_experts]
    """
    clean_logits = x @ self.w_gate[gate_idx]
    if train:
        raw_noise_stddev = x @ self.w_noise[gate_idx]
        noise_stddev = ((self.softplus(raw_noise_stddev) + noise_epsilon))
        noisy_logits = clean_logits + ( torch.randn_like(clean_logits) * noise_stddev)
        logits = noisy_logits
    else:
        logits = clean_logits

    # calculate topk + 1 that will be needed for the noisy gates
    top_logits, top_indices = logits.topk(min(self.k + 1, self.num_experts), dim=1)
    top_k_logits = top_logits[:, :self.k]
    top_k_indices = top_indices[:, :self.k]
    top_k_gates = self.softmax(top_k_logits)

    zeros = torch.zeros_like(logits, requires_grad=True)
    gates = zeros.scatter(1, top_k_indices, top_k_gates)

    if train and self.k < self.num_experts:
        load = (self._prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits)).sum(0)
    else:
        load = self._gates_to_load(gates)
    return gates, load

  def forward(self, x, train=True, collect_gates = False, loss_coef=1e-2):
      """Args:
      x: tensor shape [batch_size, input_size]
      train: a boolean scalar.
      loss_coef: a scalar - multiplier on load-balancing losses
      Returns:
      y: a tensor with shape [batch_size, output_size].
      extra_training_loss: a scalar.  This should be added into the overall
      training loss of the model.  The backpropagation of this loss
      encourages all experts to be approximately equally used across a batch.
      """
      all_y = []
      all_loss = torch.tensor(0.0).cuda()
      for gate_idx in range(self.n_tasks):
          gates, load = self.noisy_top_k_gating(gate_idx, x, train)
          # calculate importance loss
          importance = gates.sum(0)

          if collect_gates == True:
              t = gates.detach().cpu().numpy()
              for i in range(t.shape[1]):
                  self.tasks_probs[gate_idx][i] += list(t[:,i])

          loss = self.cv_squared(importance) + self.cv_squared(load)
          loss *= loss_coef

          dispatcher = SparseDispatcher(self.device, self.num_experts, gates)
          expert_inputs = dispatcher.dispatch(x)
          gates = dispatcher.expert_to_gates()
          expert_outputs = [self.experts[i](expert_inputs[i]) for i in range(self.num_experts)]
          y = dispatcher.combine(expert_outputs)

          all_y.append(y)
          all_loss = all_loss + loss

      return all_y, all_loss
class MLPExpert(nn.Module):
    def __init__(self, input_size, output_size, hidden_size):
        super(MLPExpert, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out


class MLPTower(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(MLPTower, self).__init__()
        #self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, 1)
        #self.relu = nn.ReLU()

    def forward(self, x):
        out = self.fc2(x)
        return out


class SparseDispatcher(object):
    """Helper for implementing a mixture of experts.
    The purpose of this class is to create input minibatches for the
    experts and to combine the results of the experts to form a unified
    output tensor.
    There are two functions:
    dispatch - take an input Tensor and create input Tensors for each expert.
    combine - take output Tensors from each expert and form a combined output
      Tensor.  Outputs from different experts for the same batch element are
      summed together, weighted by the provided "gates".
    The class is initialized with a "gates" Tensor, which specifies which
    batch elements go to which experts, and the weights to use when combining
    the outputs.  Batch element b is sent to expert e iff gates[b, e] != 0.
    The inputs and outputs are all two-dimensional [batch, depth].
    Caller is responsible for collapsing additional dimensions prior to
    calling this class and reshaping the output to the original shape.
    See common_layers.reshape_like().
    Example use:
    gates: a float32 `Tensor` with shape `[batch_size, num_experts]`
    inputs: a float32 `Tensor` with shape `[batch_size, input_size]`
    experts: a list of length `num_experts` containing sub-networks.
    dispatcher = SparseDispatcher(num_experts, gates)
    expert_inputs = dispatcher.dispatch(inputs)
    expert_outputs = [experts[i](expert_inputs[i]) for i in range(num_experts)]
    outputs = dispatcher.combine(expert_outputs)
    The preceding code sets the output for a particular example b to:
    output[b] = Sum_i(gates[b, i] * experts[i](inputs[b]))
    This class takes advantage of sparsity in the gate matrix by including in the
    `Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`.
    """

    def __init__(self, device, num_experts, gates):
        """Create a SparseDispatcher."""

        self.device = device
        self._gates = gates
        self._num_experts = num_experts
        # sort experts
        sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0)
        # drop indices
        _, self._expert_index = sorted_experts.split(1, dim=1)
        # get according batch index for each expert
        self._batch_index = sorted_experts[index_sorted_experts[:, 1],0]
        # calculate num samples that each expert gets
        self._part_sizes = list((gates > 0).sum(0).detach().cpu().numpy())
        # expand gates to match with self._batch_index
        gates_exp = gates[self._batch_index.flatten()]
        self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index)

    def dispatch(self, inp):
        """Create one input Tensor for each expert.
        The `Tensor` for a expert `i` contains the slices of `inp` corresponding
        to the batch elements `b` where `gates[b, i] > 0`.
        Args:
          inp: a `Tensor` of shape "[batch_size, <extra_input_dims>]`
        Returns:
          a list of `num_experts` `Tensor`s with shapes
            `[expert_batch_size_i, <extra_input_dims>]`.
        """

        # assigns samples to experts whose gate is nonzero

        # expand according to batch index so we can just split by _part_sizes
        inp_exp = inp[self._batch_index].squeeze(1)
        return torch.split(inp_exp, self._part_sizes, dim=0)

    def combine(self, expert_out, multiply_by_gates=True):
        """Sum together the expert output, weighted by the gates.
        The slice corresponding to a particular batch element `b` is computed
        as the sum over all experts `i` of the expert output, weighted by the
        corresponding gate values.  If `multiply_by_gates` is set to False, the
        gate values are ignored.
        Args:
          expert_out: a list of `num_experts` `Tensor`s, each with shape
            `[expert_batch_size_i, <extra_output_dims>]`.
          multiply_by_gates: a boolean
        Returns:
          a `Tensor` with shape `[batch_size, <extra_output_dims>]`.
        """
        # apply exp to expert outputs, so we are not longer in log space
        stitched = torch.cat(expert_out, 0).exp()

        if multiply_by_gates:
            stitched = stitched.mul(self._nonzero_gates)
        zeros = torch.zeros(self._gates.size(0), expert_out[-1].size(1), requires_grad=True).to(self.device)
        # combine samples that have been processed by the same k experts
        combined = zeros.index_add(0, self._batch_index, stitched.float())
        # add eps to all zero values in order to avoid nans when going back to log space
        combined[combined == 0] = np.finfo(float).eps
        # back to log space
        return combined.log()

    def expert_to_gates(self):
        """Gate values corresponding to the examples in the per-expert `Tensor`s.
        Returns:
          a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32`
              and shapes `[expert_batch_size_i]`
        """
        # split nonzero gates for each expert
        return torch.split(self._nonzero_gates, self._part_sizes, dim=0)

class ModelMultitaskBinary(nn.Module):
    def __init__(self, pretrained_model, tokenizer, args):
        super(ModelMultitaskBinary, self).__init__()
        self.tokenizer = tokenizer
        self.args = args

        # LM
        self.pretrained_model = pretrained_model
        # shared bottom
        self.fc1 = nn.Linear(args.hidden_size, args.bottom_hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(args.bottom_hidden_size, args.hidden_size)
        # MoE
        self.moe = MoE(args.device, args.n_tasks, args.hidden_size, args.hidden_size, args.num_experts, args.expert_hidden_size, args.k)
        # towers - one for each task
        self.towers = nn.ModuleList([MLPTower(args.hidden_size, args.tower_hidden_size) for i in range(args.n_tasks)])
        self.sigmoid = nn.Sigmoid()

        self.loss = nn.BCEWithLogitsLoss()

        # sampled candidates
        self.selected_idx = []

        # training labels
        self.original_training_labels = {}
        self.training_labels = {}
        self.training_scores = {}
        self.training_hits = {}
        for j in range(self.args.n_tasks):
            self.original_training_labels[j] = []
            self.training_labels[j] = []
            self.training_scores[j] = []
            self.training_hits[j] = []

        # multi-summary evaluation
        self.multi_summary_pred_idx = {}
        self.multi_summary_preds = {}
        for j in range(self.args.n_tasks):
            self.multi_summary_pred_idx[j] = []
            self.multi_summary_preds[j] = []

    def forward(self, mode, text_and_summaries_ids, text_and_summaries_mask, scores):
        loss = torch.tensor(0.0).to(self.pretrained_model.device)
        accuracy = [0 for j in range(self.args.n_tasks)]
        rank = [0 for j in range(self.args.n_tasks)]
        predictions_idx = [[] for j in range(self.args.n_tasks)]
        predictions = [[] for j in range(self.args.n_tasks)]
        total_predictions_idx = []
        overall_sums = []
        overall_predictions = []
        for i in range(text_and_summaries_ids.shape[0]):

            # data
            text_and_summaries_ids_i = text_and_summaries_ids[i]
            text_and_summaries_mask_i = text_and_summaries_mask[i]

            # labels construction
            scores_i = scores[i]
            original_scores_i = scores_i.clone().detach()
            labels_i = torch.zeros(scores_i.shape, device = self.pretrained_model.device)
            for j in range(self.args.n_tasks):
                best_j = scores_i[j].max()
                if self.args.sharp_pos:
                    if best_j > scores_i[j].min():
                        labels_i[j][scores_i[j] == best_j] = 1
                else:
                    labels_i[j][scores_i[j] == best_j] = 1
            original_labels_i = labels_i.clone().detach()

            # candidate sampling
            selected_idx, text_and_summaries_ids_i, text_and_summaries_mask_i, scores_i, labels_i = candidate_subsampling(
                mode, text_and_summaries_ids_i, text_and_summaries_mask_i, scores_i, labels_i, self.args
            )
            self.selected_idx += selected_idx

            # model output
            # LM encoding
            outputs_i = self.pretrained_model(
                input_ids = text_and_summaries_ids_i, attention_mask = text_and_summaries_mask_i, output_hidden_states = True
            )
            encs = outputs_i["last_hidden_state"]
            encs = encs[:, 0, :]
            # shared bottom
            if self.args.use_shared_bottom:
                preds_i = self.fc2(self.relu(self.fc1(encs)))
            else:
                preds_i = encs
            # MoE
            train = torch.sum(mode) > 0
            preds_i, aux_loss_i = self.moe(preds_i, train = train, collect_gates = not(train))

            loss_i = torch.tensor(0.0).to(self.pretrained_model.device)
            total_predictions = np.zeros(len(preds_i[0]))
            for j in range(self.args.n_tasks):

                # pred
                preds_i_j = self.towers[j](preds_i[j])[:, 0]

                # labels
                labels_i_j = labels_i[j]
                if torch.sum(mode) > 0:
                    self.original_training_labels[j].append(original_labels_i[j].sum().item())
                    self.training_labels[j].append(labels_i_j.sum().item())
                    if labels_i_j.sum() > 0:
                        self.training_scores[j].append(scores_i[j][labels_i_j == 1].mean().item())
                    self.training_hits[j].append(int(scores_i[j].max().item() == original_scores_i[j].max().item()))

                # loss
                loss_i_j = self.loss(preds_i_j, labels_i_j)
                loss_i = loss_i + loss_i_j

                # predictions
                preds_i_j = self.sigmoid(preds_i_j).detach().cpu().numpy()
                prediction_idx = np.argmax(preds_i_j)
                predictions_idx[j].append(prediction_idx)
                prediction = scores_i[j][prediction_idx].item()
                predictions[j].append(prediction)
                total_predictions += preds_i_j

                # accuracy
                pos_idx = scores_i[j].argmax().item()
                accuracy_i_j = 100 * int(scores_i[j][prediction_idx].item() == scores_i[j][pos_idx].item())
                accuracy[j] = accuracy[j] + accuracy_i_j

                # ranks
                ranks = rank_array(preds_i_j)
                all_pos_idx = [k for k in range(len(scores_i[j])) if scores_i[j][k].item() == scores_i[j][pos_idx].item()]
                rank_i_j = np.min(ranks[all_pos_idx])
                rank[j] = rank[j] + rank_i_j
            loss_i = loss_i / self.args.n_tasks
            if self.args.use_aux_loss:
                loss_i = loss_i + aux_loss_i
            loss = loss + loss_i
            total_predictions /= self.args.n_tasks
            total_prediction_idx = np.argmax(total_predictions)
            total_predictions_idx.append(total_prediction_idx)
            overall_sum = sum([scores_i[j][total_prediction_idx].item() for j in range(self.args.n_tasks)])
            overall_sums.append(overall_sum)
            overall_predictions.append(total_predictions)

        loss /= scores.shape[0]
        outputs = {
            "loss": loss,
            "loss_nce": loss,
            "total_predictions_idx": total_predictions_idx,
            "overall_predictions": overall_predictions
        }
        prediction_sum = 0
        for j in range(self.args.n_tasks):
            accuracy[j] /= scores.shape[0]
            outputs["accuracy_{}".format(self.args.scoring_methods[j])] = torch.tensor(accuracy[j]).float().to(loss.device)
            rank[j] /= scores.shape[0]
            outputs["rank_{}".format(self.args.scoring_methods[j])] = torch.tensor(rank[j]).float().to(loss.device)
            if torch.sum(mode) == 0:
                self.multi_summary_pred_idx[j] += predictions_idx[j]
                self.multi_summary_preds[j] += predictions[j]
            predictions[j] = np.mean(predictions[j])
            outputs["prediction_{}".format(self.args.scoring_methods[j])] = torch.tensor(predictions[j]).float().to(loss.device)
            prediction_sum += predictions[j]
        outputs["prediction_sum"] = torch.tensor(prediction_sum).float().to(loss.device)
        outputs["overall_sum"] = torch.tensor(np.mean(overall_sums)).float().to(loss.device)

        return outputs

In [None]:
# #Custom model trainer
# #made for creating a custom output for evaluation
# class Trainer:
#     def __init__(self, *args, **kwargs):
#         requires_backends(self, ["torch"])

# class CustomTrainer(Trainer):
#   def compute_loss(self, model, inputs, return_outputs=False):
#     mode = inputs["mode"]
#     text_and_summaries_ids = inputs["text_and_summaries_input_ids"]
#     text_and_summaries_mask = inputs["text_and_summaries_attn_mask"]
#     scores = inputs["scores"]

#     outputs = model(mode, text_and_summaries_ids, text_and_summaries_mask, scores)

#     loss = outputs["loss"]
#     output = torch.zeros(2 + 3 * args.n_tasks + 2).float().to(loss.device)
#     output[0] = loss
#     output[1] = outputs["loss_nce"]
#     for j in range(args.n_tasks):
#         output[2 + j * 3] = outputs["accuracy_{}".format(args.scoring_methods[j])]
#         output[3 + j * 3] = outputs["rank_{}".format(args.scoring_methods[j])]
#         output[4 + j * 3] = outputs["prediction_{}".format(args.scoring_methods[j])]
#     output[-2] = outputs["prediction_sum"]
#     output[-1] = outputs["overall_sum"]

#     return (loss, output) if return_outputs else loss

#   def prediction_step( self,
#           model: nn.Module,
#           inputs: Dict[str, Union[torch.Tensor, Any]],
#           prediction_loss_only: bool,
#           ignore_keys: Optional[List[str]] = None,
#   ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
#     """
#     Perform an evaluation step on :obj:`model` using obj:`inputs`.

#     Subclass and override to inject custom behavior.

#     Args:
#         model (:obj:`nn.Module`):
#             The model to evaluate.
#         inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
#             The inputs and targets of the model.

#             The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
#             argument :obj:`labels`. Check your model's documentation for all accepted arguments.
#         prediction_loss_only (:obj:`bool`):
#             Whether or not to return the loss only.
#         ignore_keys (:obj:`Lst[str]`, `optional`):
#             A list of keys in the output of your model (if it is a dictionary) that should be ignored when
#             gathering predictions.

#     Return:
#         Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
#         logits and labels (each being optional).
#     """
#     has_labels = all(inputs.get(k) is not None for k in self.label_names)
#     inputs = self._prepare_inputs(inputs)
#     if ignore_keys is None:
#         if hasattr(self.model, "config"):
#             ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
#         else:
#             ignore_keys = []

#     # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
#     if has_labels:
#         labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
#         if len(labels) == 1:
#             labels = labels[0]
#     else:
#         labels = None

#     with torch.no_grad():
#         if has_labels:
#             loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
#             loss = loss.mean().detach()
#             if isinstance(outputs, dict):
#                 logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
#             else:
#                 logits = outputs[1:]
#         else:
#             loss = None
#             if self.use_amp:
#                 # with autocast():
#                 outputs = model(**inputs)
#             else:
#                 text_inputs_ids = inputs["text_inputs_ids"]
#                 text_attention_mask = inputs["text_attention_mask"]
#                 text_inputs = {
#                     "input_ids": text_inputs_ids,
#                     "attention_mask": text_attention_mask
#                 }
#                 outputs = model(**text_inputs)
#             if isinstance(outputs, dict):
#                 logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
#             else:
#                 logits = outputs
#             # TODO: this needs to be fixed and made cleaner later.
#             if self.args.past_index >= 0:
#                 self._past = outputs[self.args.past_index - 1]

#     if prediction_loss_only:
#         return (loss, None, None)

#     logits = nested_detach(logits)
#     if len(logits) == 1:
#         logits = logits[0]

#     return (loss, logits, labels)

#   def get_train_dataloader(self) -> DataLoader:
#     """
#     Returns the training :class:`~torch.utils.data.DataLoader`.

#     Will use no sampler if :obj:`self.train_dataset` does not implement :obj:`__len__`, a random sampler (adapted
#     to distributed training if necessary) otherwise.

#     Subclass and override this method if you want to inject some custom behavior.
#     """
#     if self.train_dataset is None:
#       raise ValueError("Trainer: training requires a train_dataset.")

#     train_dataset = self.train_dataset
#     if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
#       train_dataset = self._remove_unused_columns(train_dataset, description="training")

#     return DataLoader(
#       train_dataset,
#       batch_size=self.args.train_batch_size,
#       shuffle=train_dataset.args.shuffle_train,
#       collate_fn=self.data_collator,
#       drop_last=self.args.dataloader_drop_last,
#       num_workers=self.args.dataloader_num_workers,
#       pin_memory=self.args.dataloader_pin_memory,
#     )

# def compute_metrics(eval_preds):
#   preds, labels = eval_preds
#   loss_nce = np.mean([preds[i] for i in range(0, len(preds), 1 + 3 * args.n_tasks + 2)])
#   result = {
#     "loss_nce": loss_nce
#   }
#   for j in range(args.n_tasks):
#     accuracy_arr = [preds[i] for i in range(1 + j * 3, len(preds), 1 + 3 * args.n_tasks + 2)]
#     accuracy = np.mean(accuracy_arr)
#     rank_arr = [preds[i] for i in range(2 + j * 3, len(preds), 1 + 3 * args.n_tasks + 2)]
#     rank = np.mean(rank_arr)
#     prediction_arr = [preds[i] for i in range(3 + j * 3, len(preds), 1 + 3 * args.n_tasks + 2)]
#     prediction = np.mean(prediction_arr)
#     print("Task {}, # pred batches: {}".format(j + 1, len(accuracy_arr)))
#     result["accuracy_{}".format(args.scoring_methods[j])] = accuracy
#     result["rank_{}".format(args.scoring_methods[j])] = rank
#     result["prediction_{}".format(args.scoring_methods[j])] = prediction
#   prediction_sum = np.mean([preds[i] for i in range(1 + 3 * args.n_tasks, len(preds), 1 + 3 * args.n_tasks + 2)])
#   result["prediction_sum"] = prediction_sum
#   overall_sum = np.mean([preds[i] for i in range(1 + 3 * args.n_tasks + 1, len(preds), 1 + 3 * args.n_tasks + 2)])
#   result["overall_sum"] = overall_sum

#   return result

In [None]:
#EVALS

def overall_eval(val_texts, val_summaries, val_labels, args):
    # ROUGE
    all_score_names = []
    all_scores = []
    if args.eval_rouge:
        r1, r2, rl = rouge_eval("true labels", val_texts, val_summaries, val_labels, args)
        all_scores.append(r1)
        all_scores.append(r2)
        all_scores.append(rl)
        all_score_names += ["ROUGE-1", "ROUGE-2", "ROUGE-L"]
    # BERTScore
    if args.eval_bertscore:
        bs = bertscore_eval(val_summaries, val_labels, args)
        all_scores.append(bs)
        all_score_names.append("BERTScore")
    # BARTScore
    if args.eval_bartscore:
        bas = bartscore_eval(val_summaries, val_labels, args)
        all_scores.append(bas)
        all_score_names.append("BARTScore")
    # Abstractiveness
    if args.eval_new_ngram:
        new_ngram_eval(val_texts, val_summaries, args)

    return all_scores, all_score_names

def rouge_eval(mode, val_texts, val_summaries, val_labels, args):
    print("\n", "*"*10, "1 - ROUGE evaluation with {}".format(mode), "*"*10)
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeLsum'], use_stemmer = args.stemmer)
    all_r1s = []
    all_r2s = []
    all_rls = []
    for i in range(len(val_summaries)):
        summary = val_summaries[i]
        summary = pre_rouge_processing(summary, args)
        label = val_labels[i]
        r1, r2, rl = get_rouge_scores(summary, label, scorer, args)
        all_r1s.append(r1)
        all_r2s.append(r2)
        all_rls.append(rl)
    all_r1s = 100 * np.array(all_r1s)
    all_r2s = 100 * np.array(all_r2s)
    all_rls = 100 * np.array(all_rls)
    mean_r1 = np.mean(all_r1s)
    mean_r2 = np.mean(all_r2s)
    mean_rl = np.mean(all_rls)
    mean_r = (mean_r1 + mean_r2 + mean_rl) / 3
    print("Mean R: {:.4f}, R-1: {:.4f} (var: {:.4f}), R-2: {:.4f} (var: {:.4f}), R-L: {:.4f} (var: {:.4f})".format(
        mean_r, mean_r1, np.std(all_r1s), mean_r2, np.std(all_r2s), mean_rl, np.std(all_rls)))

    return all_r1s, all_r2s, all_rls

def get_rouge_scores(summary, label, scorer, args):
  rouge_scores = scorer.score(label, summary)
  r1 = rouge_scores["rouge1"].fmeasure
  r2 = rouge_scores["rouge2"].fmeasure
  rl = rouge_scores["rougeLsum"].fmeasure

  return r1, r2, rl


In [None]:
#UTILS

def rank_array(t):
  y = np.copy(t)
  y.sort()
  y = y[::-1]
  ranks = np.zeros(len(t))
  flagged = np.zeros(len(t))
  for i in range(len(t)):
    el = t[i]
    for j in range(len(t)):
      if el == y[j] and flagged[j] == 0:
        ranks[i] = j
        flagged[j] = 1
        break
  return ranks

def nested_detach(tensors):
  if isinstance(tensors, (list, tuple)):
    return type(tensors)(nested_detach(t) for t in tensors)

  return tensors.detach()

def build_tokenizer():
  "needs model param and cache_dir"
  tokenizer = None
  print("\nUsing RoBERTa tokenizer")
  tokenizer = RobertaTokenizerFast.from_pretrained(args.model, cache_dir = args.cache_dir)
  return tokenizer

def build_model(args):
  model = None
  print("\nUsing RoBERTa model")
  model = RobertaModel.from_pretrained(args.model, cache_dir = args.cache_dir)
  return model

def build_optimizer(model, args):
  optimizer = None
  if args.optimizer == "adam":
    print("\nUsing Adam")
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
  elif args.optimizer == "adamw":
    print("\nUsing AdamW")
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)

  return optimizer

def build_scheduler(optimizer, train_steps, args):
    scheduler = None
    if args.scheduler == "linear_warmup":
      print("\nUsing linear warmup scheduler")
      warmup_steps = int(args.warmup_ratio * train_steps)
      scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, train_steps)

    return scheduler

def check_data_pipe(loaders):
  for loader in loaders:
    for idx, batch in enumerate(loader):
      print("*"*50)
      print(batch['text_lang'])
      print(batch['text_inputs']["input_ids"][:,:10])
      print(batch['summary_lang'])
      print(batch['summary_inputs']["input_ids"][:,:10])
      break

def display_losses(mode, losses):
  best_loss = np.min(np.array(losses))
  best_loss_idx = np.argmin(np.array(losses)) + 1
  print("Current {} loss is {:.4f}, best {} loss is {:.4f} achieved at iter {} / {}".format(mode, losses[-1], mode, best_loss, best_loss_idx, len(losses)))


def display_scores(mode, scores):
  for k in scores.keys():
    scores_k = scores[k]
    if "loss" in k:
      best_score_k = np.min(np.array(scores_k))
      best_score_k_idx = np.argmin(np.array(scores_k)) + 1
    else:
      best_score_k = np.max(np.array(scores_k))
      best_score_k_idx = np.argmax(np.array(scores_k)) + 1
    print("Best {} {} is {:.4f} achieved at iter {} / {}".format(mode, k, best_score_k, best_score_k_idx, len(scores_k)))


def compute_r1s(sents):
  scorer = rouge_scorer.RougeScorer(['rouge1'], use_stemmer=False)

  all_r1s = []
  for i in range(len(sents)):
    pruned_sents = sents[:i] + sents[(i + 1):]
    pruned_text = " ".join(pruned_sents)
    scores = scorer.score(pruned_text, sents[i])
    r1 = 100 * scores["rouge1"].fmeasure
    all_r1s.append(r1)
  all_r1s = np.array(all_r1s)

  return all_r1s


def check_scores(dataset):
  all_scores = []
  for i in tqdm(range(len(dataset.scored_summaries))):
    scores = dataset.scored_summaries[i][1]
    max_score = np.max(np.array(scores))
    all_scores.append(max_score)
  m_score = np.mean(all_scores)

  return m_score


In [None]:
#HyperParameters for my implementation
gen_method = "diverse_beam_search"
scoring_methods = ["rouge_1", "rouge_2", "rouge_l"]
filter_out_duplicates = True
sep_symbol = "[SEP]"
n_beams = 15
n_tasks = 3
use_shared_bottom = True

hidden_size = 1024
bottom_hidden_size = 1024
expert_hidden_size = 1024
tower_hidden_size = 1024

num_experts = 6
k = 3 #choices for topk

max_length = 512
sharp_pos = False
use_aux_los =False
max_train_size = 1000000
max_validation_size = 10000
max_test_size = 10000

train_sizes = [143000,144113]
val_size= 13368
test_sizes = 11490
max_source_length = 384
max_summary_length = 128
eval_threshold = 500

pegasus_model_names = ["pegasus_second_half", "pegasus_first_half"]

In [None]:
#demo HyperParams(uses argparser)
import argparse
parser = argparse.ArgumentParser()
args = parser.parse_args(args = [])

args.device = torch.device("cuda")
args.generation_methods = ["diverse_beam_search"]
args.num_beams = 15
args.scoring_methods = ["rouge_1", "rouge_2", "rouge_l"]
args.filter_out_duplicates = True
args.sep_symbol = "[SEP]"
args.n_tasks = 3
args.hidden_size = 1024
args.use_shared_bottom = True
args.bottom_hidden_size = 1024
args.num_experts = 6
args.k = 3
args.expert_hidden_size = 1024
args.tower_hidden_size = 1024
args.sharp_pos = False
args.use_aux_loss = False
args.max_length = 512
args.max_source_length = 384
args.max_summary_length = 128

In [None]:
from rouge_score import rouge_scorer

base_model_name = "google/pegasus-cnn_dailymail"
base_tokenizer = PegasusTokenizer.from_pretrained(base_model_name)
base_model = PegasusForConditionalGeneration.from_pretrained(base_model_name)
base_model = base_model.to(args.device)
base_model = base_model.eval()

# candidates
tok_text = base_tokenizer(text, return_tensors="pt", padding="max_length", max_length=1024)
tok_text["input_ids"] = tok_text["input_ids"][:, :1024]
tok_text["attention_mask"] = tok_text["attention_mask"][:, :1024]
with torch.no_grad():
    generated = base_model.generate(
        input_ids=tok_text["input_ids"].to(args.device),
        attention_mask=tok_text["attention_mask"].to(args.device),
        num_beams=15,
        num_beam_groups=15,
        diversity_penalty=1.0,
        num_return_sequences=15,
        repetition_penalty=1.0,
        length_penalty=0.8,
        no_repeat_ngram_size=3
    )
candidates = base_tokenizer.batch_decode(generated, skip_special_tokens=True)
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeLsum'], use_stemmer = True)
print("\nSummary candidates:")
for j in range(len(candidates)):
    candidates[j] = candidates[j].replace("<n>", " ")
    candidates[j] = "\n".join(sent_tokenize(candidates[j]))
    rouge_scores = scorer.score(label, candidates[j])
    r1 = 100 * rouge_scores["rouge1"].fmeasure
    r2 = 100 * rouge_scores["rouge2"].fmeasure
    rl = 100 * rouge_scores["rougeLsum"].fmeasure
    mean_r = (r1 + r2 + rl) / 3
    print("\nCandidates {} (Mean R: {:.2f}, R-1: {:.2f}, R-2: {:.2f}, R-L: {:.2f})".format(j, mean_r, r1, r2, rl))
    candidates[j] = candidates[j].replace("\n", " ")
    print(candidates[j])

del base_model
del tok_text
del generated
gc.collect()

# SummaReranker
# model
model_name = "roberta-large"
tokenizer = RobertaTokenizerFast(model_name)
model = RobertaModel(model_name)
model = model.to(args.device)
summareranker_model = ModelMultitaskBinary(model, tokenizer, args)
summareranker_model = summareranker_model.to(args.device)
summareranker_model_path = "/data/mathieu/2nd_stage_summarization/4_supervised_multitask_reranking/saved_models/cnndm/multitask_3_tasks_ablation_8/checkpoint-12500/pytorch_model.bin"
summareranker_model.load_state_dict(torch.load(summareranker_model_path))
summareranker_model = summareranker_model.eval()
# prepare the data
text_inputs = tokenizer(text, return_tensors="pt", max_length=args.max_source_length, padding='max_length')
text_inputs["input_ids"] = text_inputs["input_ids"][:, :args.max_source_length]
text_and_candidates_ids, text_and_candidates_masks = [], []
for j in range(len(candidates)):
    candidate = candidates[j]
    candidate_inputs = tokenizer(candidate, return_tensors="pt", max_length=args.max_summary_length, padding='max_length')
    candidate_inputs["input_ids"] = candidate_inputs["input_ids"][:, :args.max_summary_length]
    block = tokenizer.batch_decode(text_inputs["input_ids"], skip_special_tokens = True)[0] + args.sep_symbol + tokenizer.batch_decode(candidate_inputs["input_ids"], skip_special_tokens = True)[0]
    text_and_candidate = tokenizer(block, return_tensors="pt", padding="max_length", max_length=args.max_length)
    ids = text_and_candidate["input_ids"][:, :args.max_length]
    mask = text_and_candidate["attention_mask"][:, :args.max_length]
    text_and_candidates_ids.append(ids)
    text_and_candidates_masks.append(mask)
text_and_candidates_ids = torch.cat(text_and_candidates_ids, 0).unsqueeze(0)
text_and_candidates_ids = text_and_candidates_ids.to(args.device)
text_and_candidates_masks = torch.cat(text_and_candidates_masks, 0).unsqueeze(0)
text_and_candidates_masks = text_and_candidates_masks.to(args.device)
# inference
mode = torch.tensor([0]).to(args.device)
scores = torch.randn(1, len(args.scoring_methods), len(candidates))
scores = scores.to(args.device) # create random candidate scores
with torch.no_grad():
    output = summareranker_model(
        mode,
        text_and_candidates_ids,
        text_and_candidates_masks,
        scores
    )
candidate_scores = output["overall_predictions"][0]
print("\nSummaReranker predicted scores:")
for j in range(len(candidates)):
    print("Candidate {} has score: {:.4f}".format(j, candidate_scores[j]))
best_idx = np.argmax(np.array(candidate_scores))
print("\nSummaReranker output summary is candidate #{}".format(best_idx))
print(candidates[best_idx])

NameError: ignored

Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone
  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24933 sha256=2aac40109fbc4ca64f9e63fddbc503649bdbc14faea993a26de95dd112ff5f8f
  Stored in directory: /root/.cache/pip/wheels/5f/dd/89/461065a73be61a532ff8599a28e9beef17985c9e9c31e541b4
Successfully built rouge_score
Installing collected packages: rouge_score
Successfully installed rouge_score-0.1.2


In [None]:
#IN WORK, don't run after preceeding block
#Dataset __init__

save_path = data_pretrained + "pretrained/"
#device choice
device = torch.device("cpu")
  if torch.cuda.is_available():
      device = torch.device("cuda")
  print("Using device: {}".format(device))

#tokenizer
  tokenizer = build_tokenizer("roberta-large")

#datasets
for x in [(train_datasets, train_sizes), (val_dataset, val_size), (test_dataset, test_size)]:

  set, size = x
  texts, summaries, scored_summaries = load_data(set, size, args, individual_txt=args.highlights, train=train)
  print("loaded new data!", len(texts), len(summaries), len(scored_summaries), len(scored_summaries[0]),
        len(scored_summaries[0][0]), len(scored_summaries[0][1]), len(scored_summaries[0][1][0]))

  mode = "train"
  if not (train):
      mode = "val"
  dataset = MultitaskRerankingDatasetTrain(mode, tokenizer, texts, scored_summaries, summaries, args)
  datasets.append(dataset)
  print("There are {} {} batches".format(int(len(dataset.texts) / args.train_bs), set))

train_dataset = datasets[0]
train_dataset.texts = train_dataset.texts[:args.max_train_size]
train_dataset.scored_summaries = train_dataset.scored_summaries[:args.max_train_size]
train_dataset.labels = train_dataset.labels[:args.max_train_size]

val_dataset = datasets[1]
val_dataset.texts = val_dataset.texts[:args.max_val_size]
val_dataset.scored_summaries = val_dataset.scored_summaries[:args.max_val_size]
val_dataset.labels = val_dataset.labels[:args.max_val_size]

test_dataset = datasets[2]
test_dataset.texts = test_dataset.texts[:args.max_test_size]
test_dataset.scored_summaries = test_dataset.scored_summaries[:args.max_test_size]
test_dataset.labels = test_dataset.labels[:args.max_test_size]

print(train_dataset.texts[0])
  print("*" * 30)
  print(val_dataset.texts[0])
  print("*" * 30)
  print(test_dataset.texts[0])

In [None]:
#check oracle
m_train_score = check_scores(train_dataset)
m_val_score = check_scores(val_dataset)
m_test_score = check_scores(test_dataset)
print("\nOracle - train: {:.4f}, val: {:.4f}, test: {:.4f}".format(m_train_score, m_val_score, m_test_score))


In [None]:
#model __init__
pretrained_model = build_model(args)
model = ModelMultitaskBinary(pretrained_model, tokenizer, args)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("\nThe model has {} trainable parameters".format(n_params))
model = model.to(device)

train_args = TrainingArguments(
        output_dir=args.save_model_path,  # will be changed
        overwrite_output_dir=True,
        do_train=True,
        do_eval=True,
        do_predict=False,
        evaluation_strategy=args.evaluation_strategy,
        eval_steps=args.eval_every,
        save_total_limit=args.n_checkpoints_to_save,
        num_train_epochs=args.n_epochs,
        adafactor=args.adafactor,
        lr_scheduler_type=args.scheduler,
        warmup_ratio=args.warmup_ratio,
        per_device_train_batch_size=args.train_bs,
        per_device_eval_batch_size=args.inference_bs,
        learning_rate=args.lr,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        weight_decay=args.wd,
        max_grad_norm=args.gradient_clipping,
        logging_strategy="no",
        save_strategy=args.evaluation_strategy,
        save_steps=args.eval_every,
        metric_for_best_model=args.metric_for_best_model,
        fp16=args.fp16,
        load_best_model_at_end=True,
        greater_is_better=True,
        disable_tqdm=False,
        deepspeed=args.deepspeed,
        sharded_ddp=args.sharded_ddp,
        local_rank=args.local_rank
    )
data_collator = default_data_collator

trainer = CustomTrainer(
    model=model,
    args=train_args,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
)


In [None]:
#Training
if args.eval_epoch_0:
  results = trainer.evaluate()
  print("*" * 50, "Init VAL results:")
  print(results)
  model.moe.display_tasks_probs()

if args.train:
  trainer.train()
else:
  if args.load_model:
    model.load_state_dict(torch.load(args.load_model_path))
    print("Loaded the model weights!", args.load_model_path)

# validate with the best model
results = trainer.evaluate()
print("\n", "*" * 50, "BEST VAL RESULTS")
print(results)
model.moe.display_tasks_probs()

# test results
test_results = trainer.predict(test_dataset)
print("\n", "*" * 50, "TEST RESULTS:")
print(test_results[2])
model.moe.display_tasks_probs()