In [1]:
# !pip install transformers datasets sacrebleu

## Load Dataset

In [None]:
from datasets import load_dataset

dataset = load_dataset("samsum")
# Check available splits
print(dataset)

# Access train/test/validation splits
print(dataset["train"][0])  # sample from training set


  from .autonotebook import tqdm as notebook_tqdm


DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 14732
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 819
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 818
    })
})
{'id': '13818513', 'dialogue': "Amanda: I baked  cookies. Do you want some?\r\nJerry: Sure!\r\nAmanda: I'll bring you tomorrow :-)", 'summary': 'Amanda baked cookies and will bring Jerry some tomorrow.'}


In [3]:
print(dataset)
print(dataset["train"][0])  # Shows first data point

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 14732
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 819
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 818
    })
})
{'id': '13818513', 'dialogue': "Amanda: I baked  cookies. Do you want some?\r\nJerry: Sure!\r\nAmanda: I'll bring you tomorrow :-)", 'summary': 'Amanda baked cookies and will bring Jerry some tomorrow.'}


In [14]:
# Define train, validation, and test datasets
train_dataset = dataset["train"]
val_dataset = dataset["validation"]
eval_dataset = dataset["test"]

In [18]:
train_dataset[0]

{'id': '13818513',
 'dialogue': "Amanda: I baked  cookies. Do you want some?\r\nJerry: Sure!\r\nAmanda: I'll bring you tomorrow :-)",
 'summary': 'Amanda baked cookies and will bring Jerry some tomorrow.'}

## Load Pretrained Seq2Seq Model
facebook/bart-base: Pretrained BART model (you can also try t5-base, etc.).

AutoTokenizer: Loads the correct tokenizer for BART (word-piece tokenization).

AutoModelForSeq2SeqLM: Special class for encoder-decoder (Seq2Seq) tasks like summarization.


In [4]:
model_checkpoint = "facebook/bart-base"
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

##  Preprocess the Dataset



In [23]:
# Set limits
max_input_length = 512
max_target_length = 128


# Preprocessing function

def preprocess_function(examples):
    # Tokenize the input (dialogue)
    model_inputs = tokenizer(examples['dialogue'], max_length=1024, padding="max_length", truncation=True)

    # Tokenize the output (summary) and ensure padding and truncation
    labels = tokenizer(examples['summary'], max_length=256, padding="max_length", truncation=True)

    # Make sure the labels are included in the model's inputs
    model_inputs["labels"] = labels["input_ids"]

    return model_inputs


In [24]:
# Apply data preprocessing
tokenized_dataset = dataset.map(preprocess_function, batched=True)

Map: 100%|██████████| 14732/14732 [00:03<00:00, 4188.39 examples/s]
Map: 100%|██████████| 819/819 [00:00<00:00, 4420.62 examples/s]
Map: 100%|██████████| 818/818 [00:00<00:00, 4553.15 examples/s]


In [25]:
# Split into train, test splits

tokenized_dataset = tokenized_dataset["train"].train_test_split(test_size=0.2)

### Explanation:

- `tokenizer(inputs, max_length=..., truncation=True)`: Converts raw text into token IDs, capped to a max length.
- `as_target_tokenizer()`: Applies tokenization specific to target (summary) formatting.
- `labels["input_ids"]`: These are the gold summaries to compare during training.
- `dataset.map(...)`: Applies preprocessing to every example in the dataset.

In [26]:
# !pip install evaluate

In [27]:
import nltk
import numpy as np
import evaluate  # Use the evaluate library instead of datasets

# Download necessary NLTK data (if required for tokenization or other tasks)
nltk.download("punkt")

# Load the metric using 'evaluate'
metric = evaluate.load("sacrebleu")  # Example: Using the BLEU metric

[nltk_data] Downloading package punkt to /Users/tejasgadi/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [28]:
# compute_metrics function
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, tokenizer.pad_token_id, labels)  # Replace -100 (ignore index)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Sentence tokenize for sacrebleu
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]

    # Compute the metric using sacrebleu
    return metric.compute(predictions=decoded_preds, references=[[l] for l in decoded_labels])


In [29]:
# !pip install --upgrade transformers

In [30]:
# !pip install 'accelerate>=0.26.0'
# !pip install torch
# !pip install "transformers[torch]"



In [31]:
# Define training args

from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",               # Evaluate after every epoch
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    weight_decay=0.01,
    save_total_limit=3,                        # Keep only last 3 checkpoints
    num_train_epochs=3,
    predict_with_generate=True,                # Generate summaries during evaluation
    fp16=False,                                # Use float16 if on GPU (faster)
    logging_dir="./logs",
)


In [32]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

trainer.train()


  trainer = Seq2SeqTrainer(


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 