# Data Prototype

In [61]:
# import libraries
!pip install "datasets<4.0.0" --upgrade
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig



Load datasets

In [62]:
# 1) CNN/DailyMail (official Parquet export)
cnn  = load_dataset("cnn_dailymail","3.0.0", split="train")

# 2) SAMSum (Arrow export by knkarthick)
sams = load_dataset("knkarthick/samsum",split="train")

# 3) Reddit-TIFU (Arrow export by Oguzz07)
tifu = load_dataset("Oguzz07/reddit-tifu-dataset",split="train")

In [63]:
print(sams.column_names)
print(tifu.column_names)
print(cnn.column_names)

['id', 'dialogue', 'summary']
['instruction', 'response']
['article', 'highlights', 'id']


Normalise dataset

In [64]:
def normalise(ds, text_col, sum_col):
    return ds.rename_columns({text_col: "text", sum_col:"summary"})

cnn = normalise(cnn, "article", "highlights")
sams = normalise(sams, "dialogue", "summary")
tifu = normalise(tifu, "instruction", "response")

In [65]:
# concactenate dataset
from datasets import concatenate_datasets
ds = concatenate_datasets([cnn, sams, tifu])
ds = ds.shuffle(seed=42)
print(ds)

Dataset({
    features: ['text', 'summary', 'id'],
    num_rows: 302787
})


Train and Validation Split

In [66]:
split = ds.train_test_split(test_size=0.1, seed=42)
train_ds = split['train']
val_ds = split['test']

Tokenization

In [67]:
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")

def tokenize_fn(batch):
    # Ensure every entry in batch["text"] is a string
    texts = []
    for t in batch["text"]:
        if isinstance(t, str):
            texts.append(t)
        elif isinstance(t, list):
            # join lists of utterances just in case
            texts.append(" ".join(map(str, t)))
        elif t is None:
            texts.append("")                # empty string for missing
        else:
            texts.append(str(t))            # last resort: cast to str

    # Same for summaries
    targets = []
    for s in batch["summary"]:
        if isinstance(s, str):
            targets.append(s)
        elif s is None:
            targets.append("")
        else:
            targets.append(str(s))

    # Now tokenize
    inputs = tokenizer(
        texts,
        max_length=1024,
        truncation=True,
        padding="max_length"
    )
    labels = tokenizer(
        targets,
        max_length=150,
        truncation=True,
        padding="max_length"
    )

    return {
        "input_ids":      inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "labels":         labels["input_ids"],
    }

# tifu having no ids led to issues so
# CNN (has id)
cnn_tok = cnn.map(
    tokenize_fn,
    batched=True,
    batch_size=64,
    remove_columns=["text", "summary", "id"]
)

# SAMSum (has id)
sams_tok = sams.map(
    tokenize_fn,
    batched=True,
    batch_size=64,
    remove_columns=["text", "summary", "id"]
)

# TIFU (no id column)
tifu_tok = tifu.map(
    tokenize_fn,
    batched=True,
    batch_size=64,
    remove_columns=["text", "summary"]
)


Map:   0%|          | 0/14732 [00:00<?, ? examples/s]

In [70]:
# combined tokenised splits
mixed_tok = concatenate_datasets([cnn_tok, sams_tok, tifu_tok]).shuffle(seed=42)
split = mixed_tok.train_test_split(test_size=0.1, seed=42)
train_ds = split['train']
val_ds = split['test']
print(mixed_tok)
train_ds.set_format("torch", columns=["input_ids","attention_mask","labels"])
val_ds.set_format("torch",   columns=["input_ids","attention_mask","labels"])

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 302787
})


Initialise Model

In [71]:
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
!pip install evaluate
!pip install rouge_score
from evaluate import load
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
import numpy as np

collator = DataCollatorForSeq2Seq(tokenizer, model=model)
rouge = load("rouge")


def compute_rouge(eval_preds):
    preds, labels = eval_preds
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(
        predictions=decoded_preds,
        references=decoded_labels,
        use_stemmer=True
    )
    return {f"rouge_{k}": v.mid.fmeasure * 100 for k, v in result.items()}




In [76]:
args = Seq2SeqTrainingArguments(
    output_dir="bart-mixed",
    do_train=True,
    do_eval=True,
    logging_steps=200,
    eval_steps=1000,
    save_steps=1000,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=3e-5,
    num_train_epochs=3,
    predict_with_generate=True,
    fp16=True,
    report_to="none"
)

trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=tokenizer,
    data_collator=collator,
    compute_metrics=compute_rouge,
)

trainer.train()


  trainer = Seq2SeqTrainer(
  batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)


KeyboardInterrupt: 