# Setup

In [None]:
import os
import random
import numpy as np
from matplotlib import pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Layer, Conv2D, Dense, MaxPooling2D, Input, Flatten
import shutil
from pathlib import Path
import re
import uuid
from tensorflow.keras.metrics import Precision, Recall
from tensorflow.keras.mixed_precision import set_global_policy
from tensorflow.keras.callbacks import *
import time
from sklearn.metrics import accuracy_score
import gc

try:

    import tensorflow_directml as tf_directml
    print("DirectML plugin loaded successfully!")
    tf_directml.enable_mixed_precision()
    print("Available devices:")
    for device in tf.config.list_physical_devices():
        print(f"  {device}")

    physical_devices = tf.config.list_physical_devices('DML')
    if physical_devices:
        for device in physical_devices:
            tf.config.experimental.set_memory_growth(device, True)
        print(f"DirectML devices configured: {len(physical_devices)} devices")
    else:
        print("No DirectML devices found, falling back to CPU")
        
except ImportError:
    print("tensorflow-directml not found. Install with: pip install tensorflow-directml")
    print("Falling back to CPU training")

try:
    set_global_policy('mixed_float16')
except:
    print("Mixed precision not supported, using float32")

In [None]:
tf.config.threading.set_intra_op_parallelism_threads(0)
tf.config.threading.set_inter_op_parallelism_threads(0)

In [None]:
TRAIN = os.path.join('data', 'training')
TEST = os.path.join('data', 'test')
ARCH = os.path.join('data', 'archive')

In [None]:
class DirectMLCompatibleEarlyStopping(Callback):
    """DirectML-optimized early stopping callback"""
    
    def __init__(self, validation_data=None, patience=5, min_delta=0.01, 
                 min_verification_acc=0.90, restore_best_weights=True, verbose=1):
        super().__init__()
        self.validation_data = validation_data
        self.patience = patience
        self.min_delta = min_delta
        self.min_verification_acc = min_verification_acc
        self.restore_best_weights = restore_best_weights
        self.verbose = verbose
        self.best_verification_acc = 0
        self.best_weights = None
        self.wait = 0
        
    def on_epoch_end(self, epoch, logs=None):
        if self.validation_data is not None:
            val_acc = self.calculate_verification_accuracy()
            
            if self.verbose:
                print(f"Epoch {epoch + 1} - Verification Accuracy: {val_acc:.4f}")
            
            if val_acc > self.best_verification_acc + self.min_delta:
                self.best_verification_acc = val_acc
                self.wait = 0
                if self.restore_best_weights:
                    self.best_weights = self.model.get_weights()
            else:
                self.wait += 1
                
            if val_acc < self.min_verification_acc:
                print(f"Verification accuracy dropped below {self.min_verification_acc}. Stopping training.")
                self.model.stop_training = True
                
            if self.wait >= self.patience:
                print(f"Verification accuracy hasn't improved for {self.patience} epochs. Stopping training.")
                self.model.stop_training = True
                
            if self.model.stop_training and self.restore_best_weights and self.best_weights:
                self.model.set_weights(self.best_weights)
                print(f"Restored best weights with verification accuracy: {self.best_verification_acc:.4f}")
    
    def calculate_verification_accuracy(self):
        """DirectML-optimized verification accuracy calculation"""
        anchor_imgs, comparison_imgs, labels = self.validation_data
        
        # Smaller batch size for DirectML stability
        batch_size = 16  # Reduced from 32
        predictions = []
        
        for i in range(0, len(anchor_imgs), batch_size):
            try:
                batch_anchors = anchor_imgs[i:i+batch_size]
                batch_comparisons = comparison_imgs[i:i+batch_size]
                
                processed_anchors = []
                processed_comparisons = []
                
                for img in batch_anchors:
                    if isinstance(img, str):
                        processed_img = preprocess_with_augmentation(img, is_training=False)
                    else:
                        processed_img = tf.cast(img, tf.float32) / 255.0 if tf.reduce_max(img) > 1.0 else img
                    processed_anchors.append(processed_img)
                
                for img in batch_comparisons:
                    if isinstance(img, str):
                        processed_img = preprocess_with_augmentation(img, is_training=False)
                    else:
                        processed_img = tf.cast(img, tf.float32) / 255.0 if tf.reduce_max(img) > 1.0 else img
                    processed_comparisons.append(processed_img)
                
                processed_anchors = tf.stack(processed_anchors)
                processed_comparisons = tf.stack(processed_comparisons)
                
                # DirectML-compatible prediction
                with tf.device('/DML:0' if tf.config.list_physical_devices('DML') else '/CPU:0'):
                    model_output = self.model([processed_anchors, processed_comparisons])
                    if isinstance(model_output, list):
                        batch_preds = model_output[-1]
                    else:
                        batch_preds = model_output
                        
                predictions.extend(batch_preds.numpy().flatten())
            except Exception as e:
                print(f"Error in batch processing: {e}")
                continue
        
        if len(predictions) == 0:
            return 0.0
            
        binary_preds = (np.array(predictions) > 0.5).astype(int)
        return accuracy_score(labels[:len(binary_preds)], binary_preds)



In [None]:
class DirectMLGPUMonitor(Callback):
    """DirectML-specific GPU utilization monitor"""
    
    def __init__(self):
        super().__init__()
        
    def on_epoch_begin(self, epoch, logs=None):
        # More aggressive garbage collection for DirectML
        gc.collect()
        tf.keras.backend.clear_session()
        
        # Force garbage collection on DirectML device
        if tf.config.list_physical_devices('DML'):
            try:
                # DirectML-specific memory cleanup
                tf.config.experimental.reset_memory_stats('DML:0')
            except:
                pass

# Data manipulation

In [None]:
def create_directml_data_pipeline(anchors, comparisons, labels, batch_size=32, 
                                 prefetch_buffer=tf.data.AUTOTUNE, num_parallel_calls=4):
    """Create DirectML-optimized data pipeline"""
    
    # Reduce parallelism for DirectML stability
    if num_parallel_calls == tf.data.AUTOTUNE:
        num_parallel_calls = 4  # Fixed value for DirectML
    
    anchor_ds = tf.data.Dataset.from_tensor_slices(anchors)
    comparison_ds = tf.data.Dataset.from_tensor_slices(comparisons)
    labels_ds = tf.data.Dataset.from_tensor_slices(labels)
    dataset = tf.data.Dataset.zip((anchor_ds, comparison_ds, labels_ds))
    
    # Smaller shuffle buffer for DirectML
    dataset = dataset.shuffle(buffer_size=min(5000, len(anchors)))
    
    dataset = dataset.map(
        lambda a, c, l: (
            (preprocess_with_augmentation(a, is_training=True),
             preprocess_with_augmentation(c, is_training=True)),
            tf.cast(l, tf.float32)
        ),
        num_parallel_calls=num_parallel_calls
    )

    # Smaller batch size for DirectML
    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.prefetch(2)  # Reduced prefetch for DirectML
    
    # More conservative caching for DirectML
    if len(anchors) < 20000:  # Reduced threshold
        dataset = dataset.cache()
    
    return dataset

In [None]:
def create_directml_augmentation():
    """DirectML-compatible augmentation pipeline"""
    return tf.keras.Sequential([
        tf.keras.layers.RandomFlip("horizontal"),
        tf.keras.layers.RandomRotation(0.05),  # Reduced rotation
        tf.keras.layers.RandomZoom(0.05),      # Reduced zoom
        tf.keras.layers.RandomContrast(0.05),  # Reduced contrast
    ])

DIRECTML_AUGMENTATION = create_directml_augmentation()

In [None]:
def preprocess_with_augmentation(img_path, is_training=True):
    """DirectML-optimized preprocessing"""
    byte_img = tf.io.read_file(img_path)
    img = tf.io.decode_jpeg(byte_img, channels=3)
    img = tf.image.resize(img, (100, 100))
    
    img = tf.cast(img, tf.float32) / 255.0
    
    if is_training:
        img = DIRECTML_AUGMENTATION(img)
    
    return img

In [None]:
def create_optimized_pairs_from_directory(directory, max_people=1000, max_pairs_per_person=100):
    """DirectML-optimized pair creation with reduced data size"""
    person_dirs = [d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))]

    if len(person_dirs) > max_people:
        person_dirs = np.random.choice(person_dirs, max_people, replace=False)
    
    print(f"Using {len(person_dirs)} people for DirectML training")
    pos_anchor_paths = []
    pos_comparison_paths = []
    neg_anchor_paths = []
    neg_comparison_paths = []

    for idx, person in enumerate(person_dirs):
        person_path = os.path.join(directory, person)
        person_images = [os.path.join(person_path, f) for f in os.listdir(person_path) 
                         if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

        if len(person_images) < 2:
            continue

        pair_count = 0
        for i in range(len(person_images)):
            if pair_count >= max_pairs_per_person:
                break
            for j in range(i+1, len(person_images)):
                if pair_count >= max_pairs_per_person:
                    break
                pos_anchor_paths.append(person_images[i])
                pos_comparison_paths.append(person_images[j])
                pair_count += 1

    print(f"Created {len(pos_anchor_paths)} positive pairs")

    target_negative_pairs = len(pos_anchor_paths)
    
    for idx, person in enumerate(person_dirs):
        if len(neg_anchor_paths) >= target_negative_pairs:
            break
            
        person_path = os.path.join(directory, person)
        person_images = [os.path.join(person_path, f) for f in os.listdir(person_path) 
                         if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

        if len(person_images) < 1:
            continue

        other_people = [p for p in person_dirs if p != person]
        if not other_people:
            continue

        for anchor_img in person_images: 
            if len(neg_anchor_paths) >= target_negative_pairs:
                break

            sampled_others = np.random.choice(other_people, min(len(other_people), 10), replace=False)
            
            for other_person in sampled_others:
                if len(neg_anchor_paths) >= target_negative_pairs:
                    break
                    
                other_person_path = os.path.join(directory, other_person)
                other_person_images = [os.path.join(other_person_path, f) 
                                       for f in os.listdir(other_person_path) 
                                       if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

                if not other_person_images:
                    continue

                for _ in range(min(4, len(other_person_images))):
                    if len(neg_anchor_paths) >= target_negative_pairs:
                        break
                    negative_img = np.random.choice(other_person_images)
                    neg_anchor_paths.append(anchor_img)
                    neg_comparison_paths.append(negative_img)

    print(f"Created {len(neg_anchor_paths)} negative pairs")
    
    all_anchor_paths = pos_anchor_paths + neg_anchor_paths
    all_comparison_paths = pos_comparison_paths + neg_comparison_paths

    positive_labels = tf.ones(len(pos_anchor_paths))
    negative_labels = tf.zeros(len(neg_anchor_paths))
    all_labels = tf.concat([positive_labels, negative_labels], axis=0)

    return all_anchor_paths, all_comparison_paths, all_labels

In [None]:
def create_validation_data(test_directory, num_pairs=500):
    """Create validation data (reduced size for DirectML)"""
    person_dirs = [d for d in os.listdir(test_directory) 
                   if os.path.isdir(os.path.join(test_directory, d))]
    
    anchors, comparisons, labels = [], [], []
    
    for person in person_dirs[:30]:  # Reduced for DirectML
        person_path = os.path.join(test_directory, person)
        images = [os.path.join(person_path, f) for f in os.listdir(person_path) 
                 if f.lower().endswith(('.jpg', '.jpeg', '.png'))][:5]
        
        if len(images) >= 2:
            for i in range(min(3, len(images)-1)):
                anchors.append(images[i])
                comparisons.append(images[i+1])
                labels.append(1)

    for i in range(len(anchors)):
        if len(person_dirs) > 1:
            other_person = random.choice([p for p in person_dirs[:30] if p != person_dirs[i % len(person_dirs)]])
            other_person_path = os.path.join(test_directory, other_person)
            other_images = [os.path.join(other_person_path, f) for f in os.listdir(other_person_path) 
                           if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            
            if other_images:
                anchors.append(anchors[i])
                comparisons.append(random.choice(other_images))
                labels.append(0)
    
    return anchors[:num_pairs], comparisons[:num_pairs], labels[:num_pairs]

# Help functions

In [None]:
def create_directml_embedding():
    """DirectML-optimized embedding network"""
    inputs = tf.keras.Input(shape=(100, 100, 3), name="input_img")

    # Simpler architecture for DirectML compatibility
    x = tf.keras.layers.Conv2D(32, (5, 5), activation='relu', padding='same')(inputs)
    x = tf.keras.layers.MaxPooling2D((2, 2))(x)
    x = tf.keras.layers.BatchNormalization()(x)
    
    x = tf.keras.layers.Conv2D(64, (5, 5), activation='relu', padding='same')(x)
    x = tf.keras.layers.MaxPooling2D((2, 2))(x)
    x = tf.keras.layers.BatchNormalization()(x)
    
    x = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = tf.keras.layers.MaxPooling2D((2, 2))(x)
    x = tf.keras.layers.BatchNormalization()(x)
    
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    x = tf.keras.layers.Dropout(0.3)(x)

    outputs = tf.keras.layers.Dense(256, activation='sigmoid', dtype='float32')(x)
    
    return tf.keras.Model(inputs=inputs, outputs=outputs, name='directml_embedding')

In [None]:
class L1Dist(Layer):
    def __init__(self, **kwargs):
        super().__init__()

    def call(self, in_embed, valid_embed):
        return tf.math.abs(in_embed - valid_embed)

class EuclideanDistance(Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    
    def call(self, inputs):
        anchor, comparison = inputs
        return tf.sqrt(tf.reduce_sum(tf.square(anchor - comparison), axis=-1, keepdims=True))

In [None]:
def create_directml_siamese_model():
    """DirectML-compatible Siamese model"""
    embedding_network = create_directml_embedding()
    
    anchor_input = Input(shape=(100, 100, 3), name='anchor')
    comparison_input = Input(shape=(100, 100, 3), name='comparison')
    
    anchor_embedding = embedding_network(anchor_input)
    comparison_embedding = embedding_network(comparison_input)
    
    distance = EuclideanDistance(name='euclidean_distance')([anchor_embedding, comparison_embedding])
    prediction = Dense(1, activation='sigmoid', name='prediction', dtype='float32')(distance)
    
    model = Model(inputs=[anchor_input, comparison_input], 
                  outputs=[distance, prediction],
                  name='directml_siamese')
    
    return model, embedding_network

In [None]:
class DirectMLContrastiveLoss(tf.keras.losses.Loss):
    def __init__(self, margin=1.0, **kwargs):
        super().__init__(**kwargs)
        self.margin = margin
    
    def call(self, y_true, y_pred):
        y_true = tf.cast(y_true, tf.float32)
        
        # More stable computation for DirectML
        y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7)
        
        loss = y_true * tf.square(y_pred) + \
               (1 - y_true) * tf.square(tf.maximum(0.0, self.margin - y_pred))
        return tf.reduce_mean(loss)

In [None]:
def create_directml_model():
    """Create and compile DirectML-optimized model"""
    siamese_model, embedding_model = create_directml_siamese_model()
    
    # DirectML-compatible optimizer settings
    optimizer = tf.keras.optimizers.Adam(
        learning_rate=5e-4,  # Slightly higher LR for DirectML
        epsilon=1e-7,        # More stable epsilon
        clipnorm=1.0         # Gradient clipping for stability
    )
    
    siamese_model.compile(
        optimizer=optimizer,
        loss={
            'euclidean_distance': DirectMLContrastiveLoss(margin=1.0),
            'prediction': 'binary_crossentropy'
        },
        loss_weights={'euclidean_distance': 1.0, 'prediction': 0.5},
        metrics={'prediction': 'accuracy'}
    )
    
    return siamese_model, embedding_model

# Training

In [None]:
def directml_curriculum_training(model, train_directory, test_directory, 
                               initial_epochs=15, fine_tune_epochs=20, batch_size=32):
    """DirectML-optimized curriculum training"""
    
    print("Creating validation data for DirectML training...")
    val_anchors, val_comparisons, val_labels = create_validation_data(test_directory)
    
    print("\n=== DirectML Stage 1: Initial Training ===")
    
    # Reduced data size for DirectML
    train_anchors, train_comparisons, train_labels = create_optimized_pairs_from_directory(
        train_directory, max_people=800, max_pairs_per_person=25
    )
    
    print(f"Created {len(train_anchors)} training pairs for DirectML")

    train_dataset = create_directml_data_pipeline(
        train_anchors, train_comparisons, train_labels, batch_size=batch_size
    )
    
    callbacks = [
        DirectMLCompatibleEarlyStopping(
            validation_data=(val_anchors, val_comparisons, val_labels),
            patience=5,
            min_verification_acc=0.65,  # Lower threshold for DirectML
            verbose=1
        ),
        DirectMLGPUMonitor(),
        ReduceLROnPlateau(
            monitor='loss',
            factor=0.7,
            patience=3,
            min_lr=1e-6,
            verbose=1
        ),
        ModelCheckpoint(
            'face_verification_directml.h5',
            monitor='loss',
            save_best_only=True,
            verbose=1
        )
    ]
    
    print("Starting DirectML training...")
    history1 = model.fit(
        train_dataset,
        epochs=initial_epochs,
        callbacks=callbacks,
        verbose=1
    )
    
    print("\n=== DirectML Stage 2: Fine-tuning ===")
    
    # Reduce learning rate for fine-tuning
    model.optimizer.learning_rate = 1e-4
    callbacks[0].min_verification_acc = 0.75
    callbacks[0].patience = 4
    
    print("Starting DirectML fine-tuning...")
    history2 = model.fit(
        train_dataset,
        epochs=fine_tune_epochs,
        callbacks=callbacks,
        verbose=1
    )
    
    return model, [history1, history2]

In [None]:
def train_directml_face_verification():
    """Main DirectML training function"""
    
    print("Creating DirectML-optimized model...")
    model, embedding_model = create_directml_model()
    
    print("Starting DirectML curriculum training...")
    trained_model, training_histories = directml_curriculum_training(
        model, TRAIN, TEST, 
        initial_epochs=20, 
        fine_tune_epochs=25, 
        batch_size=24  # Smaller batch size for DirectML
    )
    
    print("Saving DirectML model...")
    trained_model.save('face_verification_directml_final.h5')
    print("DirectML model saved!")
    
    return trained_model, training_histories

# Execute

In [None]:
print("Starting DirectML Face Verification Training...")
print("=" * 50)

if tf.config.list_physical_devices('DML'):
    print(f"DirectML devices available: {len(tf.config.list_physical_devices('DML'))}")
else:
    print("No DirectML devices found. Training will use CPU.")

try:
    final_model, histories = train_directml_face_verification()
    print("\n🎉 DirectML training completed successfully!")
except Exception as e:
    print(f"❌ DirectML training failed: {e}")
    print("Try reducing batch size or model complexity further.")