In [None]:
import os
os.chdir("../")

In [None]:
from dataclasses import dataclass
from pathlib import Path


@dataclass(frozen=True)
class TrainingConfig:
    root_dir: Path
    trained_model_path: Path
    updated_base_model_path: Path
    training_data: Path
    params_epochs: int
    params_batch_size: int
    params_is_augmentation: bool
    params_image_size: list


In [None]:
from src.collision_alert.utils.common import create_directories, read_yaml
from src.collision_alert.constants import *

In [None]:
class ConfigurationManager:
    def __init__(
        self, 
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH):
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)
        create_directories([self.config.artifacts_root])


    def get_training_config(self) -> TrainingConfig:
        training = self.config.training
        prepare_base_model = self.config.prepare_base_model
        params = self.params
        training_data = os.path.join(self.config.data_ingestion.unzip_dir, "trafic_data")
        create_directories([
            Path(training.root_dir)
        ])

        training_config = TrainingConfig(
            root_dir=Path(training.root_dir),
            trained_model_path=Path(training.trained_model_path),
            updated_base_model_path=Path(prepare_base_model.updated_base_model_path),
            training_data=Path(training_data),
            params_epochs=params.EPOCHS,
            params_batch_size=params.BATCH_SIZE,
            params_is_augmentation=params.AUGMENTATION,
            params_image_size=params.IMAGE_SIZE
        )

        return training_config

In [None]:
import ultralytics
from ultralytics import YOLO, checks, hub
import requests
checks()  # checks

In [None]:
class Training:
    def __init__(self, config: TrainingConfig):
        self.config = config

    def get_base_model(self):
        self.model = YOLO(
            self.config.updated_base_model_path
        )

    @staticmethod
    def save_model(path: Path, model):
        model.save(path)


    def train(self):
        model = self.model
        self.model = model.train(data = self.config.training_data, task='detect', epochs = self.config.params_epochs, save = True)

        self.save_model(
            path=str(self.config.trained_model_path),
            model=self.model
        )

In [None]:
try:
    config = ConfigurationManager()
    
    training_config = config.get_training_config()
    training = Training(config=training_config)
    training.get_base_model()
    training.train()
    
except Exception as e:
    raise e