In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import load_dataset, load_from_disk, DatasetDict, Dataset
import numpy as np
from evaluate import load
from tqdm.auto import tqdm

In [2]:
import nltk
import numpy as np
nltk.download('punkt')

[nltk_data] Downloading package punkt to /home/andrew/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [3]:
metric = load("rouge")

In [4]:
max_input = 1024
max_target = 128
# model_checkpoint = "./bart-large-cnn-finetuned/checkpoint-10650/"
# model_checkpoint = "./BART-SFT/checkpoint-900/"
# model_checkpoint = "./BART-SFT2/checkpoint-900"
model_checkpoint = "./BART-SFT-r1/checkpoint-4200/"

In [5]:
raw_datasets = load_from_disk("../data/hf_dataset")
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['summary', 'article', 'article_bias', 'id', 'summary_bias'],
        num_rows: 4664
    })
    validation: Dataset({
        features: ['summary', 'article', 'article_bias', 'id', 'summary_bias'],
        num_rows: 542
    })
    test: Dataset({
        features: ['summary', 'article', 'article_bias', 'id', 'summary_bias'],
        num_rows: 602
    })
})

In [6]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [7]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

In [8]:
def preprocess_data(examples):
    # get all the articles, prepend each with "bias;"
    inputs = [
        f"{bias}; {article}"
        for bias, article in zip(examples["summary_bias"], examples["article"])
    ]
    # tokenize the inputs
    model_inputs = tokenizer(
        inputs, max_length=max_input, padding="max_length", truncation=True
    )

    # tokenize the summaries, DO NOT TRUNCATE (unlike training)
    targets = tokenizer(
        examples["summary"],
        max_length=None,
        padding=False,
        truncation=False,
    )

    # set labels
    model_inputs["labels"] = targets["input_ids"]
    # return the tokenized data
    # input_ids, attention_mask and labels
    return model_inputs


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Rouge expects a newline after each sentence
    decoded_preds = [
        "\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds
    ]
    decoded_labels = [
        "\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels
    ]

    # Note that other metrics may not have a `use_aggregator` parameter
    # and thus will return a list, computing a metric for each sentence.
    result = metric.compute(
        predictions=decoded_preds,
        references=decoded_labels,
        use_stemmer=True,
        use_aggregator=True,
    )
    # Extract a few results
    result = {key: value * 100 for key, value in result.items()}

    # Add mean generated length
    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions
    ]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

In [9]:
tokenized_data = raw_datasets.map(preprocess_data, batched=True)

In [10]:
tokenizer.decode(tokenized_data["test"][0]['input_ids'], skip_special_tokens=True) # example input



In [11]:
tokenizer.decode(tokenized_data["test"][0]['labels'], skip_special_tokens=True) # example target summary

'QAnon, a conspiracy theory claiming that Donald Trump is fighting against nefarious forces, was condemned by a resolution in the House, but not unanimously. Some Republicans, including Reps. Jodey Arrington, Michael Burgess, Bill Flores, and Brian Babin of Texas; Rob Bishop of Utah; Mo Brooks of Alabama; Buddy Carter and Drew Ferguson of Georgia; Warren Davidson of Ohio; Jeff Duncan and Ralph Norman of South Carolina; Paul Gosar of Arizona; Mike Kelly and Scott Perry of Pennsylvania; Tom Tiffany of Wisconsin; Daniel Webster of Florida; and Steve King of Iowa, voted against it or did not vote at all. QAnon has been classified as a domestic-terror threat by the FBI for its threatening behavior towards those who do not believe this theory, and one of its followers will soon be elected to Congress. Despite having received bipartisan support, the resolution was not supported by all members of the House.\n'

In [15]:
batch_size = 4
args = Seq2SeqTrainingArguments(
    "test",
    evaluation_strategy="steps",
    # eval_steps=150,
    # warmup_steps=500,
    # learning_rate=2e-5,
    # per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    # weight_decay=0.01,
    # log_level="info",
    # logging_dir="./log",
    # logging_first_step=True,
    # logging_steps=5,
    # save_total_limit=3,
    # save_strategy="steps",
    # save_steps=150,
    # load_best_model_at_end=True,
    # num_train_epochs=25,
    predict_with_generate=True,
    fp16=True,
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_data["train"],
    eval_dataset=tokenized_data["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

using `logging_steps` to initialize `eval_steps` to 500
PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
Using auto half precision backend


In [34]:
test = tokenized_data['test']

In [16]:
preds = trainer.predict(tokenized_data["test"])

The following columns in the test set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: summary, id, article, summary_bias, article_bias. If summary, id, article, summary_bias, article_bias are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 602
  Batch size = 4


In [20]:
preds.metrics # test scores

{'test_loss': 2.259944438934326,
 'test_rouge1': 42.931,
 'test_rouge2': 15.2797,
 'test_rougeL': 25.4485,
 'test_rougeLsum': 37.9946,
 'test_gen_len': 138.211,
 'test_runtime': 834.4269,
 'test_samples_per_second': 0.721,
 'test_steps_per_second': 0.181}

In [28]:
predicted_texts = [tokenizer.decode(token_ids, skip_special_tokens=True) for token_ids in preds.predictions]

In [37]:
# get topic ids in test split
ids = set()
for ex in raw_datasets['test']:
    ids.add(ex['id'])
ids = list(ids)

# get indexes of ex with same topic id
id = ids[1]
indexes = []
for idx, item in enumerate(test):
    if item['id'] == id:
        indexes.append(idx)

In [40]:
# Show a few example outputs
for i in indexes:  # Just show the first 5 examples
    print(f"Article Bias: {test[i]['article_bias']}, Summary Bias: {test[i]['summary_bias']}, ID: {test[i]['id']}")
    print(f"Article: {test[i]['article']}")
    print()
    print(f"Target Summary: {test[i]['summary']}")
    print()
    print(f"Predicted Summary: {predicted_texts[i]}\n")
    print('*' * 100, '\n')

Article Bias: center, Summary Bias: center, ID: 3711
Article: House Passes Extensive Election And Campaign Finance Overhaul Bill
The House passed an extensive bill Friday that would overhaul the way Americans vote and take aim at the money currently flowing through the U.S. political system.
The bill was dubbed the "For The People Act" by House Democrats who want election accessibility and weeding out corruption to be core tenets of their majority agenda the next two years. The bill passed along straight party lines, 234-193.
"For months, for years, really for decades, millions of Americans have been looking at Washington and feeling like they've been left behind," said Rep. John Sarbanes, D-Md., the lead author of the bill. "Too many Americans have faced this challenge where getting to the ballot box every two years is like getting through an obstacle course."
House Democrats gathered on the Capitol steps moments before the vote to celebrate the impending passage.
The more than 500-pa

## filter out examples with same article-summary bias

In [41]:
include_indexes = []
for idx, example in enumerate(test):
    if example['article_bias'] != example['summary_bias']:
        include_indexes.append(idx)

In [56]:
preds_filtered = preds.predictions[include_indexes]
labels_filtered = preds.label_ids[include_indexes]

In [58]:
compute_metrics((preds_filtered, labels_filtered))

{'rouge1': 38.6627,
 'rouge2': 11.2963,
 'rougeL': 21.7057,
 'rougeLsum': 33.9749,
 'gen_len': 138.2356}

# Train/Val

In [None]:
train_outputs = trainer.predict(tokenized_data["train"])

The following columns in the test set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: summary, id, article, summary_bias, article_bias. If summary, id, article, summary_bias, article_bias are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 4664
  Batch size = 4


In [24]:
train_outputs.metrics

{'test_loss': 1.4919403791427612,
 'test_rouge1': 46.1617,
 'test_rouge2': 19.0326,
 'test_rougeL': 28.409,
 'test_rougeLsum': 41.2628,
 'test_gen_len': 138.0838,
 'test_runtime': 6421.2517,
 'test_samples_per_second': 0.726,
 'test_steps_per_second': 0.182}

In [None]:
val_outputs = trainer.predict(tokenized_data["validation"])

The following columns in the test set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: summary, id, article, summary_bias, article_bias. If summary, id, article, summary_bias, article_bias are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 542
  Batch size = 4


In [None]:
val_outputs.metrics

{'test_loss': 2.2837963104248047,
 'test_rouge1': 43.2729,
 'test_rouge2': 15.0279,
 'test_rougeL': 25.0943,
 'test_rougeLsum': 38.3158,
 'test_gen_len': 138.4705,
 'test_runtime': 746.3585,
 'test_samples_per_second': 0.726,
 'test_steps_per_second': 0.182}