# Custom dataset people mediapipe Training

In [None]:
#!pip install 'keras<3.0.0' mediapipe-model-maker

In [None]:
import os
import tensorflow as tf
#assert tf.__version__.startswith('2')
from mediapipe_model_maker import object_detector

In [None]:
class TransferLearning():
    """ Transfer Learning class for object detection using MediaPipe Model Maker
    2 datasets are required: train and validation
    
    Attributes:
        train_dataset_path: path to the train dataset
        validation_dataset_path: path to the validation dataset
        model: model to be trained
        train_data: train dataset
        validation_data: validation dataset
    """
    def __init__(self):
        self.train_dataset_path = "/content/Person-Dataset-1/train"
        self.validation_dataset_path = "/content/Person-Dataset-1/valid" 
        self.model = None
        self.train_data = None
        self.val_data = None
        self.hparams = None
        self.options = None
        
    def load_data(self):
        """ Load train and validation datasets from the given paths        """
        self.train_data = object_detector.DataLoader.from_pascal_voc(self.train_dataset_path)
        self.validation_data = object_detector.DataLoader.from_pascal_voc(self.validation_dataset_path)
    
    def train_model(self,batch_size=8, learning_rate=0.3, epochs=50, export_dir='exported_model'):
        """ Train the model using the loaded datasets
        
        Args:
            batch_size: batch size for training
            learning_rate: learning rate for training
            epochs: number of epochs for training
            export_dir: directory to export the trained model
        """
        self.hparams = object_detector.HParams(batch_size, learning_rate, epochs, export_dir)
        self.options = object_detector.ObjectDetectorOptions(
            supported_model=object_detector.SupportedModels.MOBILENET_V2,
            hparams=self.hparams
        )
        
        self.model = object_detector.ObjectDetector.create(
            train_data=self.train_data,
            validation_data=self.val_data,
            options=self.options)

    def evaluate_model(self,batch_size=8):
        """ Evaluate the trained model 
        
        Args:
            batch_size: batch size for evaluation
        """
        loss, metrics = self.model.evaluate(self.val_data,batch_size)
        print(f"Validation loss: {loss}")
        print(f"Validation coco metrics: {metrics}")
    
    def export_model(self):
        """ Export the trained model """
        self.model.export('people-detection.tflite')
        print(f"Model exported")
    
    


In [None]:
#running functions, if docs need it like this
if __name__ == "__main__":
    TransferLearning_model = TransferLearning()
    TransferLearning_model.load_data()
    TransferLearning_model.train_model()



In [None]:

if __name__ == "__main__":
    TransferLearning_model.evaluate_model()


In [None]:
if __name__ == "__main__":
    TransferLearning_model.export_model()