In [38]:
import av

import keras
from keras import layers
from keras.layers import Conv2D, MaxPool2D, Flatten, Dense, BatchNormalization, Dropout
from keras.optimizers import Adam

import shutil
import imghdr
from PIL import Image

! pip install -q tensorflow-model-optimization

import os
import random
import numpy as np

import tempfile
import tensorflow_model_optimization as tfmot

In [2]:
DS_CDFV1 = 'celeb_df_v1/'
DS_CDFV2 = 'celeb_df_v2/'

DS_ORGINAL = 'dataset_original/'
DS_SPLIT = 'dataset_split/'
DS_IFRAMES = 'dataset_iframes/'
DS_FACE = 'dataset_face/'
DS_FACE_IMG = 'dataset_face_img/'
DS_SRM_SNIPPETS = 'dataset_srm_snippets_5/'
DS_SEGMENTS = 'dataset_segments/'
DS_RAW = 'dataset_raw/'
DS_RESIDUALS = 'dataset_residuals/'
DS_TEMPORAL = 'dataset_temporal/'


SEG_1 = 'seg_1/'
SEG_2 = 'seg_2/'
SEG_3 = 'seg_3/'
SEG_4 = 'seg_4/'
SEG_5 = 'seg_5/'

SEG = ['seg_1_', 'seg_2_', 'seg_3_', 'seg_4_', 'seg_5_']

DS_TRAIN = 'train_dataset/'
DS_TEST = 'test_dataset/'
DS_VAL = 'val_dataset/'

CLASS_FAKE = 'fake/'
CLASS_REAL = 'real/'


TOP_LEVEL_1 = [DS_SPLIT, DS_IFRAMES, DS_FACE, DS_FACE_IMG, DS_SRM_SNIPPETS]
TOP_LEVEL_2 = [DS_SEGMENTS, DS_RAW, DS_RESIDUALS]
SEGMENTS = [SEG_1, SEG_2, SEG_3, SEG_4, SEG_5]
SPLIT = [DS_TRAIN, DS_TEST, DS_VAL]
CLASS = [CLASS_REAL, CLASS_FAKE]

DATASET = [DS_CDFV1, DS_CDFV2]

In [19]:
def create_model(input_size):
    model = keras.Sequential()
    model.add(layers.Conv2D(input_shape=input_size, filters=8, kernel_size=3, activation='relu', padding="same"))
    model.add(BatchNormalization())
    model.add(MaxPool2D(2, 2, padding="same"))

    model.add(layers.Conv2D(input_shape=(128, 128, 8), filters=8, kernel_size=5, activation='relu', padding="same"))
    model.add(BatchNormalization())
    model.add(MaxPool2D(2, 2, padding="same"))

  
    model.add(layers.Conv2D(input_shape=(64, 64, 8), filters=16, kernel_size=5, activation='relu', padding="same"))
    model.add(BatchNormalization())
    model.add(MaxPool2D(4, 4, padding="same"))

  
    model.add(layers.Conv2D(input_shape=(16, 16, 16), filters=16, kernel_size=5, activation='relu', padding="same"))
    model.add(BatchNormalization())
    model.add(MaxPool2D(4, 4, padding="same"))
    model.add(Flatten())

    model.add(Dropout(0.5))
    model.add(layers.Dense(16))
    model.add(layers.LeakyReLU())

    model.add(Dropout(0.5))
    model.add(layers.Dense(1, activation='sigmoid'))
  
    return model

In [20]:
train_ds = keras.utils.image_dataset_from_directory(
directory = DS_FACE_IMG + DS_TRAIN,
labels = 'inferred',
label_mode = 'binary',
color_mode = 'rgb')

Found 30585 files belonging to 2 classes.


In [21]:
test_ds = keras.utils.image_dataset_from_directory(
directory = DS_FACE_IMG + DS_TEST,
labels = 'inferred',
label_mode = 'binary',
color_mode = 'rgb')

Found 3299 files belonging to 2 classes.


In [22]:
val_ds = keras.utils.image_dataset_from_directory(
directory = DS_FACE_IMG + DS_VAL,
labels = 'inferred',
label_mode = 'binary',
color_mode = 'rgb')

Found 7507 files belonging to 2 classes.


In [23]:
input_size = (256, 256, 3)
model = create_model(input_size)
model.compile(optimizer=Adam(learning_rate=0.0001), 
              loss='binary_crossentropy', 
              metrics = [keras.metrics.BinaryAccuracy(), 
                         keras.metrics.Precision(), 
                         keras.metrics.Recall(),
                         keras.metrics.AUC()])
model.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_4 (Conv2D)           (None, 256, 256, 8)       224       
                                                                 
 batch_normalization_4 (Batc  (None, 256, 256, 8)      32        
 hNormalization)                                                 
                                                                 
 max_pooling2d_4 (MaxPooling  (None, 128, 128, 8)      0         
 2D)                                                             
                                                                 
 conv2d_5 (Conv2D)           (None, 128, 128, 8)       1608      
                                                                 
 batch_normalization_5 (Batc  (None, 128, 128, 8)      32        
 hNormalization)                                                 
                                                      

In [26]:
def setup_pretrained_weights():
    
    model.fit(train_ds, 
             max_epochs = 20,
             validation_data = val_ds,
             callbacks=keras.callbacks.ModelCheckpoint(MS_MODEL),
             verbose = 1)
    _, pretrained_weights = tempfile.mkstemp('.tf')
    model.save_weights(pretrained_weights)
    return pretrained_weights

In [27]:
#pretrained_weights = setup_pretrained_weights()
#model.load_weights(pretrained_weights)    ---> recommended for accuracy, can only do this once model is trained

In [31]:
#pruning only dense layers
# Helper function uses `prune_low_magnitude` to make only the 
# Dense layers train with pruning.
def apply_pruning_to_dense(layer):
    
    if isinstance(layer, keras.layers.Dense):
        return tfmot.sparsity.keras.prune_low_magnitude(layer)
    
    return layer

In [32]:
# Use `tf.keras.models.clone_model` to apply `apply_pruning_to_dense` 
# to the layers of the model.
model_for_pruning = keras.models.clone_model(
    model,
    clone_function=apply_pruning_to_dense,
)

model_for_pruning.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_4 (Conv2D)           (None, 256, 256, 8)       224       
                                                                 
 batch_normalization_4 (Batc  (None, 256, 256, 8)      32        
 hNormalization)                                                 
                                                                 
 max_pooling2d_4 (MaxPooling  (None, 128, 128, 8)      0         
 2D)                                                             
                                                                 
 conv2d_5 (Conv2D)           (None, 128, 128, 8)       1608      
                                                                 
 batch_normalization_5 (Batc  (None, 128, 128, 8)      32        
 hNormalization)                                                 
                                                      

In [42]:
end_step = np.ceil(30585/32).astype(np.int32) * 20

In [43]:
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                               final_sparsity=0.80,
                                                               begin_step=0,
                                                               end_step=end_step)
}

pruned_model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
pruned_model.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 prune_low_magnitude_conv2d_  (None, 256, 256, 8)      442       
 4 (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_batch_n  (None, 256, 256, 8)      33        
 ormalization_4 (PruneLowMag                                     
 nitude)                                                         
                                                                 
 prune_low_magnitude_max_poo  (None, 128, 128, 8)      1         
 ling2d_4 (PruneLowMagnitude                                     
 )                                                               
                                                                 
 prune_low_magnitude_conv2d_  (None, 128, 128, 8)      3210      
 5 (PruneLowMagnitude)                                

In [None]:
#logdir = tempfile.mkdtemp()

#callbacks = [
 # tfmot.sparsity.keras.UpdatePruningStep(),
 # tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),]
  
#model_for_pruning.fit(train_images, train_labels,
#                  batch_size=batch_size, epochs=epochs, validation_split=validation_split,
#                 callbacks=callbacks)