### Download Datasets

In [20]:
!git clone https://huggingface.co/datasets/EdinburghNLP/xsum

Cloning into 'xsum'...


### Load the XSum Dataset

In [1]:
from datasets import load_dataset

raw_datasets = load_dataset("xsum/xsum.py", trust_remote_code=True)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 204045
    })
    validation: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11332
    })
    test: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11334
    })
})

### Load the Tokenizer

In [3]:
from transformers import AutoTokenizer

# Switch to a smaller model (bart-base)
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")

### Analysis 

In [4]:
def show_samples(dataset, num_samples=3, seed=42):
    sample = raw_datasets["train"].shuffle(seed=seed).select(range(num_samples))
    for example in sample:
        print("-" * 100)
        print(f"\n'>> Document: \n\n{example['document']}'")
        print("-" * 100)
        print(f"'>> Summary: \n\n{example['summary']}'")

show_samples(raw_datasets)


----------------------------------------------------------------------------------------------------

'>> Document: 

In Wales, councils are responsible for funding and overseeing schools.
But in England, Mr Osborne's plan will mean local authorities will cease to have a role in providing education.
Academies are directly funded by central government and head teachers have more freedom over admissions and to change the way the school works.
It is a significant development in the continued divergence of schools systems on either side of Offa's Dyke.
And although the Welsh Government will get extra cash to match the money for English schools to extend the school day, it can spend it on any devolved policy area.
Ministers have no plans to follow suit.
At the moment, governing bodies are responsible for setting school hours and they need ministerial permission to make significant changes.
There are already more than 2,000 secondary academies in England and its extension to all state school

In [5]:
import numpy as np

# Add token count for documents
def check_token_count(examples):
    tokens = tokenizer(examples["document"], truncation=False)
    return {"token_count": [len(token_ids) for token_ids in tokens["input_ids"]]}

raw_datasets = raw_datasets.map(check_token_count, batched=True)

# Add token count for summaries
def check_summary_token_count(examples):
    tokens = tokenizer(examples["summary"], truncation=False)
    return {"summary_token_count": [len(token_ids) for token_ids in tokens["input_ids"]]}

raw_datasets = raw_datasets.map(check_summary_token_count, batched=True)

# Add word count for documents
def check_word_count(examples):
    return {"word_count": [len(doc.split()) for doc in examples["document"]]}

raw_datasets = raw_datasets.map(check_word_count, batched=True)

# Add word count for summaries
def check_summary_word_count(examples):
    return {"summary_word_count": [len(summ.split()) for summ in examples["summary"]]}

raw_datasets = raw_datasets.map(check_summary_word_count, batched=True)

# Function to calculate statistics
def calculate_statistics(counts):
    counts_np = np.array(counts)
    avg_count = np.mean(counts_np)
    min_count = np.min(counts_np)
    max_count = np.max(counts_np)
    return avg_count, min_count, max_count

# Calculate statistics for each split
for split in raw_datasets:
    doc_token_counts = raw_datasets[split]["token_count"]
    summary_token_counts = raw_datasets[split]["summary_token_count"]
    doc_word_counts = raw_datasets[split]["word_count"]
    summary_word_counts = raw_datasets[split]["summary_word_count"]
    
    doc_token_avg, doc_token_min, doc_token_max = calculate_statistics(doc_token_counts)
    summary_token_avg, summary_token_min, summary_token_max = calculate_statistics(summary_token_counts)
    doc_word_avg, doc_word_min, doc_word_max = calculate_statistics(doc_word_counts)
    summary_word_avg, summary_word_min, summary_word_max = calculate_statistics(summary_word_counts)
    
    print(f"{split.capitalize()} Set - Document Token Counts - Average: {doc_token_avg}, Min: {doc_token_min}, Max: {doc_token_max}")
    print(f"{split.capitalize()} Set - Summary Token Counts - Average: {summary_token_avg}, Min: {summary_token_min}, Max: {summary_token_max}")
    print(f"{split.capitalize()} Set - Document Word Counts - Average: {doc_word_avg}, Min: {doc_word_min}, Max: {doc_word_max}")
    print(f"{split.capitalize()} Set - Summary Word Counts - Average: {summary_word_avg}, Min: {summary_word_min}, Max: {summary_word_max}")
    print("-" * 50)

Train Set - Document Token Counts - Average: 488.6465289519469, Min: 2, Max: 35314
Train Set - Summary Token Counts - Average: 28.147403758974736, Min: 3, Max: 118
Train Set - Document Word Counts - Average: 373.8646328015879, Min: 0, Max: 29189
Train Set - Summary Word Counts - Average: 21.09764512730035, Min: 1, Max: 70
--------------------------------------------------
Validation Set - Document Token Counts - Average: 481.9947052594423, Min: 2, Max: 6563
Validation Set - Summary Token Counts - Average: 28.146664313448643, Min: 5, Max: 102
Validation Set - Document Word Counts - Average: 369.1336039534063, Min: 0, Max: 3937
Validation Set - Summary Word Counts - Average: 21.126720790681258, Min: 1, Max: 86
--------------------------------------------------
Test Set - Document Token Counts - Average: 491.47714840303513, Min: 2, Max: 15278
Test Set - Summary Token Counts - Average: 28.141697547203105, Min: 5, Max: 103
Test Set - Document Word Counts - Average: 376.1446973707429, Min: 0

### Preprocessing the data

In [6]:
from datasets import DatasetDict

# Filter dataset based on token and word counts
def filter_dataset(dataset, min_doc_tokens=10, min_summary_tokens=5):
    def is_valid_example(example):
        return (example['token_count'] >= min_doc_tokens and
                example['summary_token_count'] >= min_summary_tokens and
                example['word_count'] > 0)
    
    filtered_dataset = dataset.filter(is_valid_example)
    
    print(f"Total examples: {len(dataset)}")
    print(f"Remaining after filtering: {len(filtered_dataset)}")
    print(f"Removed examples: {len(dataset) - len(filtered_dataset)} ({(len(dataset) - len(filtered_dataset)) / len(dataset) * 100:.2f}%)")
    
    return filtered_dataset

filtered_datasets = DatasetDict({
    split: filter_dataset(raw_datasets[split])
    for split in raw_datasets.keys()
})

Total examples: 204045
Remaining after filtering: 203966
Removed examples: 79 (0.04%)
Total examples: 11332
Remaining after filtering: 11326
Removed examples: 6 (0.05%)
Total examples: 11334
Remaining after filtering: 11331
Removed examples: 3 (0.03%)


In [7]:
filtered_datasets

DatasetDict({
    train: Dataset({
        features: ['document', 'summary', 'id', 'token_count', 'summary_token_count', 'word_count', 'summary_word_count'],
        num_rows: 203966
    })
    validation: Dataset({
        features: ['document', 'summary', 'id', 'token_count', 'summary_token_count', 'word_count', 'summary_word_count'],
        num_rows: 11326
    })
    test: Dataset({
        features: ['document', 'summary', 'id', 'token_count', 'summary_token_count', 'word_count', 'summary_word_count'],
        num_rows: 11331
    })
})

In [8]:
# Prepare dataset for training
max_input_length = 1024
max_target_length = 512

def preprocess_function(examples):
    model_inputs = tokenizer(
        examples["document"],
        max_length=max_input_length,
        truncation=True,
    )
    
    labels = tokenizer(
        examples["summary"],
        max_length=max_target_length,
        truncation=True,
    )
    
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_datasets = filtered_datasets.map(preprocess_function, batched=True)

In [9]:
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq

model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-base")

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)


In [10]:
import nltk
import evaluate
from nltk.tokenize import sent_tokenize

nltk.download('punkt')

rouge_score = evaluate.load("rouge")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    # Decode generated summaries into text
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    # Decode reference summaries into text
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    # ROUGE expects a newline after each sentence
    decoded_preds = ["\n".join(sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(sent_tokenize(label.strip())) for label in decoded_labels]
    # Compute ROUGE scores
    result = rouge_score.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )
    # Extract the median scores
    result = {key: value * 100 for key, value in result.items()}
    return {k: round(v, 4) for k, v in result.items()}



[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\priks\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [11]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

training_args = Seq2SeqTrainingArguments(
    output_dir="./bart-base-xsum-checkpoints",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    gradient_accumulation_steps=1,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=5,
    warmup_steps=1000,
    lr_scheduler_type="cosine",
    predict_with_generate=True,
    load_best_model_at_end=True,
    metric_for_best_model="rougeL",
    greater_is_better=True,
    fp16=True,
    label_smoothing_factor=0.1,
)





In [12]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

  trainer = Seq2SeqTrainer(
  attn_output = torch.nn.functional.scaled_dot_product_attention(


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,3.3402,3.164266,37.4957,16.0158,30.7641,30.7659
2,3.1763,3.095118,38.2165,16.8821,31.5157,31.5127
3,3.0448,3.061704,38.8404,17.5522,32.2,32.1992
4,2.9831,3.051939,39.2075,17.8449,32.4087,32.4081
5,2.9226,3.050806,39.2079,17.8686,32.4777,32.4734


There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.decoder.embed_tokens.weight', 'lm_head.weight'].


TrainOutput(global_step=63740, training_loss=3.126221167543745, metrics={'train_runtime': 35039.9492, 'train_samples_per_second': 29.105, 'train_steps_per_second': 1.819, 'total_flos': 6.008027139794534e+17, 'train_loss': 3.126221167543745, 'epoch': 5.0})

In [13]:
import json

# Save model and tokenizer
model_output_dir = "bart-base-xsum"
trainer.save_model(model_output_dir)
tokenizer.save_pretrained(model_output_dir)

# Save training logs
metrics_output_file = model_output_dir + "/log_history.json"
with open(metrics_output_file, "w") as f:
    json.dump(trainer.state.log_history, f)

In [14]:
metrics = trainer.evaluate()
print("Validation Set Metrics:", metrics)

test_metrics = trainer.evaluate(eval_dataset=tokenized_datasets["test"])
print("Test Set Metrics:", test_metrics)



Validation Set Metrics: {'eval_loss': 3.0508058071136475, 'eval_rouge1': 39.2079, 'eval_rouge2': 17.8686, 'eval_rougeL': 32.4777, 'eval_rougeLsum': 32.4734, 'eval_runtime': 3226.566, 'eval_samples_per_second': 3.51, 'eval_steps_per_second': 0.11, 'epoch': 5.0}
Test Set Metrics: {'eval_loss': 3.0606689453125, 'eval_rouge1': 39.2149, 'eval_rouge2': 17.7573, 'eval_rougeL': 32.419, 'eval_rougeLsum': 32.402, 'eval_runtime': 585.0005, 'eval_samples_per_second': 19.369, 'eval_steps_per_second': 0.607, 'epoch': 5.0}


In [1]:
from transformers import pipeline

# Load the summarization pipeline using the fine-tuned model
summarizer = pipeline("summarization", model="bart-base-xsum")

# Input text for summarization
text = (
    "In a significant breakthrough in renewable energy, scientists have developed "
    "a novel solar panel technology that promises to dramatically reduce costs and "
    "increase efficiency. The new panels are lighter, more durable, and easier to install "
    "than conventional models, marking a major advancement in sustainable energy solutions. "
    "Experts believe this innovation could lead to wider adoption of solar power across residential "
    "and commercial sectors, ultimately reducing global reliance on fossil fuels."
)

# Generate summary
summary = summarizer(text)[0]["summary_text"]
print("Generated Summary:", summary)


  from .autonotebook import tqdm as notebook_tqdm
Device set to use cuda:0
Your max_length is set to 128, but your input_length is only 80. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=40)
  attn_output = torch.nn.functional.scaled_dot_product_attention(


Generated Summary: Scientists at the University of California, Berkeley, have developed a new type of solar panel.
