In [1]:
import tensorflow as tf
import json
import os, sys
from datetime import datetime

In [2]:
sys.path.append(os.path.abspath(os.path.join('..','data_processing')))
sys.path.append(os.path.abspath(os.path.join('..','models')))

### Import contrastive data

In [None]:
from contrastive_preprocessing import train_contrastive_dataset, val_contrastive_dataset

### Import triplet data

In [None]:
from triplet_preprocessing import train_triplet_dataset, val_triplet_dataset

### Import contrastive models

In [16]:
from contrastive_v1 import create_and_compile_contrastive_v1
from contrastive_v2 import create_and_compile_contrastive_v2
from contrastive_v3 import create_and_compile_contrastive_v3

### Import triplet models

In [4]:
from triplet_v4 import create_and_compile_triplet_v4
from triplet_v5 import create_and_compile_triplet_v5
from triplet_v6 import create_and_compile_triplet_v6

### Train or continue training the models

In [10]:
def train_model(
        model,
        train_ds,
        val_ds,
        epochs,
        steps_per_epoch,
        validation_steps,
        initial_epoch=0,
        base_dir='../results/siamese',
        model_type='contrastive',  # 'contrastive' or 'triplet'
        model_name=None,
        batch_size=32,
        patience=10,
        previous_history_path=None
):
    # Set default model name if none provided
    if model_name is None:
        model_name = f"{model_type}_basic"

    # Create directory structure
    model_dir = os.path.join(base_dir, model_type, model_name)
    os.makedirs(model_dir, exist_ok=True)

    # Define paths
    current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
    training_info_path = os.path.join(model_dir, f'{model_name}_training_info_{current_time}.json')
    best_model_path = os.path.join(model_dir, f'{model_name}_best_{current_time}.h5')

    # Initialize training info
    training_info = {
        "model_name": model_name,
        "batch_size": batch_size,
        "epochs_completed": 0,
        "training_history": {}
    }

    # Load previous history if provided
    if previous_history_path and os.path.exists(previous_history_path):
        with open(previous_history_path, 'r') as f:
            previous_info = json.load(f)
            training_info["training_history"] = previous_info["training_history"]

    callbacks = [
        tf.keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=5,
            min_lr=1e-7,
            verbose=1
        ),
        tf.keras.callbacks.ModelCheckpoint(
            best_model_path,
            monitor='val_loss',
            save_best_only=True,
            mode='min',
            verbose=1
        ),
        tf.keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=patience,
            restore_best_weights=True,
            verbose=1
        )
    ]

    # Train the model
    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=epochs,
        initial_epoch=initial_epoch,
        steps_per_epoch=steps_per_epoch,
        validation_steps=validation_steps,
        callbacks=callbacks,
        verbose=1
    )

    # Update training history
    for key in history.history:
        if key not in training_info["training_history"]:
            training_info["training_history"][key] = []
        training_info["training_history"][key].extend(
            [float(val) for val in history.history[key]]
        )

    training_info["epochs_completed"] = epochs

    # Save updated training info
    with open(training_info_path, 'w') as f:
        json.dump(training_info, f, indent=4)

    return history

### Example usage

In [None]:
v1 = create_and_compile_contrastive_v1(dropout_rate=0.3, learning_rate=0.001)

In [None]:
history = train_model(v1,train_contrastive_dataset,val_contrastive_dataset,100,4000,2000,model_name='contrastive_v1',model_type='contrastive')