In [None]:
import os
os.chdir('../')

In [None]:
from dotenv import load_dotenv, find_dotenv

_ = load_dotenv(find_dotenv()) # read local .env file

MODEL_CONFIG_FILE_PATH = os.environ['MODEL_CONFIG_FILE_PATH']
MODEL_PARAMS_FILE_PATH = os.environ['MODEL_PARAMS_FILE_PATH']

In [None]:
from dataclasses import dataclass
from pathlib import Path


@dataclass(frozen=True)
class ModelPredictionConfig:
    data_path: Path
    base_model: str
    adapters_path: Path

@dataclass(frozen=True)
class ModelPredictionParameters:
    length_penalty: float
    num_beams: int
    max_length: int

@dataclass(frozen=True)
class BitsAndBytesParameters:
    load_in_4bit: bool
    bnb_4bit_quant_type: str
    bnb_4bit_use_double_quant: bool

In [None]:
from src.utils.common import read_yaml

class ConfigurationManager:
    def __init__(self,
                model_config_filepath = MODEL_CONFIG_FILE_PATH,
                model_params_filepath = MODEL_PARAMS_FILE_PATH):

        self.config = read_yaml(Path(model_config_filepath))
        self.params = read_yaml(Path(model_params_filepath))


    def get_model_prediction_config(self) -> ModelPredictionConfig:
        config = self.config.model_prediction

        model_prediction_config = ModelPredictionConfig(
            data_path=config.data_path,
            base_model = config.base_model,
            adapters_path = config.adapters_path
        )

        return model_prediction_config
    
    def get_bits_and_bytes_params(self) -> BitsAndBytesParameters:
        params = self.params.bits_and_bytes_parameters

        bits_and_bytes_parameters = BitsAndBytesParameters(
            load_in_4bit = params.load_in_4bit,
            bnb_4bit_quant_type = params.bnb_4bit_quant_type,
            bnb_4bit_use_double_quant = params.bnb_4bit_use_double_quant
        )

        return bits_and_bytes_parameters
        
    def get_model_prediction_parameters(self) -> ModelPredictionParameters:
        config = self.params.prediction_parameters

        model_prediction_parameters = ModelPredictionParameters(
            length_penalty=config.length_penalty,
            num_beams = config.num_beams,
            max_length = config.max_length
        )

        return model_prediction_parameters


In [None]:
import torch
from transformers import pipeline, AutoTokenizer, LlamaForCausalLM, BitsAndBytesConfig
from peft import PeftModel

from src.logging import logger

class ModelPrediction:
    def __init__(self, config: ModelPredictionConfig, bits_and_bytes_parameters: BitsAndBytesParameters, params: ModelPredictionParameters):
        self.config = config
        self.bits_and_bytes_parameters = bits_and_bytes_parameters
        self.params = params 

    def __initialize_tokenizer(self, model_name: str):
        self.tokenizer = AutoTokenizer.from_pretrained(self.config.base_model)
        self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        logger.info("Tokenizer initialized")

    def __initialize_bits_and_bytes(self):
        self.nf4_config = BitsAndBytesConfig(
            load_in_4bit = self.bits_and_bytes_parameters.load_in_4bit,
            bnb_4bit_quant_type = self.bits_and_bytes_parameters.bnb_4bit_quant_type,
            bnb_4bit_use_double_quant = self.bits_and_bytes_parameters.bnb_4bit_use_double_quant,
            bnb_4bit_compute_dtype = torch.bfloat16
        )
        logger.info("Bits and bytes initialized")

    def __initialize_model(self):
        self.model = LlamaForCausalLM.from_pretrained(self.config.base_model, device_map='auto', quantization_config=self.nf4_config)
        self.peft_model = PeftModel.from_pretrained(self.model, self.config.adapters_path)
        logger.info("Model initialized")


    def predict(self, question):
        self.__initialize_tokenizer(self.config.base_model)
        self.__initialize_bits_and_bytes()
        self.__initialize_model()

        gen_kwargs = {"length_penalty": self.params.length_penalty,
                      "num_beams": self.params.max_length,
                      "max_length": self.params.max_length}

        pipe = pipeline("generation", model=self.peft_model, tokenizer=self.tokenizer)
        logger.info("Pipeline initialized")

        logger.info("Generating output...")
        output = pipe(question, **gen_kwargs)[0]["response"]
        logger.info("Output generated: ", output)

        return output

In [None]:
try:
    question = "What is 2+2?"
    config = ConfigurationManager()
    model_prediction_config = config.get_model_prediction_config()
    model_prediction_parameters = config.get_model_prediction_parameters()
    bits_and_bytes_parameters = config.get_bits_and_bytes_params()
    model_prediction = ModelPrediction(config=model_prediction_config, bits_and_bytes_parameters=bits_and_bytes_parameters, params=model_prediction_parameters)
    model_prediction.predict(question)
except Exception as e:
    raise e