In [1]:
import os
os.chdir("../")

In [2]:
import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV

In [3]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class ModelPreparationTrainingConfig:
    root_dir: Path
    save_models: Path
    param_target_col: str
    param_random_state: int
    param_n_estimators: list
    param_c_svc: list
    param_gamma_svc: list
    param_c_log_reg: list
    param_number_cv: int

In [4]:
from Mushroom_Classification.utils.common import create_directories, read_yaml, save_object, read_file
from Mushroom_Classification.constants import *
from Mushroom_Classification import logger

In [5]:
class ConfigurationManager:
    def __init__(self,
                 config_filepath = CONFIG_FILE_PATH,
                 params_filepath = PARAMS_FILE_PATH):
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_root]) #the artifacts_root is the key of the dictionary created
                                                # in the yaml file and we can read this key like that instead of
                                                # ["artifacts_root"] because we used the ConfigBox in the common.py file


    def get_model_preparation_training_config(self) -> ModelPreparationTrainingConfig:
        config = self.config.model_preparation_training #model_preparation_training is the other key value of the dictionary in the config.yaml file

        create_directories([config.root_dir,config.save_models])

        model_preparation_training_config = ModelPreparationTrainingConfig(
            root_dir=config.root_dir,
            save_models= config.save_models,
            param_target_col=self.params.TARGET,
            param_random_state=self.params.RANDOM_STATE,
            param_n_estimators=self.params.N_ESTIMATORS,
            param_c_svc=self.params.C_SVC,
            param_gamma_svc=self.params.GAMMA_SVC,
            param_c_log_reg=self.params.C_LOG_REG,
            param_number_cv=self.params.NUMBER_CV
        )                                     

        return model_preparation_training_config

In [16]:
class ModelPreparationTraining:
    def __init__(self, config: ModelPreparationTrainingConfig):
        self.config = config


    def model_preparation_training(self):
        """Stage of model preparation and training"""

        train_file = "train.csv"
        train_pca_file = "train_pca.csv"

        train = read_file(Path(self.config.root_dir),train_file)
        train_pca = read_file(Path(self.config.root_dir),train_pca_file)
        X_train = train.drop(columns=self.config.param_target_col, axis = 1)
        y_train = train[self.config.param_target_col]
        X_train_pca = train_pca.drop(columns=self.config.param_target_col, axis = 1)
        y_train_pca = train_pca[self.config.param_target_col]


        model_params = {
                'rnd_for':{
                    'model' : RandomForestClassifier(random_state=self.config.param_random_state),
                    'params':{
                        'n_estimators':self.config.param_n_estimators
                    }
                },
                'log_reg':{
                    'model': LogisticRegression(max_iter = 5000),
                    'params': {
                        'C':self.config.param_c_log_reg
                    }
                },
                'svm':{
                    'model': SVC(),
                    'params':{
                        'gamma':self.config.param_gamma_svc,
                        'C':self.config.param_c_svc
                    }
                }
                }
        
        best_models=[]

        logger.info("Performing GridSearch\n")

        for model,model_param in model_params.items():
            grid = GridSearchCV(model_param['model'], model_param['params'], cv = 5)
            grid.fit(X_train,y_train)
            best_models.append({
                'model':model,
                'best_params':grid.best_params_,
                'best_score':grid.best_score_
            })

        best_model_df = pd.DataFrame(best_models)
        ind = max(best_model_df.best_score.index)
        model =best_model_df.iloc[ind]["model"]
        param =best_model_df.iloc[ind]["best_params"]
        models = {
                'rnd_for': RandomForestClassifier(random_state=self.config.param_random_state),
                'log_reg': LogisticRegression(max_iter = 5000),
                'svm':SVC(),
                }
        logger.info(f"The chosen model is {model} with the following parameters: {param}")
        final_model = models[best_model_df.iloc[ind]["model"]]
        final_model.set_params(**param)
        final_model.fit(X_train,y_train)
        save_object(Path(self.config.save_models),final_model,"best_model.pkl")

        final_model_pca = models[best_model_df.iloc[ind]["model"]]
        final_model_pca.set_params(**param)
        final_model_pca.fit(X_train_pca,y_train_pca)
        save_object(Path(self.config.save_models),final_model,"pca_model.pkl")

        return None


In [17]:
try:
    config = ConfigurationManager()
    model_preparation_training_config = config.get_model_preparation_training_config()
    model_preparation_training = ModelPreparationTraining(config=model_preparation_training_config)
    model_preparation_training.model_preparation_training()
    
except Exception as e:
    raise e

[2024-06-17 23:12:37,951: INFO: common: yaml file: config\config.yaml loaded successfully]
[2024-06-17 23:12:37,952: INFO: common: yaml file: params.yaml loaded successfully]
[2024-06-17 23:12:37,952: INFO: common: The directory artifacts already exists]
[2024-06-17 23:12:37,952: INFO: common: The directory artifacts/training already exists]
[2024-06-17 23:12:37,952: INFO: common: The directory artifacts/models already exists]
[2024-06-17 23:12:37,971: INFO: 2399315960: Performing GridSearch
]
[2024-06-17 23:12:42,665: INFO: 2399315960: The chosen model is svm with the following parameters: {'C': 0.01, 'gamma': 0.1}]
