In [1]:
import os
os.chdir("../")

In [2]:
%pwd

'e:\\MyOnlineCourses\\ML_Projects\\arabic-digits-recognition'

In [3]:
from dataclasses import dataclass
from pathlib import Path


@dataclass(frozen=True)
class DataTFTrainConfig:
    root_dir: Path
    dst_path: str
    

In [4]:
from src.ard.constants import *
from src.ard.utils.help import read_yaml, create_directories

class ConfigurationManager:
    def __init__(
        self,
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH):

        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_root])


    
    def get_data_tf_training_config(self) -> DataTFTrainConfig:
        config = self.config.data_tf_training

        create_directories([config.root_dir])

        data_tf_training_config = DataTFTrainConfig(
            root_dir=config.root_dir,
            dst_path=config.dst_path
           
        )

        return data_tf_training_config, self.params

In [12]:
import tensorflow as tf
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Dict, Any, Tuple
from dataclasses import dataclass
import os
from ard import logger
from ard.utils.tf_utils import build_dataset, finalize_dataset

tf.get_logger().setLevel('ERROR')  # Stop tf WARNINGS


class ModelTraining:
    def __init__(self, config: DataTFTrainConfig, params: Dict[str, Any]):
        self.config = config
        self.params = params
        self.labels = params['LABELS']
    
    
    def load_and_preprocess_data(self) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
        train_files, val_files = self._read_csv_to_list()
        for files in [train_files, val_files]:
            self._log_audio_info(files)
        return self._create_datasets(train_files, val_files)


    def _read_csv_to_list(self):

        train_files = pd.read_csv(os.path.join(self.config.root_dir, 'train_metadata.csv'))
        val_files = pd.read_csv(os.path.join(self.config.root_dir, 'val_metadata.csv'))
        logger.info(f"Total files: {len(train_files)+ len(val_files)}")
        logger.info(f"Training files: {len(train_files)} ({len(train_files)/(len(train_files)+ len(val_files)):.2%})")
        logger.info(f"Validation files: {len(val_files)} ({len(val_files)/(len(train_files)+ len(val_files)):.2%})")
        # Convert the 'path' column to a list
        return train_files['path'].tolist(), val_files['path'].tolist()
        

    def _log_audio_info(self, files, desc=None):
        num_samples = len(files)
        logger.info(f'Number of total examples in {desc}: {num_samples}')
        
        monos, stereos = self._count_channels(files)
        logger.info(f"Mono audio files: {len(monos)}, Stereo audio files: {len(stereos)} in {desc} dataset")

       

    def _count_channels(self, files):
        monos, stereos = [], []
        for file in files:
            wav_contents = tf.io.read_file(file)
            wav, _ = tf.audio.decode_wav(contents=wav_contents)
            (monos if wav.shape[1] == 1 else stereos).append(file)
        return monos, stereos

    def _create_datasets(self, train_files, val_files):
        AUTOTUNE = tf.data.AUTOTUNE

        train_ds = build_dataset(train_files, AUTOTUNE)
        val_ds = build_dataset(val_files, AUTOTUNE)

        return finalize_dataset(train_ds, AUTOTUNE), finalize_dataset(val_ds, AUTOTUNE)

   
    def build_and_train_model(self, train_ds, val_ds):
        input_shape = self._get_input_shape(train_ds)[1:]
        logger.info(f'Input shape: {input_shape}')
        model = self._create_model(input_shape, len(self.params.LABELS))
        self._compile_model(model)
        self._train_model(model, train_ds, val_ds)

        model.save(self.config.dst_path)

    def _get_input_shape(self, dataset):
        for spectrogram, _ in dataset.take(1):
            input_shape = spectrogram.shape
        return input_shape

    def _create_model(self, input_shape, num_labels):
        return tf.keras.Sequential([
            tf.keras.layers.Input(shape=input_shape),
            tf.keras.layers.Conv2D(32, 3, activation='relu'),
            tf.keras.layers.Dropout(0.25),
            tf.keras.layers.Conv2D(64, 3, activation='relu'),
            tf.keras.layers.MaxPooling2D(),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(128, activation='relu'),
            tf.keras.layers.Dropout(0.5),
            tf.keras.layers.Dense(num_labels)
        ])

    def _compile_model(self, model):
        model.compile(
            optimizer=tf.keras.optimizers.Adam(),
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            metrics=['accuracy']
        )

    def _train_model(self, model, train_ds, val_ds):
        EPOCHS = 100
        early_stopping = tf.keras.callbacks.EarlyStopping(verbose=1, patience=10)
        history = model.fit(
            train_ds,
            validation_data=val_ds,
            epochs=EPOCHS,
            callbacks=[early_stopping]
        )
        return history
        
    def train(self):
        train, validation = self.load_and_preprocess_data()
        self.build_and_train_model(train, validation)

In [13]:
try:
    config = ConfigurationManager()
    data_tf_training_config, data_tf_training_params = config.get_data_tf_training_config()
    data_tf_training = ModelTraining(config=data_tf_training_config, params=data_tf_training_params)
    data_tf_training.train()

except Exception as e:
    raise e

[2024-08-11 23:32:28,822: INFO: help: yaml file: config\config.yaml loaded successfully. Content size: 8]
[2024-08-11 23:32:28,834: INFO: help: yaml file: params.yaml loaded successfully. Content size: 7]
[2024-08-11 23:32:28,837: INFO: help: Total directories created: 1]
[2024-08-11 23:32:28,842: INFO: help: Total directories created: 1]
[2024-08-11 23:32:29,923: INFO: 2185098702: Total files: 361]
[2024-08-11 23:32:30,101: INFO: 2185098702: Number of total examples in None: 321]
[2024-08-11 23:32:32,033: INFO: 2185098702: Mono audio files: 0, Stereo audio files: 321 in None dataset]
[2024-08-11 23:32:32,035: INFO: 2185098702: Number of total examples in None: 40]
[2024-08-11 23:32:32,192: INFO: 2185098702: Mono audio files: 0, Stereo audio files: 40 in None dataset]
[2024-08-11 23:32:45,040: INFO: 2185098702: Train dataset size: 14]
[2024-08-11 23:32:45,042: INFO: 2185098702: Validation dataset size: 2]
[2024-08-11 23:32:53,485: INFO: 2185098702: Input shape: (77, 129, 1)]
Epoch 1/10