In [56]:
import torch
from datasets import load_dataset
from transformers import (
    BartTokenizer,
    BartForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq
)
import evaluate
from bert_score import BERTScorer
import wandb


In [57]:
CHUNK_SIZE      = 512                  # tokens per chunk
STRIDE          = 128                  # overlap tokens between chunks
BATCH_SIZE      = 8                    # examples per device per step
STAGE1_MAX_SUM  = 250                  # max target length for stage 1
STAGE2_MAX_SUM  = 350                  # max target length for stage 2
EPOCHS_STAGE1   = 20                   # epochs in stage 1
EPOCHS_STAGE2   = 20                   # epochs in stage 2
MODEL_NAME      = "facebook/bart-base" # pretrained checkpoint

In [58]:
dataset = load_dataset(
    "csv",
    data_files={"train": "train.csv", "validation": "val.csv"}
)

# Drop the extra column and rename so our code lines up:
dataset = dataset.remove_columns("word_count")
dataset = dataset.rename_column("full_text",    "text")
dataset = dataset.rename_column("brief_summary","summary")

print("Columns now:", dataset["train"].column_names)

Columns now: ['text', 'summary']


In [59]:
tokenizer = BartTokenizer.from_pretrained(MODEL_NAME)
model     = BartForConditionalGeneration.from_pretrained(MODEL_NAME)

rouge     = evaluate.load("rouge")
bertscore = BERTScorer(lang="en", rescale_with_baseline=True)

data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=tokenizer.pad_token_id
)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [60]:
def preprocess_function(examples, max_target_length):
    inputs  = [str(x) for x in examples["text"]]
    targets = [str(x) for x in examples["summary"]]
    # tokenize + fixed padding + optional stride
    model_inputs = tokenizer(
        inputs,
        max_length=CHUNK_SIZE,
        padding="max_length",
        truncation=True,
        stride=STRIDE
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            targets,
            max_length=max_target_length,
            padding="max_length",
            truncation=True,
            stride=STRIDE
        )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs



In [61]:
def compute_metrics(eval_pred):
    preds, labels = eval_pred
    decoded_preds  = tokenizer.batch_decode(preds,  skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # ROUGE
    result = rouge.compute(
        predictions=decoded_preds,
        references=decoded_labels
    )
    # BERTScore
    P, R, F1 = bertscore.score(decoded_preds, decoded_labels)
    result.update({
        "bertscore_precision": P.mean().item(),
        "bertscore_recall":    R.mean().item(),
        "bertscore_f1":        F1.mean().item()
    })
    return {k: round(v,4) for k,v in result.items()}

In [62]:
train_ds_stage1 = dataset["train"].map(
    lambda x: preprocess_function(x, STAGE1_MAX_SUM),
    batched=True,
    remove_columns=["text","summary"]
)
val_ds_stage1   = dataset["validation"].map(
    lambda x: preprocess_function(x, STAGE1_MAX_SUM),
    batched=True,
    remove_columns=["text","summary"]
)

print("Stage1 train columns:", train_ds_stage1.column_names)

Map:   0%|          | 0/644 [00:00<?, ? examples/s]

Map:   0%|          | 0/161 [00:00<?, ? examples/s]

Stage1 train columns: ['input_ids', 'attention_mask', 'labels']


In [63]:
wandb.init(project="bart_game_news", name="stage1")

args1 = Seq2SeqTrainingArguments(
    output_dir                  ="stage1_output",
    evaluation_strategy         ="epoch",
    per_device_train_batch_size =BATCH_SIZE,
    per_device_eval_batch_size  =BATCH_SIZE,
    num_train_epochs            =EPOCHS_STAGE1,
    predict_with_generate       =True,
    logging_dir                 ="logs_stage1",
    logging_steps               =100,
    report_to                   ="wandb",
)

trainer1 = Seq2SeqTrainer(
    model           = model,
    args            = args1,
    train_dataset   = train_ds_stage1,
    eval_dataset    = val_ds_stage1,
    tokenizer       = tokenizer,
    data_collator   = data_collator,
    compute_metrics = compute_metrics,
)

trainer1.train()
metrics1 = trainer1.evaluate()
print("Stage 1 Validation Metrics:", metrics1)

trainer1.save_model("bart_stage1")

  trainer1 = Seq2SeqTrainer(


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

In [None]:
model_stage2 = BartForConditionalGeneration.from_pretrained("bart_stage1")

train_ds_stage2 = dataset["train"].map(
    lambda x: preprocess_function(x, STAGE2_MAX_SUM),
    batched=True,
    remove_columns=["text","summary"]
)
val_ds_stage2   = dataset["validation"].map(
    lambda x: preprocess_function(x, STAGE2_MAX_SUM),
    batched=True,
    remove_columns=["text","summary"]
)

In [None]:
wandb.init(project="bart_game_news", name="stage2")

args2 = Seq2SeqTrainingArguments(
    output_dir                  ="stage2_output",
    evaluation_strategy         ="epoch",
    per_device_train_batch_size =BATCH_SIZE,
    per_device_eval_batch_size  =BATCH_SIZE,
    num_train_epochs            =EPOCHS_STAGE2,
    predict_with_generate       =True,
    logging_dir                 ="logs_stage2",
    logging_steps               =100,
    report_to                   ="wandb",
)

trainer2 = Seq2SeqTrainer(
    model           = model_stage2,
    args            = args2,
    train_dataset   = train_ds_stage2,
    eval_dataset    = val_ds_stage2,
    tokenizer       = tokenizer,
    data_collator   = data_collator,
    compute_metrics = compute_metrics,
)

trainer2.train()
metrics2 = trainer2.evaluate()
print("Stage 2 Validation Metrics:", metrics2)

trainer2.save_model("bart_stage2")