In [1]:
import os

In [2]:
os.chdir("../")

In [3]:
%pwd

'e:\\project\\Text Summerization'

In [4]:
from dataclasses import dataclass
from pathlib import Path


@dataclass(frozen=True)
class ModelTrainerConfig:
    root_dir: Path
    data_path: Path
    model_ckpt: Path
    num_train_epochs: int
    warmup_steps: int
    per_device_train_batch_size: int
    per_device_eval_batch_size: int
    weight_decay: float
    logging_steps: int
    eval_strategy: str
    eval_steps: int
    save_steps: float
    gradient_accumulation_steps: int
    learning_rate: float
    fp16: bool
    save_total_limit: int

In [5]:
from CutYourText.constants import *
from CutYourText.utils.common import read_yaml, create_directories

In [6]:
class ConfigurationManager:
    def __init__(self, config_filepath = CONFIG_FILE_PATH, params_filepath = PARAMS_FILE_PATH):
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_root])
    
    def get_model_trainer_config(self):
        config = self.config.model_trainer
        params = self.params.TrainingArguments

        create_directories([config.root_dir])

        model_trainer_config = ModelTrainerConfig(
            root_dir=config.root_dir,
            data_path=config.data_path,
            model_ckpt = config.model_ckpt,
            num_train_epochs = params.num_train_epochs,
            warmup_steps = params.warmup_steps,
            per_device_train_batch_size = params.per_device_train_batch_size,
            per_device_eval_batch_size = params.per_device_eval_batch_size,
            weight_decay = params.weight_decay,
            logging_steps = params.logging_steps,
            eval_strategy = params.eval_strategy,
            eval_steps = params.eval_steps,
            save_steps = params.save_steps,
            gradient_accumulation_steps = params.gradient_accumulation_steps,
            learning_rate = params.learning_rate,
            fp16 = params.fp16,
            save_total_limit = params.save_total_limit
        )

        return model_trainer_config

In [7]:
from transformers import TrainingArguments, Trainer
from transformers import DataCollatorForSeq2Seq
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from datasets import load_from_disk
import torch
from torch.optim import AdamW

  from .autonotebook import tqdm as notebook_tqdm


[2024-08-18 06:33:15,589: INFO: config: PyTorch version 2.4.0+cu121 available.]


In [8]:
class ModelTrainer:
    def __init__(self, config: ModelTrainerConfig):
        self.config = config
        
    def train(self):
        device = "cuda" if torch.cuda.is_available() else "cpu"
        tokenizer = AutoTokenizer.from_pretrained(self.config.model_ckpt)
        model_bart = AutoModelForSeq2SeqLM.from_pretrained(self.config.model_ckpt).to(device)
        seq2seq_data_collator = DataCollatorForSeq2Seq(tokenizer, model=model_bart)
        
        # Load data
        dataset_dialogsum_pt = load_from_disk(self.config.data_path)

        # Define training arguments
        trainer_args = TrainingArguments(
            output_dir=self.config.root_dir, 
            num_train_epochs=self.config.num_train_epochs, 
            warmup_steps=self.config.warmup_steps,
            per_device_train_batch_size=self.config.per_device_train_batch_size, 
            per_device_eval_batch_size=self.config.per_device_train_batch_size,
            weight_decay=self.config.weight_decay, 
            logging_steps=self.config.logging_steps,
            eval_strategy=self.config.eval_strategy, 
            eval_steps=self.config.eval_steps, 
            save_steps= self.config.save_steps,
            gradient_accumulation_steps=self.config.gradient_accumulation_steps,
            learning_rate=self.config.learning_rate,
            fp16=self.config.fp16,
            save_total_limit=self.config.save_total_limit
        ) 

        # Define the optimizer using torch.optim.AdamW
        optimizer = AdamW(
            model_bart.parameters(),
            lr=3e-5,             # Fine-tuning learning rate for text summarization
            weight_decay=0.01,    # Standard weight decay
            eps=1e-8              # Epsilon to avoid division by zero in Adam
        )

        # Initialize Trainer with the custom optimizer
        trainer = Trainer(
            model=model_bart, 
            args=trainer_args,
            tokenizer=tokenizer, 
            data_collator=seq2seq_data_collator,
            train_dataset=dataset_dialogsum_pt["train"], 
            eval_dataset=dataset_dialogsum_pt["validation"],
            optimizers=(optimizer, None)  
        )
        
        trainer.train()

        # Save model
        model_bart.save_pretrained(os.path.join(self.config.root_dir, "bart-dialogsum-model"))
        # Save tokenizer
        tokenizer.save_pretrained(os.path.join(self.config.root_dir, "tokenizer"))

In [9]:
try:
    config = ConfigurationManager()
    model_trainer_config = config.get_model_trainer_config()
    model_trainer = ModelTrainer(model_trainer_config)
    model_trainer.train()
except Exception as e:
    raise e

[2024-08-18 06:33:52,528: INFO: common: yaml file: config\config.yaml loaded successfully]
[2024-08-18 06:33:52,540: INFO: common: yaml file: params.yaml loaded successfully]


[2024-08-18 06:33:52,541: INFO: common: created directory at: artifacts]
[2024-08-18 06:33:52,543: INFO: common: created directory at: artifacts/model_trainer]


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)
  attn_output = torch.nn.functional.scaled_dot_product_attention(
  2%|▏         | 50/2334 [12:52<10:05:42, 15.91s/it]

{'loss': 1.8231, 'grad_norm': 5.613476276397705, 'learning_rate': 2.88e-06, 'epoch': 0.06}


  4%|▍         | 100/2334 [26:18<10:46:40, 17.37s/it]

{'loss': 1.2746, 'grad_norm': 5.3094048500061035, 'learning_rate': 5.82e-06, 'epoch': 0.13}


  6%|▋         | 150/2334 [39:06<9:02:53, 14.91s/it] 

{'loss': 1.1924, 'grad_norm': 6.198858261108398, 'learning_rate': 8.82e-06, 'epoch': 0.19}


  9%|▊         | 200/2334 [52:13<9:34:51, 16.16s/it] 

{'loss': 1.1347, 'grad_norm': 5.065151691436768, 'learning_rate': 1.182e-05, 'epoch': 0.26}


 11%|█         | 250/2334 [1:04:57<9:09:22, 15.82s/it]

{'loss': 1.0996, 'grad_norm': 4.685226917266846, 'learning_rate': 1.482e-05, 'epoch': 0.32}


 13%|█▎        | 300/2334 [1:19:23<10:03:06, 17.79s/it]

{'loss': 1.1092, 'grad_norm': 4.759219646453857, 'learning_rate': 1.782e-05, 'epoch': 0.39}


 15%|█▍        | 350/2334 [1:34:54<10:09:50, 18.44s/it]

{'loss': 1.0819, 'grad_norm': 4.807496547698975, 'learning_rate': 2.082e-05, 'epoch': 0.45}


 17%|█▋        | 400/2334 [1:49:28<9:39:09, 17.97s/it] 

{'loss': 1.0858, 'grad_norm': 3.9146718978881836, 'learning_rate': 2.3820000000000002e-05, 'epoch': 0.51}


 19%|█▉        | 450/2334 [2:02:41<8:02:39, 15.37s/it] 

{'loss': 1.0669, 'grad_norm': 5.43733549118042, 'learning_rate': 2.682e-05, 'epoch': 0.58}


 21%|██▏       | 500/2334 [2:15:55<8:11:05, 16.07s/it]

{'loss': 1.0549, 'grad_norm': 4.416407585144043, 'learning_rate': 2.982e-05, 'epoch': 0.64}


 24%|██▎       | 550/2334 [2:31:05<8:59:18, 18.14s/it] 

{'loss': 1.0492, 'grad_norm': 4.626941680908203, 'learning_rate': 2.9231188658669577e-05, 'epoch': 0.71}


 26%|██▌       | 600/2334 [2:46:57<9:02:06, 18.76s/it] 

{'loss': 1.0636, 'grad_norm': 3.954674482345581, 'learning_rate': 2.841330425299891e-05, 'epoch': 0.77}


 28%|██▊       | 650/2334 [3:01:37<7:37:54, 16.32s/it] 

{'loss': 1.0363, 'grad_norm': 4.806249141693115, 'learning_rate': 2.7595419847328245e-05, 'epoch': 0.83}


 30%|██▉       | 700/2334 [3:15:07<8:17:29, 18.27s/it]

{'loss': 1.0186, 'grad_norm': 3.965263843536377, 'learning_rate': 2.677753544165758e-05, 'epoch': 0.9}


 32%|███▏      | 750/2334 [3:28:30<6:51:26, 15.58s/it]

{'loss': 1.0468, 'grad_norm': 4.339626789093018, 'learning_rate': 2.5959651035986914e-05, 'epoch': 0.96}


 34%|███▍      | 800/2334 [3:41:31<7:11:00, 16.86s/it]

{'loss': 0.9726, 'grad_norm': 3.81845760345459, 'learning_rate': 2.5141766630316248e-05, 'epoch': 1.03}


 36%|███▋      | 850/2334 [3:56:55<8:06:29, 19.67s/it]

{'loss': 0.8744, 'grad_norm': 4.437399864196777, 'learning_rate': 2.4323882224645582e-05, 'epoch': 1.09}


 39%|███▊      | 900/2334 [4:12:24<7:13:45, 18.15s/it]

{'loss': 0.8657, 'grad_norm': 6.225265979766846, 'learning_rate': 2.350599781897492e-05, 'epoch': 1.16}


 41%|████      | 950/2334 [4:27:33<6:16:25, 16.32s/it]

{'loss': 0.8462, 'grad_norm': 5.015539169311523, 'learning_rate': 2.2688113413304254e-05, 'epoch': 1.22}


 43%|████▎     | 1000/2334 [4:40:29<6:52:29, 18.55s/it]

{'loss': 0.8535, 'grad_norm': 3.362178087234497, 'learning_rate': 2.187022900763359e-05, 'epoch': 1.28}


                                                       
Non-default generation parameters: {'max_length': 142, 'min_length': 56, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


{'eval_loss': 1.0869678258895874, 'eval_runtime': 136.1615, 'eval_samples_per_second': 3.672, 'eval_steps_per_second': 1.836, 'epoch': 1.28}


 45%|████▍     | 1050/2334 [4:55:39<5:02:33, 14.14s/it] 

{'loss': 0.8559, 'grad_norm': 5.644124984741211, 'learning_rate': 2.1052344601962923e-05, 'epoch': 1.35}


 47%|████▋     | 1100/2334 [5:09:54<6:10:46, 18.03s/it]

{'loss': 0.8542, 'grad_norm': 3.822324275970459, 'learning_rate': 2.0234460196292257e-05, 'epoch': 1.41}


 49%|████▉     | 1150/2334 [5:25:36<5:49:30, 17.71s/it]

{'loss': 0.8515, 'grad_norm': 3.7318804264068604, 'learning_rate': 1.941657579062159e-05, 'epoch': 1.48}


 51%|█████▏    | 1200/2334 [5:41:08<6:10:15, 19.59s/it]

{'loss': 0.8222, 'grad_norm': 3.40995717048645, 'learning_rate': 1.859869138495093e-05, 'epoch': 1.54}


 54%|█████▎    | 1250/2334 [5:56:45<5:54:13, 19.61s/it]

{'loss': 0.8564, 'grad_norm': 4.109475612640381, 'learning_rate': 1.7780806979280263e-05, 'epoch': 1.61}


 56%|█████▌    | 1300/2334 [6:12:16<5:34:00, 19.38s/it]

{'loss': 0.8466, 'grad_norm': 3.1213114261627197, 'learning_rate': 1.6962922573609597e-05, 'epoch': 1.67}


 58%|█████▊    | 1350/2334 [6:27:49<4:35:53, 16.82s/it]

{'loss': 0.83, 'grad_norm': 3.3308284282684326, 'learning_rate': 1.614503816793893e-05, 'epoch': 1.73}


 60%|█████▉    | 1400/2334 [6:43:31<5:17:17, 20.38s/it]

{'loss': 0.8511, 'grad_norm': 5.015727519989014, 'learning_rate': 1.5327153762268266e-05, 'epoch': 1.8}


 62%|██████▏   | 1450/2334 [6:59:27<5:01:26, 20.46s/it]

{'loss': 0.8426, 'grad_norm': 3.2566137313842773, 'learning_rate': 1.4509269356597602e-05, 'epoch': 1.86}


 64%|██████▍   | 1500/2334 [7:14:45<4:18:55, 18.63s/it]

{'loss': 0.8367, 'grad_norm': 3.4181246757507324, 'learning_rate': 1.3691384950926936e-05, 'epoch': 1.93}


 66%|██████▋   | 1550/2334 [7:30:40<4:16:38, 19.64s/it]

{'loss': 0.8536, 'grad_norm': 3.391620635986328, 'learning_rate': 1.2873500545256272e-05, 'epoch': 1.99}


 69%|██████▊   | 1600/2334 [7:46:08<4:04:23, 19.98s/it]

{'loss': 0.6698, 'grad_norm': 3.7891039848327637, 'learning_rate': 1.2055616139585606e-05, 'epoch': 2.05}


 71%|███████   | 1650/2334 [8:01:46<3:14:43, 17.08s/it]

{'loss': 0.6512, 'grad_norm': 4.015625, 'learning_rate': 1.123773173391494e-05, 'epoch': 2.12}


 73%|███████▎  | 1700/2334 [8:17:21<3:04:09, 17.43s/it]

{'loss': 0.6601, 'grad_norm': 2.927855968475342, 'learning_rate': 1.0419847328244276e-05, 'epoch': 2.18}


 75%|███████▍  | 1750/2334 [8:31:39<2:33:13, 15.74s/it]

{'loss': 0.6339, 'grad_norm': 3.631047248840332, 'learning_rate': 9.601962922573609e-06, 'epoch': 2.25}


 77%|███████▋  | 1800/2334 [8:45:29<2:08:26, 14.43s/it]

{'loss': 0.6553, 'grad_norm': 3.169023036956787, 'learning_rate': 8.784078516902945e-06, 'epoch': 2.31}


 79%|███████▉  | 1850/2334 [9:00:52<2:36:35, 19.41s/it]

{'loss': 0.656, 'grad_norm': 3.6083710193634033, 'learning_rate': 7.96619411123228e-06, 'epoch': 2.38}


 81%|████████▏ | 1900/2334 [9:14:53<1:58:09, 16.33s/it]

{'loss': 0.647, 'grad_norm': 4.061622619628906, 'learning_rate': 7.148309705561615e-06, 'epoch': 2.44}


 84%|████████▎ | 1950/2334 [9:30:43<1:57:25, 18.35s/it]

{'loss': 0.654, 'grad_norm': 3.283297538757324, 'learning_rate': 6.330425299890949e-06, 'epoch': 2.5}


 86%|████████▌ | 2000/2334 [9:45:01<1:35:48, 17.21s/it]

{'loss': 0.6447, 'grad_norm': 3.4708502292633057, 'learning_rate': 5.512540894220283e-06, 'epoch': 2.57}


                                                       
Non-default generation parameters: {'max_length': 142, 'min_length': 56, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


{'eval_loss': 1.1158305406570435, 'eval_runtime': 154.9879, 'eval_samples_per_second': 3.226, 'eval_steps_per_second': 1.613, 'epoch': 2.57}


 88%|████████▊ | 2050/2334 [10:04:11<1:27:02, 18.39s/it]

{'loss': 0.6461, 'grad_norm': 3.986788749694824, 'learning_rate': 4.694656488549618e-06, 'epoch': 2.63}


 90%|████████▉ | 2100/2334 [10:21:14<1:13:02, 18.73s/it]

{'loss': 0.6454, 'grad_norm': 4.123053073883057, 'learning_rate': 3.8767720828789534e-06, 'epoch': 2.7}


 92%|█████████▏| 2150/2334 [10:37:00<55:27, 18.08s/it]  

{'loss': 0.6396, 'grad_norm': 3.468980312347412, 'learning_rate': 3.058887677208288e-06, 'epoch': 2.76}


 94%|█████████▍| 2200/2334 [10:53:13<40:35, 18.18s/it]  

{'loss': 0.6619, 'grad_norm': 3.330507755279541, 'learning_rate': 2.2410032715376227e-06, 'epoch': 2.83}


 96%|█████████▋| 2250/2334 [11:07:35<24:20, 17.38s/it]

{'loss': 0.6315, 'grad_norm': 3.1077840328216553, 'learning_rate': 1.4231188658669574e-06, 'epoch': 2.89}


 99%|█████████▊| 2300/2334 [11:22:56<10:42, 18.90s/it]

{'loss': 0.6307, 'grad_norm': 3.819241762161255, 'learning_rate': 6.052344601962923e-07, 'epoch': 2.95}


Non-default generation parameters: {'max_length': 142, 'min_length': 56, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}
100%|██████████| 2334/2334 [11:33:50<00:00, 17.84s/it]
Non-default generation parameters: {'max_length': 142, 'min_length': 56, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


{'train_runtime': 41630.2394, 'train_samples_per_second': 0.898, 'train_steps_per_second': 0.056, 'train_loss': 0.8786970373903877, 'epoch': 3.0}
