In [1]:
from datasets import load_dataset
df = load_dataset("knkarthick/dialogsum")

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")


In [None]:
def preprocess(batch):
    source = batch["dialogue"]
    target = batch["summary"]

    source_enc = tokenizer(source, truncation=True, padding="max_length", max_length=120)
    target_enc = tokenizer(target, truncation=True, padding="max_length", max_length=120)

    labels = [
        [(t if t != tokenizer.pad_token_id else -100) for t in seq]
        for seq in target_enc["input_ids"]
    ]

    return {
        "input_ids": source_enc["input_ids"],
        "attention_mask": source_enc["attention_mask"],
        "labels": labels
    }

df_tokenized = df.map(preprocess, batched=True)


In [None]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="./bart_dialogsum",
    per_device_train_batch_size=4,
    num_train_epochs=2,
    logging_steps=50,
    save_steps=500,
    evaluation_strategy="epoch"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=df_tokenized["train"],
    eval_dataset=df_tokenized["test"]
)

trainer.train()


In [None]:
model.save_pretrained("./dialogsum_model")
tokenizer.save_pretrained("./dialogsum_model")
