In [1]:
import torch

if torch.cuda.is_available():
    print(f"GPU is available: {torch.cuda.get_device_name(0)}")
else:
    print("GPU is not available. Training will use CPU.")


GPU is available: NVIDIA GeForce RTX 3060 Ti


In [2]:
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    PreTrainedTokenizer
)
from datasets import load_dataset
import evaluate
import nltk
import torch
import numpy as np

# Ensure the NLTK tokenizer is available
try:
    nltk.data.find("tokenizers/punkt")
except (LookupError, OSError):
    nltk.download("punkt")

print("Libraries imported and NLTK tokenizer ready.")


  from .autonotebook import tqdm as notebook_tqdm



Libraries imported and NLTK tokenizer ready.


In [None]:

def preprocess_function(examples, tokenizer, max_input_length=1024, max_target_length=128):
    """Preprocess the data for training."""
    inputs = [doc for doc in examples["article"]]
    model_inputs = tokenizer(
        inputs,
        max_length=max_input_length,
        truncation=True,
        padding="max_length",
    )

    labels = tokenizer(
        [doc for doc in examples["highlights"]],
        max_length=max_target_length,
        truncation=True,
        padding="max_length",
    )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

def compute_metrics(eval_pred):
    """Compute ROUGE metrics."""
    rouge_score = evaluate.load("rouge")
    predictions, labels = eval_pred
    tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]

    result = rouge_score.compute(
        predictions=decoded_preds,
        references=decoded_labels,
        use_stemmer=True
    )

    return {key: round(value * 100, 2) for key, value in result.items()}

print("Helper functions defined.")


Helper functions defined.


In [4]:
dataset_name = "cnn_dailymail"
dataset_config = "3.0.0"

print("Loading dataset...")
dataset = load_dataset(dataset_name, dataset_config)
print("Dataset loaded.")


Loading dataset...
Dataset loaded.


In [5]:
model_name = "facebook/bart-large-cnn"
max_input_length = 1024
max_target_length = 128

print("Loading tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
print("Model and tokenizer loaded.")


Loading tokenizer and model...
Model and tokenizer loaded.


In [6]:
# Limit the dataset size for training and validation
train_size = 1000
val_size = 100

print("Selecting subsets for training and validation...")
small_train_dataset = dataset["train"].select(range(train_size))
small_val_dataset = dataset["validation"].select(range(val_size))

print(f"Training dataset size: {len(small_train_dataset)}")
print(f"Validation dataset size: {len(small_val_dataset)}")


Selecting subsets for training and validation...
Training dataset size: 1000
Validation dataset size: 100


In [7]:
print("Preprocessing the sliced dataset...")
tokenized_train_dataset = small_train_dataset.map(
    lambda x: preprocess_function(x, tokenizer, max_input_length, max_target_length),
    batched=True,
    remove_columns=dataset["train"].column_names,
    desc="Preprocessing training dataset",
)

tokenized_val_dataset = small_val_dataset.map(
    lambda x: preprocess_function(x, tokenizer, max_input_length, max_target_length),
    batched=True,
    remove_columns=dataset["validation"].column_names,
    desc="Preprocessing validation dataset",
)

print("Subset preprocessing complete.")


Preprocessing the sliced dataset...
Subset preprocessing complete.


In [8]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"

In [9]:
output_dir = "./fine_tuned_summarizer"
num_train_epochs = 3
batch_size = 2
learning_rate = 5e-5

training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    evaluation_strategy="epoch",
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_train_epochs,
    weight_decay=0.01,
    save_total_limit=3,
    predict_with_generate=True,
    logging_steps=100,
    logging_first_step=True,
    fp16=True if torch.cuda.is_available() else False,  
    gradient_accumulation_steps=4,
    dataloader_pin_memory=True,  
)

model.gradient_checkpointing_enable()
print("Training arguments set.")


Training arguments set.




In [10]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset= tokenized_val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

print("Trainer initialized.")


Trainer initialized.


In [11]:
print("Starting training...")
trainer.train()
print("Training complete.")


Starting training...


  0%|          | 1/375 [00:02<15:23,  2.47s/it]

{'loss': 8.2744, 'grad_norm': inf, 'learning_rate': 5e-05, 'epoch': 0.01}


 27%|██▋       | 100/375 [18:18<49:18, 10.76s/it] 

{'loss': 1.0975, 'grad_norm': 2.664755344390869, 'learning_rate': 3.68e-05, 'epoch': 0.8}


                                                 
 33%|███▎      | 125/375 [25:21<45:41, 10.97s/it]

{'eval_loss': 0.5987657308578491, 'eval_rouge1': 35.72, 'eval_rouge2': 15.26, 'eval_rougeL': 25.35, 'eval_rougeLsum': 32.85, 'eval_runtime': 145.1947, 'eval_samples_per_second': 0.689, 'eval_steps_per_second': 0.344, 'epoch': 1.0}


 53%|█████▎    | 200/375 [39:05<31:23, 10.76s/it]  

{'loss': 0.4182, 'grad_norm': 2.5535361766815186, 'learning_rate': 2.3466666666666667e-05, 'epoch': 1.6}


                                                 
 67%|██████▋   | 250/375 [50:38<22:52, 10.98s/it]

{'eval_loss': 0.6577863097190857, 'eval_rouge1': 35.2, 'eval_rouge2': 14.9, 'eval_rougeL': 25.2, 'eval_rougeLsum': 32.56, 'eval_runtime': 143.3548, 'eval_samples_per_second': 0.698, 'eval_steps_per_second': 0.349, 'epoch': 2.0}


 80%|████████  | 300/375 [59:47<13:27, 10.77s/it]  

{'loss': 0.2677, 'grad_norm': 2.0699822902679443, 'learning_rate': 1.0133333333333333e-05, 'epoch': 2.4}


                                                   
100%|██████████| 375/375 [1:15:49<00:00, 12.13s/it]

{'eval_loss': 0.7513128519058228, 'eval_rouge1': 36.18, 'eval_rouge2': 14.76, 'eval_rougeL': 24.67, 'eval_rougeLsum': 33.23, 'eval_runtime': 136.0607, 'eval_samples_per_second': 0.735, 'eval_steps_per_second': 0.367, 'epoch': 3.0}
{'train_runtime': 4549.8992, 'train_samples_per_second': 0.659, 'train_steps_per_second': 0.082, 'train_loss': 0.5324372838338216, 'epoch': 3.0}
Training complete.





In [12]:
!nvidia-smi

Mon Dec  9 08:57:52 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 566.14                 Driver Version: 566.14         CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3060 Ti   WDDM  |   00000000:01:00.0 Off |                  N/A |
| 44%   59C    P3             63W /  200W |    7994MiB /   8192MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [13]:
import torch

torch.cuda.empty_cache()
print("empty")


empty
