In [None]:
from transformers import T5Tokenizer, TFT5ForConditionalGeneration
from datasets import load_dataset

model_name = "t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = TFT5ForConditionalGeneration.from_pretrained(model_name)

dataset = load_dataset("cnn_dailymail", "3.0.0", split="train[:10%]")

def preprocess(example):
  input_text = f"summarize: {example['article']}"
  input_ids = tokenizer(input_text, padding = "max_length", truncation = True, return_tensors = "tf").input_ids[0]
  target_ids = tokenizer(example['highlights'], padding = 'max_length', truncation = True, max_length = 64, return_tensors = 'tf').input_ids[0]
  return {"input_ids": input_ids, "labels": target_ids}

tokenized_dataset = dataset.map(preprocess)

def generator():
  for ex in tokenized_dataset:
    yield {"input_ids": ex["input_ids"], "labels": ex["labels"]}

import tensorflow as tf
tf_dataset = tf.data.Dataset.from_generator(
    generator,
    output_signature = {
        "input_ids": tf.TensorSpec(shape = (512, ), dtype = tf.int32),
        "labels" : tf.TensorSpec(shape = (64,), dtype = tf.int32)
    }
).batch(16).prefetch(tf.data.AUTOTUNE)

from transformers import TFTrainingArguments
model.compile(optimizer = 'adam',
              loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True))

model.fit(tf_dataset, epochs = 5)

def generate_summary(text):
    input_text = f"summarize: {text.strip()}"
    input_ids = tokenizer(input_text, return_tensors="tf", padding="max_length", truncation=True, max_length=512).input_ids

    # Generate with better beam settings
    summary_ids = model.generate(
        input_ids,
        max_length=64,
        num_beams=8,
        length_penalty=2.0,
        early_stopping=True,
        no_repeat_ngram_size=2
    )

    return tokenizer.decode(summary_ids[0], skip_special_tokens=True)

# Example usage
article = """The Eiffel Tower is one of the most iconic landmarks in Paris. Built in 1889, 
it stands at over 300 meters tall and attracts millions of tourists every year. 
Originally constructed for the World's Fair, it was initially met with criticism but 
has since become a beloved symbol of France."""

summary = generate_summary(article)
print("Generated Summary:", summary)
