# 3.- Training Configuration

In [15]:
import os
import requests
import tarfile
import re
from git import Repo

In [16]:
def download_file(url, destination_folder, file_name):

    os.makedirs(destination_folder, exist_ok=True)
    path_to_write = f"{destination_folder}/{file_name}"
    response = requests.get(url)

    if response.status_code == 200:
        with open(path_to_write, 'wb') as file:
            file.write(response.content)
        print(f"File downloaded and saved to: {path_to_write}")
    else:
        print(f"Failed to download the file, status code: {response.status_code}")


In [17]:
chosen_model = 'ssd-mobilenet-v2-fpnlite-320'

MODELS_CONFIG = {
    'ssd-mobilenet-v2': {
        'model_name': 'ssd_mobilenet_v2_320x320_coco17_tpu-8',
        'base_pipeline_file': 'ssd_mobilenet_v2_320x320_coco17_tpu-8.config',
        'pretrained_checkpoint': 'ssd_mobilenet_v2_320x320_coco17_tpu-8.tar.gz',
    },
    'efficientdet-d0': {
        'model_name': 'efficientdet_d0_coco17_tpu-32',
        'base_pipeline_file': 'ssd_efficientdet_d0_512x512_coco17_tpu-8.config',
        'pretrained_checkpoint': 'efficientdet_d0_coco17_tpu-32.tar.gz',
    },
    'ssd-mobilenet-v2-fpnlite-320': {
        'model_name': 'ssd_mobilenet_v2_fpnlite_320x320_coco17_tpu-8',
        'base_pipeline_file': 'ssd_mobilenet_v2_fpnlite_320x320_coco17_tpu-8.config',
        'pretrained_checkpoint': 'ssd_mobilenet_v2_fpnlite_320x320_coco17_tpu-8.tar.gz',
    },
}

model_name = MODELS_CONFIG[chosen_model]['model_name']
pretrained_checkpoint = MODELS_CONFIG[chosen_model]['pretrained_checkpoint']
base_pipeline_file = MODELS_CONFIG[chosen_model]['base_pipeline_file']

## Download Model

In [18]:
chosen_model = 'ssd_mobilenet_v2_fpnlite_320x320_coco17_tpu-8'
download_tar = 'http://download.tensorflow.org/models/object_detection/tf2/20200711/' + pretrained_checkpoint
destination_folder = "../train/pretrained_models"
model_file_name = f"{chosen_model}.tar.gz"

In [19]:
download_file(download_tar, destination_folder, model_file_name)

File downloaded and saved to: ../train/pretrained_models/ssd_mobilenet_v2_fpnlite_320x320_coco17_tpu-8.tar.gz


In [20]:
tar = tarfile.open(f"{destination_folder}/{model_file_name}")
tar.extractall(path = destination_folder)
tar.close()

## Download Model Config File

In [21]:
url_config = f"https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/configs/tf2/{model_name}.config"
config_file_name = f"{model_name}.config"

In [22]:
download_file(url_config, destination_folder, config_file_name)

File downloaded and saved to: ../train/pretrained_models/ssd_mobilenet_v2_fpnlite_320x320_coco17_tpu-8.config


## Clone Github Tensorflow Models Garden Repository

In [23]:
def clone_repository(repo_url, local_dir):
    # Create the local directory if it does not exist
    if not os.path.exists(local_dir):
        os.makedirs(local_dir)
        print(f"Directory '{local_dir}' created.")
    else:
        print(f"Directory '{local_dir}' already exists.")
    
    # Clone the repository into the specified directory
    try:
        print(f"Cloning {repo_url} into {local_dir}")
        Repo.clone_from(repo_url, local_dir)
        print("Repository successfully cloned.")
    except Exception as e:
        print(f"Failed to clone the repository: {e}")


In [None]:
# URL of the GitHub repository you want to clone
repo_url = 'https://github.com/tensorflow/models.git'

# Local path where you want to save the cloned repository
train_code = '../train/src/'

# Execute the function
clone_repository(repo_url, train_code)

## Define Parameters in Config Files

In [31]:
import os
import re

def update_pipeline_config(pipeline_path, output_path, fine_tune_checkpoint, train_record, val_record, 
                           label_map, batch_size, num_steps, num_classes, chosen_model):
    """
    Updates a TensorFlow Object Detection API pipeline configuration file with custom values.

    Parameters:
    - pipeline_path (str): Path to the original pipeline.config file.
    - output_path (str): Path where the updated config file should be saved.
    - fine_tune_checkpoint (str): Path to the fine-tune checkpoint.
    - train_record (str): Path to the training TFRecord file.
    - val_record (str): Path to the validation TFRecord file.
    - label_map (str): Path to the label map file.
    - batch_size (int): Training batch size.
    - num_steps (int): Number of training steps.
    - num_classes (int): Number of classes in the dataset.
    - chosen_model (str): Model type (e.g., 'ssd-mobilenet-v2', 'efficientdet-d0').
    
    Returns:
    - None
    """
    print("Writing custom configuration file...")

    # Ensure the destination directory exists
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    with open(pipeline_path, 'r') as f:
        config_content = f.read()
    
    # Update fine-tune checkpoint path
    config_content = re.sub(r'fine_tune_checkpoint: ".*?"', 
                            f'fine_tune_checkpoint: "{fine_tune_checkpoint}"', config_content)
    
    # Update TFRecord paths
    config_content = re.sub(r'(input_path: ".*?)(PATH_TO_BE_CONFIGURED/train)(.*?")', 
                            f'input_path: "{train_record}"', config_content)
    config_content = re.sub(r'(input_path: ".*?)(PATH_TO_BE_CONFIGURED/val)(.*?")', 
                            f'input_path: "{val_record}"', config_content)
    
    # Update label map path
    config_content = re.sub(r'label_map_path: ".*?"', 
                            f'label_map_path: "{label_map}"', config_content)
    
    # Update batch size
    config_content = re.sub(r'batch_size: \d+', 
                            f'batch_size: {batch_size}', config_content)
    
    # Update training steps
    config_content = re.sub(r'num_steps: \d+', 
                            f'num_steps: {num_steps}', config_content)
    
    # Update number of classes
    config_content = re.sub(r'num_classes: \d+', 
                            f'num_classes: {num_classes}', config_content)
    
    # Change fine-tune checkpoint type to 'detection'
    config_content = re.sub(r'fine_tune_checkpoint_type: "classification"', 
                            'fine_tune_checkpoint_type: "detection"', config_content)
    
    # Adjust learning rate if using ssd-mobilenet-v2
    if chosen_model == 'ssd-mobilenet-v2':
        config_content = re.sub(r'learning_rate_base: .8', 'learning_rate_base: .08', config_content)
        config_content = re.sub(r'warmup_learning_rate: 0.13333', 'warmup_learning_rate: .026666', config_content)
    
    # Adjust resizer settings if using efficientdet-d0 (for TFLite compatibility)
    if chosen_model == 'efficientdet-d0':
        config_content = re.sub(r'keep_aspect_ratio_resizer', 'fixed_shape_resizer', config_content)
        config_content = re.sub(r'pad_to_max_dimension: true', '', config_content)
        config_content = re.sub(r'min_dimension', 'height', config_content)
        config_content = re.sub(r'max_dimension', 'width', config_content)
    
    # Save updated config file
    with open(output_path, 'w') as f:
        f.write(config_content)
    
    print(f"Updated pipeline configuration saved to: {output_path}")



### Local Config File

In [32]:
pipeline_path = destination_folder + '/' + base_pipeline_file
output_path = '../train/config_files/local_config_file.config'
fine_tune_checkpoint = '/app/train/pretrained_models/' + model_name + '/checkpoint/ckpt-0'
train_record = '/app/data/TFRecords/train.tfrecord'
val_record = '/app/data/TFRecords/val.tfrecord'
label_map = '/app/data/TFRecords/label_map.pbtxt'
batch_size = 1
num_steps = 1
num_classes = 1

In [33]:
update_pipeline_config(pipeline_path,
                       output_path,
                       fine_tune_checkpoint,
                       train_record,
                       val_record,
                       label_map,
                       batch_size,
                       num_steps,
                       num_classes,
                       chosen_model)

Writing custom configuration file...
Updated pipeline configuration saved to: ../train/config_files/local_config_file.config
