In [11]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, LSTM, Dense, Flatten, Reshape
from tensorflow.keras.models import Model, Sequential 



In [None]:
import os
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split

# Define paths
data_dir = '/path/to/casia-fasd'  # Replace with your dataset path
real_label = 0
spoof_label = 1

# Function to load images
def load_images_from_folder(folder, label):
    images = []
    labels = []
    for filename in os.listdir(folder):
        img_path = os.path.join(folder, filename)
        img = cv2.imread(img_path)
        if img is not None:
            img = cv2.resize(img, (224, 224))  # Resize to MobileNetV2 expected input size
            img = img / 255.0  # Normalize the image
            images.append(img)
            labels.append(label)
    return np.array(images), np.array(labels)

# Load real and spoof images
real_images, real_labels = load_images_from_folder(os.path.join(data_dir, 'real'), real_label)
spoof_images, spoof_labels = load_images_from_folder(os.path.join(data_dir, 'spoof'), spoof_label)

# Combine real and spoof images
X = np.concatenate((real_images, spoof_images), axis=0)
y = np.concatenate((real_labels, spoof_labels), axis=0)

# One-hot encode labels
y = to_categorical(y, num_classes=2)

# Split dataset into training, validation, and test sets
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.3, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

# Convert to TensorFlow Dataset for efficient loading and training
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(32).shuffle(buffer_size=1024)
val_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val)).batch(32)
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(32)

# Optionally, save datasets
np.savez('/path/to/save/casia_train.npz', X_train=X_train, y_train=y_train)
np.savez('/path/to/save/casia_val.npz', X_val=X_val, y_val=y_val)
np.savez('/path/to/save/casia_test.npz', X_test=X_test, y_test=y_test)

# Now the datasets are ready for training


In [4]:
def resnet_feature_extractor(input_shape):
    base_model = tf.keras.applications.MobileNetV2(input_shape=input_shape, include_top=False, weights='imagenet')
    base_model.trainable = False  # Freezing the base model for efficiency
    x = layers.GlobalAveragePooling2D()(base_model.output)
    return models.Model(inputs=base_model.input, outputs=x)

In [5]:
def lstm_temporal_analysis(x, lstm_units=64):
    # Reshape to make sure it's in the right shape for LSTM
    x = layers.Reshape((-1, x.shape[-1]))(x)
    
    # LSTM for detecting temporal patterns
    x = layers.LSTM(lstm_units, return_sequences=True)(x)
    x = layers.LSTM(lstm_units, return_sequences=True)(x)
    
    # Self-Attention to capture repetitive patterns
    attention = layers.MultiHeadAttention(num_heads=4, key_dim=lstm_units)(x, x)
    x = layers.Add()([x, attention])
    x = layers.GlobalAveragePooling1D()(x)
    
    return x



In [6]:
def tiny_transformer_block(x, d_model, num_heads, dff, training):
    attn_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)(x, x)
    attn_output = layers.Dropout(0.1)(attn_output, training=training)
    out1 = layers.LayerNormalization(epsilon=1e-6)(x + attn_output)

    ffn_output = layers.Dense(dff, activation='relu')(out1)
    ffn_output = layers.Dense(d_model)(ffn_output)
    ffn_output = layers.Dropout(0.1)(ffn_output, training=training)

    return layers.LayerNormalization(epsilon=1e-6)(out1 + ffn_output)

def ensemble_adapters(x, num_adapters, d_model, bottleneck_dim):
    adapter_outputs = []
    for _ in range(num_adapters):
        adapter = layers.Dense(bottleneck_dim, activation='gelu')(x)
        adapter = layers.Dense(d_model)(adapter)
        adapter_outputs.append(adapter)

    adapter_outputs = tf.stack(adapter_outputs, axis=0)
    adapter_outputs = tf.reduce_mean(adapter_outputs, axis=0)
    
    return x + adapter_outputs

def feature_wise_transformation(x, d_model):
    scale = tf.Variable(initial_value=tf.ones((d_model,)), trainable=True)
    shift = tf.Variable(initial_value=tf.zeros((d_model,)), trainable=True)
    return x * scale + shift

def adaptive_transformer_block(x, d_model, num_heads, dff, num_adapters, bottleneck_dim, training):
    x = tiny_transformer_block(x, d_model, num_heads, dff, training)
    x = ensemble_adapters(x, num_adapters, d_model, bottleneck_dim)
    return feature_wise_transformation(x, d_model)


In [7]:
def class_conditional_domain_discriminator(x, num_domains, d_model):
    x = layers.Dense(d_model, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    domain_pred = layers.Dense(num_domains, activation='softmax', name='domain_output')(x)
    return domain_pred


In [8]:
def build_combined_model(input_shape, d_model=64, num_heads=4, dff=128, num_blocks=4, num_adapters=2, bottleneck_dim=32, lstm_units=64, num_domains=2):
    inputs = layers.Input(shape=input_shape)

    # ResNet Feature Extraction
    resnet_model = resnet_feature_extractor(input_shape=input_shape)
    x = resnet_model(inputs)

    # LSTM for Temporal Analysis and Repetitive Pattern Detection
    x = lstm_temporal_analysis(x, lstm_units=lstm_units)

    # Reshape for Transformer input
    x = layers.Reshape((1, -1))(x)  # Adjust based on expected transformer input

    # Adaptive Transformer Blocks
    for _ in range(num_blocks):
        x = adaptive_transformer_block(x, d_model, num_heads, dff, num_adapters, bottleneck_dim, training=True)

    # Global Pooling and Feature Extraction
    x = layers.GlobalAveragePooling1D()(x)

    # Class-Conditional Domain Discriminator
    domain_pred = class_conditional_domain_discriminator(x, num_domains, d_model)
    
    # Final Classification (Real vs Spoof)
    spoof_pred = layers.Dense(1, activation='sigmoid', name='spoof_output')(x)

    return models.Model(inputs=inputs, outputs=[spoof_pred, domain_pred])


In [9]:
import numpy as np
input_shape = (224, 224, 3)  # Example input shape for an image

model = build_combined_model(input_shape=input_shape)

model.compile(optimizer='adam', 
              loss={'spoof_output': 'binary_crossentropy', 'domain_output': 'categorical_crossentropy'}, 
              metrics={'spoof_output': 'accuracy', 'domain_output': 'accuracy'})

# Example data
X_train = np.random.random((100, 224, 224, 3))
y_train_spoof = np.random.randint(2, size=(100, 1))
y_train_domain = tf.keras.utils.to_categorical(np.random.randint(2, size=(100, 1)), num_classes=2)

# Train the model
model.fit(X_train, {'spoof_output': y_train_spoof, 'domain_output': y_train_domain}, epochs=10, batch_size=16)


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x1905b26f408>

In [10]:
import tensorflow_model_optimization as tfmot

# Define a pruning schedule
pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(
    initial_sparsity=0.0, final_sparsity=0.5, begin_step=2000, end_step=10000
)

# Apply pruning
pruned_model = tfmot.sparsity.keras.prune_low_magnitude(
    model, pruning_schedule=pruning_schedule
)

# Compile the pruned model
pruned_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Fine-tune the pruned model
callbacks = [tfmot.sparsity.keras.UpdatePruningStep()]
pruned_model.fit(x_train, y_train, epochs=3, validation_split=0.1, callbacks=callbacks)


ModuleNotFoundError: No module named 'tensorflow_model_optimization'

In [None]:
# Strip the pruning wrappers
final_model = tfmot.sparsity.keras.strip_pruning(pruned_model)

# Save the model
final_model.save('pruned_model.h5')


In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(final_model)

# Enable full integer quantization
converter.optimizations = [tf.lite.Optimize.DEFAULT]

# Convert the model
tflite_model = converter.convert()

# Save the quantized model
with open('quantized_model.tflite', 'wb') as f:
    f.write(tflite_model)
