In [None]:
%%capture
!pip install transformers==4.28.0 fsspec==2023.6.0 hf_xet datasets==2.17.1 rouge_score

In [None]:
output_dir = "/content/drive/MyDrive/bert2bert-cnn-dailymail"

In [None]:
from datasets import load_dataset

dataset = load_dataset("cnn_dailymail", "3.0.0")

train_dataset = dataset["train"]
val_dataset = dataset["validation"]
test_dataset = dataset["test"]

"""train_dataset = dataset["train"].select(range(int(0.05 * len(dataset["train"]))))
val_dataset = dataset["validation"].select(range(int(0.05 * len(dataset["validation"]))))
test_dataset = dataset["test"].select(range(int(0.05 * len(dataset["test"]))))"""

In [None]:
from transformers import BertTokenizerFast, EncoderDecoderModel

# Load tokenizer
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
tokenizer.bos_token = tokenizer.cls_token
tokenizer.eos_token = tokenizer.sep_token

# Load encoder-decoder model
model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")

# Tie special tokens to model config
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id

In [None]:
from transformers import BertTokenizerFast

tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

# Set BOS and EOS tokens
tokenizer.bos_token = tokenizer.cls_token
tokenizer.eos_token = tokenizer.sep_token

# Check if they are set
print("BOS token:", tokenizer.bos_token)
print("BOS token ID:", tokenizer.bos_token_id)

print("EOS token:", tokenizer.eos_token)
print("EOS token ID:", tokenizer.eos_token_id)

# Confirm through special_tokens_map
print("\nSpecial Tokens Map:")
print(tokenizer.special_tokens_map)

# Confirm through all special tokens
print("\nAll Special Tokens:")
print(tokenizer.all_special_tokens)

In [None]:
batch_size=16
encoder_max_length=512
decoder_max_length=128

def process_data_to_model_inputs(batch):
    # Tokenize the inputs and labels
    inputs = tokenizer(
        batch["article"],
        padding="max_length",
        truncation=True,
        max_length=encoder_max_length
    )
    outputs = tokenizer(
        batch["highlights"],
        padding="max_length",
        truncation=True,
        max_length=decoder_max_length
    )

    batch["input_ids"] = inputs["input_ids"]
    batch["attention_mask"] = inputs["attention_mask"]

    labels = outputs["input_ids"]
    labels = [[-100 if token == tokenizer.pad_token_id else token for token in seq] for seq in labels]
    batch["labels"] = labels

    return batch

In [None]:
train_data = train_dataset.map(
    process_data_to_model_inputs,
    batched=True,
    batch_size=batch_size,
    remove_columns=["article", "highlights", "id"]
)

val_data = val_dataset.map(
    process_data_to_model_inputs,
    batched=True,
    batch_size=batch_size,
    remove_columns=["article", "highlights", "id"]
)

In [None]:
# Show tokenized details for the first 10 examples, with special tokens included
for i in range(10):
    example = train_data[i]

    print(f"\nExample {i+1}:")
    print("Input IDs:", example["input_ids"])
    print("Decoded Input (with special tokens):", tokenizer.decode(example["input_ids"], skip_special_tokens=True))

    print("\nAttention Mask:", example["attention_mask"])

    print("\nLabels:", example["labels"])

    # Replace -100 with pad_token_id to decode labels properly
    decoded_labels = [token if token != -100 else tokenizer.pad_token_id for token in example["labels"]]
    print("Decoded Labels (with special tokens):", tokenizer.decode(decoded_labels, skip_special_tokens=True))

In [None]:
model.config.vocab_size = model.config.decoder.vocab_size
model.config.max_length = 142
model.config.min_length = 56
model.config.no_repeat_ngram_size = 3
model.config.early_stopping = True
model.config.length_penalty = 1.2
model.config.num_beams = 4
model.config.repetition_penalty = 1.5

In [None]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="/content/drive/MyDrive/bert2bert-checkpoints",
    save_steps=500,
    save_total_limit=3,
    evaluation_strategy="steps",
    eval_steps=500,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    learning_rate=2e-5,
    weight_decay=0.01,
    num_train_epochs=3,
    predict_with_generate=True,
    logging_dir="/content/drive/MyDrive/bert2bert-logs",
    logging_steps=100,
    fp16=True
)

In [None]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=val_data,
    tokenizer=tokenizer,
)

In [None]:
import torch
import builtins
_real_torch_load = torch.load

def safe_torch_load(*args, **kwargs):
    kwargs['weights_only'] = False  # Force it to False to fix the issue
    return _real_torch_load(*args, **kwargs)

torch.load = safe_torch_load

In [None]:
from transformers.trainer_utils import get_last_checkpoint
import torch
from torch.serialization import add_safe_globals
import numpy as np

add_safe_globals([
    np.core.multiarray._reconstruct,
    np.ndarray,
    np.dtype,
    np.float64,
    np.int64,
    np.dtypes.UInt32DType,
])

# Resume from checkpoint
checkpoint_dir = training_args.output_dir
last_checkpoint = get_last_checkpoint(checkpoint_dir)

if last_checkpoint is not None:
    print(f"Resuming from checkpoint: {last_checkpoint}")
    trainer.train(resume_from_checkpoint=last_checkpoint)
else:
    print("Starting from scratch.")
    trainer.train()

In [None]:
results = trainer.evaluate()
print(results)

In [None]:
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

In [None]:
dataset = load_dataset("cnn_dailymail", "3.0.0")
train_dataset = dataset["train"]
val_dataset = dataset["validation"]
test_dataset = dataset["test"]

In [None]:
from transformers import BertTokenizerFast, EncoderDecoderModel

model = EncoderDecoderModel.from_pretrained(output_dir)
tokenizer = BertTokenizerFast.from_pretrained(output_dir)

In [None]:
encoder_max_length=512
decoder_max_length=128

def preprocess_test_set(batch):
    # Only tokenize the input article
    inputs = tokenizer(
        batch["article"],
        padding="max_length",
        truncation=True,
        max_length=encoder_max_length
    )

    batch["input_ids"] = inputs["input_ids"]
    batch["attention_mask"] = inputs["attention_mask"]
    return batch

tokenized_test = test_dataset.map(
    preprocess_test_set,
    batched=True,
    remove_columns=test_dataset.column_names
)

In [None]:
import torch
from tqdm import tqdm

# Move model to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Store generated summaries
generated_summaries = []

# Batch size
batch_size = 16

# Generate in batches
for i in tqdm(range(0, len(tokenized_test), batch_size)):
    batch = tokenized_test[i: i + batch_size]

    input_ids = torch.tensor(batch["input_ids"]).to(device)
    attention_mask = torch.tensor(batch["attention_mask"]).to(device)

    # Generate summaries
    outputs = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_length=decoder_max_length,
        num_beams=4,
        length_penalty=2.0,
        no_repeat_ngram_size=3,
        early_stopping=True
    )

    # Decode summaries
    decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    generated_summaries.extend(decoded)

In [None]:
# Number of samples to show
num_samples = 25

for i in range(num_samples):
    print(f"\n--- Example {i+1} ---")
    print("\nARTICLE:\n", test_dataset[i]["article"])
    print("\nREFERENCE SUMMARY:\n", test_dataset[i]["highlights"])
    print("\nGENERATED SUMMARY:\n", generated_summaries[i])

In [None]:
from datasets import load_metric
rouge = load_metric("rouge")

# Reference summaries
reference_summaries = [test_dataset[i]["highlights"] for i in range(len(generated_summaries))]

# Compute ROUGE
results = rouge.compute(predictions=generated_summaries, references=reference_summaries)

# Print results
for key in results:
    print(f"{key}: {results[key].mid.fmeasure:.4f}")