In [1]:
import transformers
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import load_dataset, load_from_disk, DatasetDict
import numpy as np
from evaluate import load



  from .autonotebook import tqdm as notebook_tqdm


In [28]:
max_input = 1024
max_target = 128
model_checkpoint = "facebook/bart-large-cnn"

In [36]:
raw_datasets = load_from_disk("../data/hf_dataset")

In [37]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['summary', 'article_bias', 'id', 'article', 'summary_bias'],
        num_rows: 4704
    })
    validation: Dataset({
        features: ['summary', 'article_bias', 'id', 'article', 'summary_bias'],
        num_rows: 523
    })
    test: Dataset({
        features: ['summary', 'article_bias', 'id', 'article', 'summary_bias'],
        num_rows: 581
    })
})

In [5]:
# Define the sample sizes
train_sample_size = 100  # Adjust as needed
val_sample_size = 20     # Adjust as needed
test_sample_size = 20    # Adjust as needed

# Sample the datasets
sampled_datasets = DatasetDict({
    "train": raw_datasets["train"].shuffle(seed=42).select(range(train_sample_size)),
    "validation": raw_datasets["validation"].shuffle(seed=42).select(range(val_sample_size)),
    "test": raw_datasets["test"].shuffle(seed=42).select(range(test_sample_size)),
})

In [6]:
sampled_datasets

DatasetDict({
    train: Dataset({
        features: ['summary', 'id', 'article', 'bias'],
        num_rows: 100
    })
    validation: Dataset({
        features: ['summary', 'id', 'article', 'bias'],
        num_rows: 20
    })
    test: Dataset({
        features: ['summary', 'id', 'article', 'bias'],
        num_rows: 20
    })
})

In [7]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [30]:
def preprocess_data(examples):
  #get all the articles, prepend each with "bias;"
  inputs = [f"{bias}; article" for bias, article in zip(examples["bias"], examples["article"])]
  #tokenize the inputs
  model_inputs = tokenizer(inputs,  max_length=max_input, padding='max_length', truncation=True)
    
  #tokenize the summaries
  targets = tokenizer(examples['summary'], max_length=max_target, padding='max_length', truncation=True)

  #set labels
  model_inputs['labels'] = targets['input_ids']
  #return the tokenized data
  #input_ids, attention_mask and labels
  return model_inputs

In [31]:
tokenize_data = sampled_datasets.map(preprocess_data, batched = True)

Map: 100%|██████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 2940.83 examples/s]


In [10]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

In [11]:
batch_size = 4
model_name = model_checkpoint.split("/")[-1]
args = Seq2SeqTrainingArguments(
    f"{model_name}-finetuned",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=1,
    predict_with_generate=True
)

In [12]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [17]:
metric = load("rouge")

Downloading builder script: 100%|█████████████████████████████████████████| 6.27k/6.27k [00:00<00:00, 15.2MB/s]


In [13]:
import nltk
import numpy as np
nltk.download('punkt')

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    # Note that other metrics may not have a `use_aggregator` parameter
    # and thus will return a list, computing a metric for each sentence.
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True, use_aggregator=True)
    # Extract a few results
    result = {key: value * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

[nltk_data] Downloading package punkt to /Users/andrew/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [32]:
trainer = Seq2SeqTrainer(
    model, 
    args,
    train_dataset=tokenize_data['train'],
    eval_dataset=tokenize_data['validation'],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [33]:
trainer.train()



Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,No log,2.978171,24.6008,3.1202,14.9601,22.5376,121.15


TrainOutput(global_step=25, training_loss=2.907449951171875, metrics={'train_runtime': 142.8934, 'train_samples_per_second': 0.7, 'train_steps_per_second': 0.175, 'total_flos': 216710460211200.0, 'train_loss': 2.907449951171875, 'epoch': 1.0})

In [34]:
preds = trainer.predict(tokenize_data["test"])

In [35]:
predicted_token_ids = preds.predictions.argmax(-1)

# Now let's convert these token ids to actual text
# We will decode the token ids while skipping the special tokens
predicted_texts = [tokenizer.decode(token_ids, skip_special_tokens=True) for token_ids in predicted_token_ids]

test_dataset = sampled_datasets['test']
# Show a few example outputs
for i in range(5):  # Just show the first 5 examples
    print(f"Example {i+1}:")
    print(f"Bias: {test_dataset[i]['bias']}")
    print(f"Input: {test_dataset[i]['article']}")
    print(f"Target Summary: {test_dataset[i]['summary']}")
    print(f"Predicted Summary: {predicted_texts[i]}\n")

Example 1:
Bias: left
Input: Trump administration says Israel’s West Bank settlements do not violate international law
Secretary of State Mike Pompeo said Monday that the Trump administration had determined that Israel’s West Bank settlements do not violate international law, a decision he said had “increased the likelihood” of a Middle East peace settlement.
Pompeo said the Trump administration, as it did with recognition of Jerusalem as the Israeli capital and Israel’s sovereignty over the disputed Golan Heights, had simply “recognized the reality on the ground.”
The move upends more than 40 years of U.S. policy that has declared Israeli expansion into territories occupied since the 1967 war a major obstacle to settling the Israeli-Palestinian conflict.
In response to a question, Pompeo denied that the announcement was connected to turmoil in Israel in which Prime Minister Benjamin Netanyahu, who has supported the Israeli annexation of West Bank territory, is fighting for his politic