Skip to content

Commit

Permalink
code linted
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinKalema committed Jun 11, 2024
1 parent f7738ef commit 90f068a
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 21 deletions.
2 changes: 2 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import streamlit as st
from fastai.text.all import *


@st.cache_resource
def load_model():
learn = load_learner('text_classifier_learner.pth')
return learn


learn = load_model()

# Streamlit app
Expand Down
4 changes: 3 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,6 @@ def run_pipeline_stage(stage_name, pipeline_class) -> None:

if __name__ == '__main__':
run_pipeline_stage("DATA INGESTION STAGE", DataIngestionTrainingPipeline)
run_pipeline_stage("MODEL TRAINING AND EVALUATION STAGE", ModelTrainingAndEvaluationPipeline)
run_pipeline_stage(
"MODEL TRAINING AND EVALUATION STAGE",
ModelTrainingAndEvaluationPipeline)
3 changes: 2 additions & 1 deletion src/swahiliNewsClassifier/components/data_ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def extract_zip_file(self):
with zipfile.ZipFile(zip_file, "r") as zip_ref:
zip_ref.extractall(decompress_path)

log.info(f"Extracted zip file {zip_file} into: {decompress_path}")
log.info(
f"Extracted zip file {zip_file} into: {decompress_path}")
except Exception as e:
log.error(f"Error extracting zip file: {zip_file}")
raise e
4 changes: 3 additions & 1 deletion src/swahiliNewsClassifier/configuration/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

load_dotenv()


class ConfigurationManager:
def __init__(self, config_filepath=CONFIG_FILE_PATH,
params_filepath=PARAMS_FILE_PATH):
Expand Down Expand Up @@ -41,7 +42,8 @@ def get_data_ingestion_config(self) -> DataIngestionConfig:
decompressed_dir=data_ingestion_config.decompressed_dir
)

def get_model_training_and_evaluation_config(self) -> ModelTrainingAndEvaluationConfig:
def get_model_training_and_evaluation_config(
self) -> ModelTrainingAndEvaluationConfig:
"""
Get the model training and evaluation configuration.
Expand Down
32 changes: 16 additions & 16 deletions src/swahiliNewsClassifier/entity/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,39 +32,39 @@ class ModelTrainingAndEvaluationConfig:
test_size (float): Proportion of the dataset to include in the test split. This parameter is used to split the dataset into training and validation sets.
learning_rate_1 (float): Learning rate for training the language model learner. This is used during the fine-tuning of the pre-trained language model.
learning_rate_2 (float): Learning rate for the first phase of classifier training. This is used in the initial phase of training the text classifier.
learning_rate_3 (float): Learning rate for the second phase of classifier training. This is used in the second phase of training the text classifier.
learning_rate_4 (float): Learning rate for the third phase of classifier training. This is used in the third phase of training the text classifier.
learning_rate_5 (float): Learning rate for the fourth phase of classifier training. This is used in the final phase of training the text classifier.
batch_size_1 (int): Batch size for language model training. This parameter defines the number of samples that will be propagated through the network at once during language model training.
batch_size_2 (int): Batch size for text classifier training. This parameter defines the number of samples that will be propagated through the network at once during text classifier training.
epochs_1 (int): Number of epochs for training the language model learner. This defines the number of complete passes through the training dataset.
epochs_2 (int): Number of epochs for the first phase of classifier training. This defines the number of complete passes through the training dataset in the first phase.
epochs_3 (int): Number of epochs for the second phase of classifier training. This defines the number of complete passes through the training dataset in the second phase.
epochs_4 (int): Number of epochs for the third phase of classifier training. This defines the number of complete passes through the training dataset in the third phase.
epochs_5 (int): Number of epochs for the fourth phase of classifier training. This defines the number of complete passes through the training dataset in the final phase.
training_data (Path): Path to the training data CSV file. This file contains the text data and corresponding labels for training and validation.
root_dir (Path): Root directory for storing model artifacts. This directory is used to save trained models, logs, and other artifacts.
mlflow_tracking_uri (str): URI for the MLflow tracking server. This is used to log and track experiments with MLflow.
mlflow_repo_name (str): Repository name for MLflow tracking. This is used to organize and identify different MLflow runs within the repository.
mlflow_repo_owner (str): Owner of the MLflow repository. This is used to identify the owner of the MLflow repository.
all_params (dict): Dictionary containing all parameters used for model training. This includes all hyperparameters and other settings for reproducibility and logging.
"""
test_size: float
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def main(self):
"""
try:
data_ingestion_config = self.config.get_data_ingestion_config()
data_ingestion = DataIngestion(data_ingestion_configurations=data_ingestion_config)
data_ingestion = DataIngestion(
data_ingestion_configurations=data_ingestion_config)
data_ingestion.download_file()
data_ingestion.extract_zip_file()
except Exception as e:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_data_ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def data_ingestion_configurations():

@pytest.fixture
def data_ingestion(data_ingestion_configurations):
return DataIngestion(data_ingestion_configurations=data_ingestion_configurations)
return DataIngestion(
data_ingestion_configurations=data_ingestion_configurations)


@patch('swahiliNewsClassifier.components.data_ingestion.os.makedirs')
Expand Down

0 comments on commit 90f068a

Please sign in to comment.