In [68]:
import os
import sys
from pathlib import Path

In [69]:
project_root = Path(__file__).parent.parent if "__file__" in locals() else Path().resolve().parent
sys.path.append(str(project_root))

In [70]:
from dataclasses import dataclass

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq, TrainingArguments, Trainer
import torch
from datasets import load_from_disk


In [71]:
from src.Briefly.constants import *
from src.Briefly.utils.common import read_yaml, create_dirs

In [72]:
CONFIG_PATH = project_root / CONFIG_PATH
PARAMS_PATH = project_root / PARAMS_PATH

In [73]:
@dataclass
class TrainingModelConfig:
    root_dir: Path
    data_path: Path
    model_checkpoint: Path
    num_train_epochs: int
    warmup_steps: int
    per_device_train_batch_size: int
    per_device_eval_batch_size: int
    weight_decay: float
    logging_steps: int
    evaluation_strategy: str
    eval_steps: int
    save_steps: float
    gradient_accumulation_steps: int

Configuration:

In [74]:
class ConfigManager: 
    def __init__(self, config_path: Path = CONFIG_PATH, params_path: Path = PARAMS_PATH):
        self.config = read_yaml(config_path) 
        self.params = read_yaml(params_path)

        create_dirs([self.config.artifacts_root])
    def get_training_model_config(self) -> TrainingModelConfig:
        config = self.config.training_model
        params = self.params.TrainingArguments

        create_dirs([config.root_dir])
        training_model_config = TrainingModelConfig(
            root_dir=config.root_dir,
            data_path=config.data_path,
            model_checkpoint=config.model_checkpoint,
            num_train_epochs=params.num_train_epochs,
            per_device_train_batch_size=params.per_device_train_batch_size,
            per_device_eval_batch_size= params.per_device_eval_batch_size,
            warmup_steps=params.warmup_steps,
            weight_decay=params.weight_decay,
            logging_steps=params.logging_steps,
            evaluation_strategy=params.evaluation_strategy,
            eval_steps=params.eval_steps,
            save_steps=params.save_steps,
            gradient_accumulation_steps=params.gradient_accumulation_steps
        )
        return training_model_config
    

Components:

In [75]:
class TrainingModel:
    def __init__(self, config:TrainingModelConfig):
        self.config = config
    def train(self):
        device = "cuda" if torch.cuda.is_available() else "cpu"
        tokenizer = AutoTokenizer.from_pretrained(self.config.model_checkpoint)
        model_google_pegasus = AutoModelForSeq2SeqLM.from_pretrained(self.config.model_checkpoint).to(device)
        seq2seq_data_collator = DataCollatorForSeq2Seq(tokenizer, model=model_google_pegasus)

        dataset_samsum_pt = load_from_disk(self.config.data_path)

        training_arguments = TrainingArguments(
            output_dir=self.config.root_dir, 
            num_train_epochs=1,
            warmup_steps=500,
            per_device_train_batch_size=1, 
            per_device_eval_batch_size=1,
            weight_decay=0.01, logging_steps=10,
            evaluation_strategy='steps', eval_steps=500, 
            save_steps=1e6,
            gradient_accumulation_steps=16
        )
        trainer = Trainer(model=model_google_pegasus, args=training_arguments,
                  processing_class=tokenizer, data_collator=seq2seq_data_collator,
                  train_dataset=dataset_samsum_pt["test"],
                  eval_dataset=dataset_samsum_pt["validation"])
        trainer.train()
        model_google_pegasus.save_pretrained(os.path.join(self.config.root_dir, "pegasus-samsum-model"))
        tokenizer.save_pretrained(os.path.join(self.config.root_dir, "tokenizer"))

In [76]:
config = ConfigManager()
training_model_config = config.get_training_model_config()
training_model_config = TrainingModel(config=training_model_config)
training_model_config.train()

[2025-01-08 03:05:19,653: INFO: common: YaML File: C:\Users\ABHINAV\Desktop\Programming\Projects\Briefly\config\config.yaml loaded successfully]
[2025-01-08 03:05:19,655: INFO: common: YaML File: C:\Users\ABHINAV\Desktop\Programming\Projects\Briefly\params.yaml loaded successfully]
[2025-01-08 03:05:19,656: INFO: common: Creating directory at: artifacts]
[2025-01-08 03:05:19,657: INFO: common: Creating directory at: artifacts/training_model]


  2%|‚ñè         | 1/51 [04:01<3:21:25, 241.71s/it]
Some weights of PegasusForConditionalGeneration were not initialized from the model checkpoint at google/pegasus-cnn_dailymail and are newly initialized: ['model.decoder.embed_positions.weight', 'model.encoder.embed_positions.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  0%|          | 0/51 [00:00<?, ?it/s]

KeyboardInterrupt: 