In [1]:
import os
os.chdir('../')

In [2]:
from dataclasses import dataclass
from src.logger.custom_logging import logger
from pathlib import Path
from src.constants import *
from src.utils.utlis import *
from src.exceptions.expection import CustomException
from sklearn.ensemble import RandomForestRegressor
from sklearn.svm import SVR
from sklearn.model_selection import train_test_split
import sys,joblib
import pandas as pd
from sklearn.base import BaseEstimator


@dataclass(frozen=True)

class ModelTrainerConfig:
    root_dir: Path
    train_model_path: Path
    training_data_path: Path
    testing_data_path: Path

In [3]:
class ConfigManager:
    def __init__(self,config_file=CONFIG_FILE_PATH,
                 params_filepath = PARAMS_FILE_PATH):
        self.config=read_yaml(config_file)
        self.params=read_yaml(params_filepath)
    

        create_directories([self.config.artifacts_root])

    def get_model_trainer_config(self)-> ModelTrainerConfig:
        config=self.config.model_trainer
        create_directories([config.root_dir])

        model_trainer_config = ModelTrainerConfig(
            root_dir=config.root_dir,
            train_model_path=config.train_model_path,
            training_data_path=config.training_data_path,
            testing_data_path=config.testing_data_path
    

        )    
        return model_trainer_config

In [8]:
class ModelTrainer:
    def __init__(self,config:ModelTrainerConfig):
        self.config = config

    @staticmethod
    def save_model(path: Path, model: BaseEstimator):
        """Save the trained model to the specified path."""
        try:
            joblib.dump(model, path)
            print(f"Model saved at {path}")
        except Exception as e:
            print(f"Error saving model: {e}")
            logger.error(f"Error saving model: {e}")
            raise CustomException(e, sys)  

    def initate_model_trainer(self):
        train_data=pd.read_csv(self.config.training_data_path)
        test_data=pd.read_csv(self.config.testing_data_path)
        try:
            X_train = train_data.iloc[:, :-1].values
            y_train = train_data.iloc[:, -1].values
            X_test = test_data.iloc[:, :-1].values
            y_test = test_data.iloc[:, -1].values   


            models = {
                "Random Forest": RandomForestRegressor(),
                "SVR": SVR()
            }
            

            # Evaluate models
            model_report = final_model(models, X_train, X_test, y_train, y_test)
            print(model_report)
            print('\n====================================================================================\n')
            logger.info(f'Model Report: {model_report}')

            # Get the best model score
            best_model_score = max(model_report.values())
            best_model_name = [name for name, score in model_report.items() if score == best_model_score][0]
            best_model = models[best_model_name]

            print(f"Best Model Found, Model Name is: {best_model_name}, Accuracy_Score: {best_model_score}")
            print("\n***************************************************************************************\n")
            logger.info(f"Best model found, Model Name is {best_model_name}, Accuracy Score: {best_model_score}")

            # Save the best model
            self.save_model(path=self.config.train_model_path,model=best_model)
        except Exception as e:
            logger.error(f'Error occurred: {e}')
            raise CustomException(e, sys)

In [9]:
try:
    config=ConfigManager()
    model_trainer_config=config.get_model_trainer_config()
    model_trainer=ModelTrainer(model_trainer_config)
    model_trainer.initate_model_trainer()

except Exception as e:
    raise CustomException(e,sys)



{'Random Forest': 0.9109858071598773, 'SVR': -0.10865270007595074}


Best Model Found, Model Name is: Random Forest, Accuracy_Score: 0.9109858071598773

***************************************************************************************

Model saved at artifacts/model_trainer/model.pkl
