In [2]:
import os
import torch
import pandas as pd
import numpy as np
from datasets import load_dataset
from torch.utils.data import Dataset
from transformers import (
    T5Tokenizer,
    T5ForConditionalGeneration,
    Trainer,
    TrainingArguments,
    DataCollatorForSeq2Seq
)

# Configuration

In [3]:
class Config:
    MODEL_NAME = "google/flan-t5-base"

    # --- Hyperparameters ---
    MAX_INPUT_LENGTH = 512
    MAX_TARGET_LENGTH = 150
    BATCH_SIZE = 8
    EPOCHS = 3
    LEARNING_RATE = 2e-4
    WEIGHT_DECAY = 0.01
    WARMUP_STEPS = 500

    # --- Compression Thresholds ---
    # Ratio = (Summary Length / Article Length)
    HARSH_THRESHOLD = 0.30
    STANDARD_THRESHOLD = 0.60
    MAX_HARSH_SAMPLES = 15000
    # --- System ---
    OUTPUT_DIR = "./summarization_model_output"
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
def setup_storage():
    try:
        from google.colab import drive
        print("Google Colab detected. Mounting Drive...")
        drive.mount('/content/drive')

        # Define a safe path in your Drive
        drive_path = "/content/drive/MyDrive/Project/Text Summarizer/model"

        # Create directory if it doesn't exist
        if not os.path.exists(drive_path):
            print(f"Creating directory: {drive_path}")
            os.makedirs(drive_path, exist_ok=True)

        Config.OUTPUT_DIR = drive_path
        print(f"Storage configured. Model will be saved to: {Config.OUTPUT_DIR}")

    except ImportError:
        print(f"Google Colab not detected. Saving locally to: {Config.OUTPUT_DIR}")

# DATA PROCESSING

In [5]:
def calculate_compression_ratio(text, summary):
    text_len = len(text.split())
    summary_len = len(summary.split())
    if text_len == 0: return 0
    return summary_len / text_len

def assign_length_category(ratio):
    if ratio <= Config.HARSH_THRESHOLD:
        return "harsh"
    elif ratio <= Config.STANDARD_THRESHOLD:
        return "standard"
    else:
        return "detailed"

In [6]:
def load_and_preprocess_data():
    """
    Loads CNN/DailyMail and performs SMART SAMPLING to fix imbalance and speed.
    """
    print(f"Loading FULL dataset: cnn_dailymail (3.0.0)...")
    dataset = load_dataset("cnn_dailymail", "3.0.0")

    # Use Pandas for easier filtering
    train_df = dataset['train'].to_pandas()
    val_df = dataset['validation'].to_pandas()
    # We can skip test set processing for training to save RAM

    def process_and_balance(df, split_name):
        print(f"\nProcessing {split_name} data...")

        # 1. Calculate Metrics
        df['compression_ratio'] = df.apply(lambda row: calculate_compression_ratio(row['article'], row['highlights']), axis=1)
        df['length_category'] = df['compression_ratio'].apply(assign_length_category)

        # 2. Separate Categories
        df_harsh = df[df['length_category'] == 'harsh']
        df_standard = df[df['length_category'] == 'standard']
        df_detailed = df[df['length_category'] == 'detailed']

        print(f"  - Found {len(df_harsh)} harsh, {len(df_standard)} standard, {len(df_detailed)} detailed.")

        # 3. SMART SAMPLING (The Fix)
        # Keep ALL standard and detailed (because they are rare)
        # Cap 'harsh' to save time
        if len(df_harsh) > Config.MAX_HARSH_SAMPLES and split_name == 'train':
            print(f"  - Downsampling 'harsh' from {len(df_harsh)} to {Config.MAX_HARSH_SAMPLES} to balance data & speed up training.")
            df_harsh = df_harsh.sample(n=Config.MAX_HARSH_SAMPLES, random_state=42)

        # 4. Recombine
        balanced_df = pd.concat([df_harsh, df_standard, df_detailed])
        balanced_df = balanced_df.sample(frac=1, random_state=42).reset_index(drop=True) # Shuffle

        # 5. Format Prompt
        balanced_df['input_text'] = balanced_df.apply(lambda row: f"summarize {row['length_category']}: {row['article']}", axis=1)
        balanced_df['target_text'] = balanced_df['highlights']

        return balanced_df

    train_df = process_and_balance(train_df, 'train')
    val_df = process_and_balance(val_df, 'validation') # We don't strictly need to balance validation, but it helps speed

    print("\n Final Training Data Distribution:")
    print(train_df['length_category'].value_counts())
    print(f"Total Training Samples: {len(train_df)}")

    return train_df, val_df

In [7]:
class SummarizationDataset(Dataset):
    def __init__(self, dataframe, tokenizer):
        self.data = dataframe.reset_index(drop=True)
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        input_encoding = self.tokenizer(
            row['input_text'],
            max_length=Config.MAX_INPUT_LENGTH,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        target_encoding = self.tokenizer(
            row['target_text'],
            max_length=Config.MAX_TARGET_LENGTH,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        labels = target_encoding['input_ids']
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            'input_ids': input_encoding['input_ids'].squeeze(),
            'attention_mask': input_encoding['attention_mask'].squeeze(),
            'labels': labels.squeeze()
        }

# TRAINING PIPELINE

In [8]:
def train():
    print(f"Initializing Model: {Config.MODEL_NAME}")
    tokenizer = T5Tokenizer.from_pretrained(Config.MODEL_NAME)
    model = T5ForConditionalGeneration.from_pretrained(Config.MODEL_NAME)
    model.to(Config.DEVICE)

    train_df, val_df = load_and_preprocess_data()

    train_dataset = SummarizationDataset(train_df, tokenizer)
    val_dataset = SummarizationDataset(val_df, tokenizer)

    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

    args = TrainingArguments(
        output_dir=Config.OUTPUT_DIR,
        num_train_epochs=Config.EPOCHS,
        per_device_train_batch_size=Config.BATCH_SIZE,
        per_device_eval_batch_size=Config.BATCH_SIZE,
        warmup_steps=Config.WARMUP_STEPS,
        weight_decay=Config.WEIGHT_DECAY,
        logging_dir=f"{Config.OUTPUT_DIR}/logs",
        logging_steps=50,
        eval_strategy="steps",
        eval_steps=500,
        save_steps=500,
        save_total_limit=2,
        load_best_model_at_end=True,
        learning_rate=Config.LEARNING_RATE,
        report_to="none",
        fp16=True,
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer
    )

    print("Starting Training...")
    trainer.train()

    print("Saving Model...")
    trainer.save_model(Config.OUTPUT_DIR)
    tokenizer.save_pretrained(Config.OUTPUT_DIR)
    print(f"Model saved to: {Config.OUTPUT_DIR}")
    return model, tokenizer

#Inference

In [9]:
def generate_summary(text, style, model, tokenizer):
    """
    Generates a summary based on the requested style (harsh/standard/detailed).
    This function handles the length constraints at inference time.
    """
    model.eval()

    # 1. Construct the Prompt
    input_text = f"summarize {style}: {text}"

    inputs = tokenizer(
        input_text,
        max_length=Config.MAX_INPUT_LENGTH,
        truncation=True,
        return_tensors="pt"
    ).to(Config.DEVICE)

    # 2. Dynamic Length Configuration
    base_len = 150
    if style == "harsh":
        # Force concise output
        min_len, max_gen_len = 10, int(base_len * 0.4)
    elif style == "standard":
        min_len, max_gen_len = 30, base_len
    else: # detailed
        # Encourage longer output
        min_len, max_gen_len = 60, int(base_len * 1.5)

    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_length=max_gen_len,
            min_length=min_len,
            num_beams=4,
            length_penalty=1.0 if style != "harsh" else 2.0,
            early_stopping=True,
            no_repeat_ngram_size=3
        )

    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [10]:
if __name__ == "__main__":
    # 1. Setup Drive
    setup_storage()

    # 2. Train the model
    model, tokenizer = train()

    # 3. Test the "Sliders"
    print("\n" + "="*50)
    print("TESTING MODEL CAPABILITIES")
    print("="*50)

    sample_article = """
    Artificial intelligence (AI) is intelligence demonstrated by machines, as opposed to the natural intelligence displayed by animals including humans.
    Leading AI textbooks define the field as the study of "intelligent agents": any system that perceives its environment and takes actions that maximize its chance of achieving its goals.
    Some popular accounts use the term "artificial intelligence" to describe machines that mimic "cognitive" functions that humans associate with the human mind, such as "learning" and "problem solving", however, this definition is rejected by major AI researchers.
    AI applications include advanced web search engines (e.g., Google), recommendation systems (used by YouTube, Amazon and Netflix), understanding human speech (such as Siri and Alexa), self-driving cars (e.g., Tesla), automated decision-making and competing at the highest level in strategic game systems (such as chess and Go).
    As machines become increasingly capable, tasks considered to require "intelligence" are often removed from the definition of AI, a phenomenon known as the AI effect.
    For instance, optical character recognition is frequently excluded from things considered to be AI, having become a routine technology.
    """

    print(f"\nOriginal Text Length: {len(sample_article.split())} words")

    styles = ["harsh", "standard", "detailed"]

    for style in styles:
        summary = generate_summary(sample_article, style, model, tokenizer)
        print(f"\n--- {style.upper()} SUMMARY ---")
        print(summary)
        print(f"Length: {len(summary.split())} words")

Google Colab detected. Mounting Drive...
Mounted at /content/drive
Storage configured. Model will be saved to: /content/drive/MyDrive/Project/Text Summarizer/model
Initializing Model: google/flan-t5-base


tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Loading FULL dataset: cnn_dailymail (3.0.0)...


README.md: 0.00B [00:00, ?B/s]

3.0.0/train-00000-of-00003.parquet:   0%|          | 0.00/257M [00:00<?, ?B/s]

3.0.0/train-00001-of-00003.parquet:   0%|          | 0.00/257M [00:00<?, ?B/s]

3.0.0/train-00002-of-00003.parquet:   0%|          | 0.00/259M [00:00<?, ?B/s]

3.0.0/validation-00000-of-00001.parquet:   0%|          | 0.00/34.7M [00:00<?, ?B/s]

3.0.0/test-00000-of-00001.parquet:   0%|          | 0.00/30.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/287113 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13368 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11490 [00:00<?, ? examples/s]


Processing train data...
  - Found 285497 harsh, 1453 standard, 163 detailed.
  - Downsampling 'harsh' from 285497 to 15000 to balance data & speed up training.

Processing validation data...
  - Found 13247 harsh, 116 standard, 5 detailed.

 Final Training Data Distribution:
length_category
harsh       15000
standard     1453
detailed      163
Name: count, dtype: int64
Total Training Samples: 16616


  trainer = Trainer(


Starting Training...


  batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)


Step,Training Loss,Validation Loss
500,0.0,
1000,0.0,
1500,0.0,
2000,0.0,
2500,0.0,
3000,0.0,
3500,0.0,
4000,0.0,
4500,0.0,
5000,0.0,


Saving Model...
Model saved to: /content/drive/MyDrive/Project/Text Summarizer/model

TESTING MODEL CAPABILITIES

Original Text Length: 175 words

--- HARSH SUMMARY ---
summary harsh: Artificial intelligence (AI) is intelligence demonstrated by machines, as opposed to the natural intelligence displayed by animals including humans.
Length: 21 words

--- STANDARD SUMMARY ---
The study of "intelligent agents": any system that perceives its environment and takes actions that maximize its chance of achieving its goals.
Length: 22 words

--- DETAILED SUMMARY ---
Understand the meaning of "artificial intelligence" in the context of computer science. Understand the definition of artificial intelligence. Understand how artificial intelligence is used in computer science and technology. Understand what artificial intelligence means to us and what it means to the rest of the world. Understand why artificial intelligence has become a standard practice.
Length: 55 words
