In [1]:
import os 

In [2]:
os.chdir("../")

In [3]:
%pwd

'c:\\Users\\ASUS\\Desktop\\MLOps_Whisper'

In [47]:
from dataclasses import dataclass
from pathlib import Path
from transformers import WhisperForConditionalGeneration, WhisperProcessor


@dataclass(frozen=True)
class PrepareBaseModelConfig:
    root_dir: Path
    base_model_path: Path
    updated_base_model_path: Path
    gradient_accumulation_steps: int
    learning_rate: float
    warmup_steps: int 
    max_steps: int
    gradient_checkpointing: bool
    fp16: bool
    per_device_eval_batch_size: int
    predict_with_generate: bool
    generation_max_length: int
    save_steps: int
    eval_steps: int

In [48]:
from src.whisper.constants import *
from src.whisper.utils.common import read_yaml, create_directories

In [49]:
class ConfigurationManager:
    def __init__(
        self,
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH):

        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_root])


    def get_prepare_base_model_config(self) -> PrepareBaseModelConfig:
        config = self.config.prepare_base_model
        self.model_name = "openai/whisper-small"
        
        create_directories([config.root_dir])

        prepare_base_model_config = PrepareBaseModelConfig(
            root_dir=Path(config.root_dir),
            base_model_path=Path(config.base_model_path),
            updated_base_model_path=Path(config.updated_base_model_path),
            gradient_accumulation_steps=self.params.gradient_accumulation_steps,
            learning_rate=self.params.learning_rate,
            warmup_steps=self.params.warmup_steps,
            max_steps=self.params.max_steps,
            gradient_checkpointing=self.params.gradient_checkpointing,
            fp16= self.params.fp16,
            per_device_eval_batch_size= self.params.per_device_eval_batch_size,
            predict_with_generate=self.params.predict_with_generate,
            generation_max_length=self.params.generation_max_length,
            save_steps=self.params.save_steps,
            eval_steps=self.params.save_steps
            
        )

        return prepare_base_model_config

In [50]:
import torch
from transformers import WhisperForConditionalGeneration

In [53]:
class PrepareBaseModel:
    def __init__(self, config: PrepareBaseModelConfig):
        self.config = config
    
    def get_base_model(self):
        self.model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
        self.processor = WhisperProcessor.from_pretrained("openai/whisper-small")
        self.save_model(self.config.base_model_path, self.model)

    def save_model(self, path: Path, model):
        model.save_pretrained(path)
        print(f"Model saved to {path}")

    def _prepare_full_model(self, model, processor, **kwargs):
        
        return model, processor
    

    def update_base_model(self):
        # This method would be used to update the base model if needed
        self.full_model, self.processor = self._prepare_full_model(
            model=self.model,
            processor=self.processor,
            learning_rate=self.config.learning_rate,
            gradient_accumulation_steps=self.config.gradient_accumulation_steps,
            warmup_steps=self.config.warmup_steps,
            max_steps=self.config.max_steps,
            gradient_checkpointing=self.config.gradient_checkpointing,
            fp16=self.config.fp16,
            per_device_eval_batch_size=self.config.per_device_eval_batch_size,
            predict_with_generate=self.config.predict_with_generate,
            generation_max_length=self.config.generation_max_length,
            save_steps=self.config.save_steps,
            eval_steps=self.config.eval_steps
        )
        self.save_model(self.config.updated_base_model_path, self.full_model)

    @staticmethod
    def save_model(path: str, model):
        model.save_pretrained(path)
        print(f"Model saved to {path}")



In [None]:
try:
    config = ConfigurationManager()
    prepare_base_model_config = config.get_prepare_base_model_config()
    prepare_base_model = PrepareBaseModel(config=prepare_base_model_config)
    prepare_base_model.get_base_model()
    prepare_base_model.update_base_model()
except Exception as e:
    raise e