<a href="https://colab.research.google.com/github/anish0045h/ai_news/blob/main/ai_news(BRET).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    TrainingArguments,
    Trainer,
    pipeline
)

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

def load_model_and_tokenizer(model_name="facebook/bart-large-cnn"):
    """Loads the BART model and tokenizer for summarization."""
    print(f"Loading model and tokenizer: {model_name}")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    return model, tokenizer

In [None]:
from datasets import load_dataset

def load_and_prepare_dataset_for_bart(tokenizer, train_samples=2000, eval_samples=500):
    print("Loading CNN/DailyMail dataset...")
    dataset = load_dataset("cnn_dailymail", "3.0.0")

    # Select subset for faster training or debugging
    train_dataset = dataset["train"].select(range(train_samples))
    eval_dataset = dataset["validation"].select(range(eval_samples))

    def preprocess_function(examples):
        """Tokenize input (article) and target (highlights) for BART."""
        # No prefix needed for BART
        inputs = examples["article"]
        model_inputs = tokenizer(
            inputs,
            max_length=1024,
            truncation=True,
            padding="max_length"
        )

        # Tokenize summaries (targets)
        labels = tokenizer(
            text_target=examples["highlights"],
            max_length=128,
            truncation=True,
            padding="max_length"
        )
        model_inputs["labels"] = labels["input_ids"]

        return model_inputs

    print("Tokenizing dataset...")
    tokenized_train_dataset = train_dataset.map(preprocess_function, batched=True, remove_columns=dataset["train"].column_names)
    tokenized_eval_dataset = eval_dataset.map(preprocess_function, batched=True, remove_columns=dataset["validation"].column_names)

    return tokenized_train_dataset, tokenized_eval_dataset


In [None]:
from transformers import TrainingArguments

def get_training_arguments(output_dir="./results_bart_summarizer"):
    return TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=4,                   # ✅ BART usually converges faster (3–4 epochs is enough)
        per_device_train_batch_size=4,        # ✅ Slightly lower to prevent GPU OOM on Colab
        per_device_eval_batch_size=4,
        gradient_accumulation_steps=2,        # ✅ Effectively doubles batch size
        learning_rate=2e-5,                   # ✅ Slightly smaller LR works better for BART
        warmup_steps=500,
        weight_decay=0.01,
        logging_dir='./logs',
        logging_steps=100,
        eval_strategy="epoch",          # ✅ Corrected keyword argument name from 'evaluation_strategy' to 'eval_strategy'
        save_strategy="epoch",
        save_total_limit=2,                   # ✅ Keeps only 2 best checkpoints
        load_best_model_at_end=True,
        metric_for_best_model="loss",         # ✅ Select best model by lowest loss
        greater_is_better=False,
        report_to="none",                     # Disable wandb/tensorboard unless needed
        fp16=True,                            # ✅ Mixed precision for Colab GPU
        dataloader_num_workers=2,             # ✅ Speeds up data loading
    )

In [None]:
def train_model(model, tokenizer, train_dataset, eval_dataset, training_args):
    print("Starting fine-tuning on BART...")
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
    )
    trainer.train()
    print("Training complete!")
    return trainer

In [None]:
def save_and_test_model(trainer, test_article, save_path="./my_bart_summarizer"):
    # Save the fine-tuned model
    trainer.save_model(save_path)
    print(f"✅ Model saved to: {save_path}")

    # Load summarization pipeline
    summarizer = pipeline("summarization", model=save_path, tokenizer=save_path)

    # Generate summary
    summary = summarizer(
        test_article,
        max_length=150,   # controls length of summary
        min_length=30,    # prevents too-short summaries
        do_sample=False   # deterministic output (optional)
    )

    # Display result
    print("\n--- TEST ARTICLE ---")
    print(test_article[:500], "...")  # print only part of article
    print("\n--- GENERATED SUMMARY ---")
    print(summary[0]["summary_text"])

In [None]:
def main():
    # Use a BART model for better summarization quality
    model_name = "facebook/bart-large-cnn"
    model, tokenizer = load_model_and_tokenizer(model_name)

    # Load and prepare CNN/DailyMail dataset
    train_dataset, eval_dataset = load_and_prepare_dataset_for_bart(tokenizer)

    # Get optimized training arguments for GPU
    training_args = get_training_arguments()

    # Train the model
    trainer = train_model(model, tokenizer, train_dataset, eval_dataset, training_args)

    # Example test article
    test_article = """
    The Karnataka Forest Department has initiated a new conservation program in the forests surrounding Sirsi
    to protect the Malabar pied hornbill. The program involves local communities in monitoring nesting sites
    and preventing illegal logging. Officials stated on Friday that this collaborative effort aims to ensure
    the long-term survival of the iconic bird species, which is crucial for the region's biodiversity.
    The initiative also includes awareness campaigns in local schools.
    """

    # Save and test
    new_save_path = "./my_bart_summarizer"
    save_and_test_model(trainer, test_article, save_path=new_save_path)

main()

Loading model and tokenizer: facebook/bart-large-cnn
Loading CNN/DailyMail dataset...
Tokenizing dataset...
Starting fine-tuning on BART...


  trainer = Trainer(


Epoch,Training Loss,Validation Loss
1,1.0445,0.638992
2,0.5198,0.619548
3,0.354,0.684747
4,0.2221,0.764771


There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.decoder.embed_tokens.weight', 'lm_head.weight'].


Training complete!
✅ Model saved to: ./my_bart_summarizer


Device set to use cuda:0
Your max_length is set to 150, but your input_length is only 111. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=55)



--- TEST ARTICLE ---

    The Karnataka Forest Department has initiated a new conservation program in the forests surrounding Sirsi
    to protect the Malabar pied hornbill. The program involves local communities in monitoring nesting sites
    and preventing illegal logging. Officials stated on Friday that this collaborative effort aims to ensure
    the long-term survival of the iconic bird species, which is crucial for the region's biodiversity.
    The initiative also includes awareness campaigns in local school ...

--- GENERATED SUMMARY ---
Karnataka Forest Department has initiated a new conservation program in the forests surrounding Sirsi .
The program involves local communities in monitoring nesting sites and preventing illegal logging .


In [31]:
!ls ./my_bart_summarizer

config.json		model.safetensors	 tokenizer.json
generation_config.json	special_tokens_map.json  training_args.bin
merges.txt		tokenizer_config.json	 vocab.json


In [32]:
# Zip the saved model directory
!zip -r my_bart_summarizer.zip ./my_bart_summarizer

# Download the zip file
from google.colab import files
files.download('my_bart_summarizer.zip')

  adding: my_bart_summarizer/ (stored 0%)
  adding: my_bart_summarizer/tokenizer_config.json (deflated 75%)
  adding: my_bart_summarizer/vocab.json (deflated 59%)
  adding: my_bart_summarizer/generation_config.json (deflated 46%)
  adding: my_bart_summarizer/special_tokens_map.json (deflated 52%)
  adding: my_bart_summarizer/training_args.bin (deflated 53%)
  adding: my_bart_summarizer/tokenizer.json (deflated 82%)
  adding: my_bart_summarizer/config.json (deflated 62%)
  adding: my_bart_summarizer/merges.txt (deflated 53%)
  adding: my_bart_summarizer/model.safetensors (deflated 7%)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>