In [0]:
# Some quality of life fixes for logging
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

# Set the default max pool size for all connection pools
import urllib3
urllib3.connection.HTTPConnection.default_pool_maxsize = 50
urllib3.connection.HTTPSConnection.default_pool_maxsize = 50

In [0]:
import tensorflow as tf
import numpy as np
import json
import os
from PIL import Image

In [0]:
image_dir = '/Volumes/shm/default/cppe5'

In [0]:
from cppe_modules import ObjectDetectionDataset, create_model, custom_loss

In [0]:
def train():
    # Create dataset
    dataset = ObjectDetectionDataset(image_dir, max_files=100)
    images, boxes, classes, masks = dataset.load_data()
    
    # Split into train and validation
    from sklearn.model_selection import train_test_split
    (X_train, X_val, 
     boxes_train, boxes_val,
     classes_train, classes_val,
     masks_train, masks_val) = train_test_split(
        images, boxes, classes, masks, test_size=0.2, random_state=42
    )
    
    # Create tf.data.Datasets
    train_dataset = dataset.create_tf_dataset(
        X_train, boxes_train, classes_train, masks_train)
    val_dataset = dataset.create_tf_dataset(
        X_val, boxes_val, classes_val, masks_val)
    
    # Create and compile model
    model = create_model()
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
        loss=custom_loss
    )
    
    # Training callbacks
    callbacks = [
        tf.keras.callbacks.ModelCheckpoint(
            'best_model.keras',
            save_best_only=True,
            monitor='val_loss',
            save_freq=100
        ),
        tf.keras.callbacks.EarlyStopping(
            patience=10,
            monitor='val_loss'
        ),
        tf.keras.callbacks.ReduceLROnPlateau(
            factor=0.1,
            patience=5,
            monitor='val_loss'
        )
    ]
    
    # Train
    history = model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=100,
        callbacks=callbacks
    )
    
    return model, history

In [0]:
# https://mlflow.org/docs/latest/python_api/mlflow.tensorflow.html
import mlflow
import mlflow.tensorflow
mlflow.tensorflow.autolog(
  every_n_iter=1,
  checkpoint_save_freq=100,
  )

In [0]:
train()