In [None]:
!pip install transformers
!pip install datasets

In [None]:
!pip install rouge

In [None]:
from datasets import load_dataset, load_metric
train_dataset  = load_dataset("LA1512/train-20K-4096")["train"]

In [None]:
from nltk import tokenize
import nltk
import numpy as np
nltk.download('punkt')

In [None]:
from transformers import AutoTokenizer
model_name = "facebook/bart-large-cnn"
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
from rouge import Rouge
rouge_pltrdy = Rouge()

def get_rouge2recall_scores(sentences, reference, oracle_type):
    if oracle_type not in ['padrand', 'padlead']:
        raise Exception("oracle_type must be padrand or padlead")

    # rouge_pltrdy is case sensitive
    reference = reference.lower()
    scores = [None for _ in range(len(sentences))]
    count_nonzero_rouge2recall = 0
    for i, sent in enumerate(sentences):
        sent = sent.lower()
        try:
            rouge_scores = rouge_pltrdy.get_scores(sent, reference)
            rouge2recall = rouge_scores[0]['rouge-2']['r']
            scores[i] = rouge2recall
        except ValueError:
            scores[i] = 0.0
        except RecursionError:
            scores[i] = 0.5 # just assign 0.5 as this sentence is simply too long
        if scores[i] > 0.0: count_nonzero_rouge2recall += 1
    scores = np.array(scores)
    N = len(scores)

    if oracle_type == 'padlead':
        biases = np.array([(N-i)*1e-12 for i in range(N)])
    elif oracle_type == 'padrand':
        biases = np.random.normal(scale=1e-10,size=(N,))
    else:
        raise ValueError("this oracle method not supported")
    return scores + biases

In [None]:
max_length = 4096
oracle_type = "padrand"

In [None]:
def process_data_ORC(batch):
    input_text = batch["article"]
    sentences = tokenize.sent_tokenize(input_text)
    references = batch["abstract"]

    keep_idx = []
    selection_score = get_rouge2recall_scores(sentences, references,oracle_type)
    rank = np.argsort(selection_score)[::-1]

    l1 = len(tokenizer.encode(input_text)[1:-1])
    if l1 < max_length:
      batch["article_CS"] = input_text
      batch["ext_target"] = [1 for i in range(len(sentences))]

    else:
      total_length = 0
      for r in rank:
        if total_length < max_length:
          sent = sentences[r]
          length = len(tokenizer.encode(sent)[1:-1])
          total_length += length
          keep_idx.append(r)
        else:
          break

      keep_idx = sorted(keep_idx)
      chosen_sentences = [sentences[j] for j in keep_idx]
      keep_idx_binary = [1 if i in keep_idx else 0 for i in range(len(sentences))]

      batch["article_CS"] = " ".join(chosen_sentences)
      batch["ext_target"] = keep_idx_binary



    return batch

In [None]:
train_dataset_change = train_dataset.map(
    process_data_ORC,
)

In [None]:
!pip install huggingface_hub --q


In [None]:
!huggingface-cli login --token "token acess"

In [None]:
train_dataset_change.push_to_hub("LA1512/train_pubmed_ORC_1024_20K")