In [1]:
import av

from PIL import Image
from matplotlib import pyplot as plt

import tensorflow as tf

import keras
from keras import layers
from keras.layers import Conv2D, MaxPool2D, Flatten, Dense, BatchNormalization, Dropout, Layer
from keras.optimizers import Adam
from keras import backend as K

import os
import random
import math
import numpy as np
import cv2

In [2]:
DS_CDFV1 = 'celeb_df_v1/'
DS_CDFV2 = 'celeb_df_v2/'

DS_ORGINAL = 'dataset_original/'
DS_SPLIT = 'dataset_split/'
DS_IFRAMES = 'dataset_iframes/'
DS_FACE = 'dataset_face/'
DS_FACE_IMG = 'dataset_face_img/'
DS_SRM_SNIPPETS = 'dataset_srm_snippets_5/'
DS_SEGMENTS = 'dataset_segments/'
DS_RAW = 'dataset_raw/'
DS_RESIDUALS = 'dataset_residuals/'

SEG_1 = 'seg_1/'
SEG_2 = 'seg_2/'
SEG_3 = 'seg_3/'
SEG_4 = 'seg_4/'
SEG_5 = 'seg_5/'

DS_TRAIN = 'train_dataset/'
DS_TEST = 'test_dataset/'
DS_VAL = 'val_dataset/'

CLASS_FAKE = 'fake/'
CLASS_REAL = 'real/'


TOP_LEVEL_1 = [DS_SPLIT, DS_IFRAMES, DS_FACE, DS_FACE_IMG, DS_SRM_SNIPPETS]
TOP_LEVEL_2 = [DS_SEGMENTS, DS_RAW, DS_RESIDUALS]
SEGMENTS = [SEG_1, SEG_2, SEG_3, SEG_4, SEG_5]
SPLIT = [DS_TRAIN, DS_TEST, DS_VAL]
CLASS = [CLASS_REAL, CLASS_FAKE]

DATASET = [DS_CDFV1, DS_CDFV2]

# Snippet Extraction

## Functions

In [100]:
# Returns the index of frames that begin a new segment (except the first segment)
def get_segment_dividers(frame_count, num_segments):
    segments_per_frame = math.floor(frame_count / num_segments)

    return [(segments_per_frame * i) for i in range(1, num_segments) ]

In [101]:
# Returns the indices of the frames that will be randomly selected from each segment
# Multiple snippets indices per segment can be returned by setting the num_snippets arg 
def get_snippet_indices(segment_dividers, num_snippets):
    start_index = 0
    num_snippets = 1 if num_snippets <= 0 else num_snippets

    snippet_indices = []
    for end_index in segment_dividers:

        # Extracting multiple snippets per segment (if needed)
        for _ in range(num_snippets):
            snippet_indices.append(random.randint(start_index, end_index - 1))

        start_index = end_index
        
    return snippet_indices

In [102]:
# Returns an array of randomly selected snippets(PIL.Image) from each segment of the input video
def extract_snippets(fp, num_segments, num_snippets):
    vid_container = av.open(fp)
    vid_stream = vid_container.streams.video[0]
    frame_count = vid_stream.frames

    snippets = []

    # If number of frames in video is less than the number of frames that need to sampled
    # then take all frames in the video
    if frame_count < num_segments * num_snippets:
        for frame in vid_container.decode():
            snippets.append(frame.to_image())

    else:
        segment_dividers = get_segment_dividers(frame_count, num_segments)
        segment_dividers = segment_dividers + [frame_count]

        snippet_indices = get_snippet_indices(segment_dividers, num_snippets)

        frame_index = 0
        for frame in vid_container.decode():
            if frame_index > max(snippet_indices):
                break

            if frame_index in snippet_indices:
                snippets.append(frame.to_image())

            frame_index += 1

    return snippets

## Testing Logic

In [54]:
tmp_count = 30
tmp_seg = get_segment_dividers(tmp_count, 3)
tmp_snip = get_snippet_indices(tmp_seg + [tmp_count], 2)

print(f'Segment Dividers: {tmp_seg}')
print(f'Snippets{tmp_snip}')

Segment Dividers: [10, 20]
Snippets[4, 9, 15, 13, 20, 26]


In [99]:
# test_file = os.listdir(DS_CDFV1 + DS_SPLIT + DS_TRAIN + CLASS_REAL)[0]
test_file = DS_CDFV2 + DS_SPLIT + DS_TRAIN + CLASS_REAL + 'id27_0005.mp4'
# test_input = av.open(os.path.realpath(DS_CDFV1 + DS_SPLIT + DS_TRAIN + CLASS_REAL + test_file))
test_input = av.open(test_file)

print(test_input.streams.video[0].frames)

# for frame in test_input.decode():
#     print(frame.key_frame)

1


In [95]:
# test_file = os.listdir(DS_CDFV1 + DS_FACE + DS_TRAIN + CLASS_REAL)[0]
test_file = DS_CDFV2 + DS_FACE + DS_TRAIN + CLASS_REAL + 'id27_0005.mp4'
# test_file = DS_CDFV1 + DS_FACE + DS_TRAIN + CLASS_REAL + test_file
tmp_snippets = extract_snippets(test_file, 5, 1)

for s in tmp_snippets:
    s.show()

## Implementation

### Celeb-DF v1 & v2

In [103]:
def save_snippets_CDF(dataset, num_segments, num_snippets):
    if dataset != DS_CDFV1 and dataset != DS_CDFV2:
        print(dataset)
        return
    
    random.seed(1)
    
    src_base_path = dataset + DS_FACE
    dst_base_path = dataset + DS_SRM_SNIPPETS

    for split in SPLIT:
        print(f'---Split started: {split}---')
        for class_dir in CLASS:
            print(f'Class started: {class_dir}')

            for video in os.listdir(src_base_path + split + class_dir):
                fp = src_base_path + split + class_dir + video
                snippets = extract_snippets(fp, num_segments, num_snippets)

                for i, snippet in enumerate(snippets, start=1):
                    seg_index = math.ceil(float(i) / num_snippets)
                    snip_index = (i - 1) % num_snippets
              
                    dst = f'{dst_base_path + split + class_dir + os.path.splitext(video)[0]}_s{seg_index}_f{snip_index}.jpeg'
                    snippet.save(dst)         

In [85]:
# CELEB DF V1
save_snippets_CDF(DS_CDFV1, num_segments=5, num_snippets=1)

---Split started: train_dataset/---
Class started: real/
Class started: fake/
---Split started: test_dataset/---
Class started: real/
Class started: fake/
---Split started: val_dataset/---
Class started: real/
Class started: fake/


In [104]:
# CELEB DF V2
save_snippets_CDF(DS_CDFV2, num_segments=5, num_snippets=1)

---Split started: train_dataset/---
Class started: real/
Class started: fake/
---Split started: test_dataset/---
Class started: real/
Class started: fake/
---Split started: val_dataset/---
Class started: real/
Class started: fake/


# Tensor Dataset Creation

In [3]:
def create_tensor_dataset(dataset, split):
    ds = keras.utils.image_dataset_from_directory(
        directory = dataset + DS_SRM_SNIPPETS + split,
        labels = 'inferred',
        label_mode = 'binary',
        batch_size = 32,
        color_mode = 'rgb',
        shuffle = True,
        seed = 1
    )

    return ds

## Celeb DF v1

In [4]:
train_dataset_cdfv1 = create_tensor_dataset(DS_CDFV1, DS_TRAIN)
test_dataset_cdfv1 = create_tensor_dataset(DS_CDFV1, DS_TEST)
val_dataset_cdfv1 = create_tensor_dataset(DS_CDFV1, DS_VAL)

Found 4415 files belonging to 2 classes.
Found 500 files belonging to 2 classes.
Found 1100 files belonging to 2 classes.


## Celeb DF v2

In [39]:
train_dataset_cdfv2 = create_tensor_dataset(DS_CDFV2, DS_TRAIN)
test_dataset_cdfv2 = create_tensor_dataset(DS_CDFV2, DS_TEST)
val_dataset_cdfv2 = create_tensor_dataset(DS_CDFV2, DS_VAL)

Found 24046 files belonging to 2 classes.
Found 2590 files belonging to 2 classes.
Found 6005 files belonging to 2 classes.


# Model Creation

**For Convolution**
- Input Shape:  [batch_size, height, width, channels]
- Filter Shape: [height, width, in_channels, filter_count]

## Functions

In [21]:
class SRMLayer(keras.layers.Layer):
    def __init__(self, strides=[1,1,1,1], padding='SAME'):
        super(SRMLayer, self).__init__()
        self.strides = strides
        self.padding = padding

        # Set of 3 fixed SRM Filters used to extract noise & semantic features
        self.filter_small = tf.constant([[0, 0,  0, 0, 0],
                                         [0, 0,  0, 0, 0],
                                         [0, 1, -2, 1, 0],
                                         [0, 0,  0, 0, 0],
                                         [0, 0,  0, 0, 0]], dtype=tf.float32)
        
        self.filter_med = tf.constant([[0,  0,  0,  0, 0],
                                       [0, -1,  2, -1, 0],
                                       [0,  2, -4,  2, 0],
                                       [0, -1,  2, -1, 0],
                                       [0,  0,  0,  0, 0]], dtype=tf.float32)
        
        self.filter_large = tf.constant([[-1,  2,  -2,  2, -1],
                                         [ 2, -6,   8, -6,  2],
                                         [-2,  8, -12,  8, -2],
                                         [ 2, -6,   8, -6,  2],
                                         [-1,  2,  -2,  2, -1]], dtype=tf.float32)

        # Learnability in SRM filters introduced by 'q' values
        # SRM filters are divided by their respective 'q' values before convolution
        # 'q' values are updated during backpropagation using gradient descent
        self.q_small = self.add_weight(name='q_small',
                                       shape=(1, ),
                                       initializer=keras.initializers.Constant(value=2.0),
                                       trainable=True)
        
        self.q_med = self.add_weight(name='q_med',
                                     shape=(1, ),
                                     initializer=keras.initializers.Constant(value=4.0),
                                     trainable=True)
        
        self.q_large = self.add_weight(name='q_large',
                                       shape=(1, ),
                                       initializer=keras.initializers.Constant(value=12.0),
                                       trainable=True)
        
        # 3rd dimension of filters => number of input channels (Three channels)
        self.filter_small = tf.stack([self.filter_small, self.filter_small, self.filter_small], axis=2)
        self.filter_med   = tf.stack([self.filter_med, self.filter_med, self.filter_med], axis=2)
        self.filter_large = tf.stack([self.filter_large, self.filter_large, self.filter_large], axis=2)

        # 4th dimension of filters => number of output feature maps (One feature map)
        # Each filter gives a single output feature map
        self.filter_small = tf.expand_dims(self.filter_small, axis=-1)
        self.filter_med   = tf.expand_dims(self.filter_med, axis=-1)
        self.filter_large = tf.expand_dims(self.filter_large, axis=-1)
        
    def call(self, inputs):
        filter_small = tf.math.divide(self.filter_small, self.q_small)
        filter_med   = tf.math.divide(self.filter_med, self.q_med)
        filter_large = tf.math.divide(self.filter_large, self.q_large)

        output_small = tf.nn.conv2d(inputs, filter_small, strides=self.strides, padding=self.padding)
        output_med   = tf.nn.conv2d(inputs, filter_med,   strides=self.strides, padding=self.padding)
        output_large = tf.nn.conv2d(inputs, filter_large, strides=self.strides, padding=self.padding)

        return tf.concat([output_small, output_med, output_large], axis=3)

    def get_config(self):
        config = super(SRMLayer, self).get_config()
        config.update({'strides': self.strides,
                       'padding': self.padding,
                       'filter_small': self.filter_small,
                       'filter_med': self.filter_med,
                       'filter_large': self.filter_large,
                       'q_small': self.q_small,
                       'q_med': self.q_med,
                       'q_large': self.q_large})
        
        return config


In [24]:
class XceptionPreProcessor(keras.layers.Layer):
    def __init__(self):
        super(XceptionPreProcessor, self).__init__()


    def call(self, inputs):
        return tf.keras.applications.xception.preprocess_input(inputs)

In [23]:
Xception_network = keras.applications.Xception(
    include_top = False,
    weights = 'imagenet',
    input_shape = (256, 256, 3),
    pooling = max,
    classes = 2
)

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5


## Implementation

In [29]:
# Freezing all layers in Xception network to use pre-trained weights (trained on ImageNet)
Xception_network.trainable = False

for layer in Xception_network.layers:
    assert layer.trainable == False

In [None]:
SRM_Model = keras.Sequential([
    keras.layers.Input(shape=(256, 256, 3)),
    SRMLayer(),
    
])

## Testing Logic

In [25]:
tmp_srm_model = keras.Sequential([
    keras.layers.Input(shape=(256, 256, 3)),
    SRMLayer(),
    XceptionPreProcessor(),
    keras.layers.Flatten(),
    keras.layers.Dense(1, activation='sigmoid')
])

tmp_srm_model.compile(optimizer=keras.optimizers.Adam(),
                      loss=keras.losses.BinaryCrossentropy(),
                      metrics=['acc'])

tmp_srm_model.summary()

Model: "sequential_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 srm_layer_3 (SRMLayer)      (None, 256, 256, 3)       3         
                                                                 
 xception_pre_processor (Xce  (None, 256, 256, 3)      0         
 ptionPreProcessor)                                              
                                                                 
 flatten_3 (Flatten)         (None, 196608)            0         
                                                                 
 dense_3 (Dense)             (None, 1)                 196609    
                                                                 
Total params: 196,612
Trainable params: 196,612
Non-trainable params: 0
_________________________________________________________________


In [7]:
tmp_srm_model.layers[0].get_weights()

[array([2.], dtype=float32),
 array([4.], dtype=float32),
 array([12.], dtype=float32)]

In [8]:
tmp_srm_model.fit(train_dataset_cdfv1, epochs=1, validation_data=val_dataset_cdfv1)



<keras.callbacks.History at 0x20a6951efe0>

In [9]:
tmp_srm_model.layers[0].get_weights()

[array([2.0912213], dtype=float32),
 array([4.1425376], dtype=float32),
 array([12.13928], dtype=float32)]

In [14]:
tmp_srm_model.layers[2].get_weights()

[array([[-0.00065065],
        [-0.00062465],
        [-0.00713642],
        ...,
        [-0.00407955],
        [ 0.00504078],
        [ 0.00208746]], dtype=float32),
 array([-0.0025801], dtype=float32)]

In [12]:
tmp_srm_model.save_weights('models/tmp_srm_model')

In [15]:
tmp_srm_model_2 = keras.Sequential([
    keras.layers.Input(shape=(256, 256, 3)),
    SRMLayer(),
    keras.layers.Flatten(),
    keras.layers.Dense(1, activation='sigmoid')
])

tmp_srm_model.compile(optimizer=keras.optimizers.Adam(),
                      loss=keras.losses.BinaryCrossentropy(),
                      metrics=['acc'])

tmp_srm_model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 srm_layer (SRMLayer)        (None, 256, 256, 3)       3         
                                                                 
 flatten (Flatten)           (None, 196608)            0         
                                                                 
 dense (Dense)               (None, 1)                 196609    
                                                                 
Total params: 196,612
Trainable params: 196,612
Non-trainable params: 0
_________________________________________________________________


In [19]:
tmp_srm_model_2.layers[0].get_weights()

[array([2.0912213], dtype=float32),
 array([4.1425376], dtype=float32),
 array([12.13928], dtype=float32)]

In [18]:
tmp_srm_model_2.load_weights('models/tmp_srm_model')

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x20a7b2c7370>

In [4]:
tmp_filter_small = tf.constant([[0, 0,  0, 0, 0],
                                [0, 0,  0, 0, 0],
                                [0, 1, -2, 1, 0],
                                [0, 0,  0, 0, 0],
                                [0, 0,  0, 0, 0]], dtype=tf.float32)

tmp_filter_med = tf.constant([[0,  0,  0,  0, 0],
                              [0, -1,  2, -1, 0],
                              [0,  2, -4,  2, 0],
                              [0, -1,  2, -1, 0],
                              [0,  0,  0,  0, 0]], dtype=tf.float32)

tmp_filter_small = tf.expand_dims(tf.expand_dims(tmp_filter_small, axis=-1), axis=-1)
tmp_filter_med   = tf.expand_dims(tf.expand_dims(tmp_filter_med,   axis=-1), axis=-1)
# tf.print(tmp_filter_small.shape)

tmp_filter_small = tf.tile(tmp_filter_small, [1, 1, 3, 1])
tmp_filter_med   = tf.tile(tmp_filter_med,   [1, 1, 3, 1])

tmp_filters = tf.concat([tmp_filter_small, tmp_filter_med], axis=3)

tf.print(tmp_filters.shape)

tmp_conv = Conv2D(filters=2, 
                  kernel_size=5, 
                  kernel_initializer=tf.keras.initializers.Constant(tmp_filters))

tmp_x = tf.random.normal([32, 256, 256, 3])
tmp_y = tmp_conv(tmp_x)

print(tmp_y.shape)

TensorShape([5, 5, 3, 2])
(32, 252, 252, 2)


### Custom SRM Layer

In [9]:
class SSRM(Layer):
    def __init__(self, **kwargs):
        super(SSRM, self).__init__(**kwargs)

        self.filter_small = tf.constant([[0, 0,  0, 0, 0],
                                         [0, 0,  0, 0, 0],
                                         [0, 1, -2, 1, 0],
                                         [0, 0,  0, 0, 0],
                                         [0, 0,  0, 0, 0]], dtype=tf.float32)
        
        self.filter_med = tf.constant([[0,  0,  0,  0, 0],
                                       [0, -1,  2, -1, 0],
                                       [0,  2, -4,  2, 0],
                                       [0, -1,  2, -1, 0],
                                       [0,  0,  0,  0, 0]], dtype=tf.float32)
        
        self.filter_large = tf.constant([[-1,  2,  -2,  2, -1],
                                         [ 2, -6,   8, -6,  2],
                                         [-2,  8, -12,  8, -2],
                                         [ 2, -6,   8, -6,  2],
                                         [-1,  2,  -2,  2, -1]], dtype=tf.float32)
        
        self.filter_small = tf.expand_dims(tf.expand_dims(self.filter_small, axis=-1), axis=-1)
        self.filter_med   = tf.expand_dims(tf.expand_dims(self.filter_med,   axis=-1), axis=-1)
        self.filter_large = tf.expand_dims(tf.expand_dims(self.filter_large, axis=-1), axis=-1)

        self.filter_small = tf.tile(self.filter_small, [1, 1, 3, 1])
        self.filter_med   = tf.tile(self.filter_med,   [1, 1, 3, 1])
        self.filter_large = tf.tile(self.filter_large, [1, 1, 3, 1])
        
        def call(self, inputs):
            output_small = tf.nn.conv2d(inputs, self.filter_small, strides=[1, 1, 1, 1], padding='SAME')
            output_med   = tf.nn.conv2d(inputs, self.filter_med,   strides=[1, 1, 1, 1], padding='SAME')
            output_large = tf.nn.conv2d(inputs, self.filter_large, strides=[1, 1, 1, 1], padding='SAME')

            return tf.concat([output_small, output_med, output_large], axis=3)

In [17]:
class PSRM_Small(Layer):
    def __init__(self):
        super(PSRM_Small, self).__init__()

        self.filter_small = tf.constant([[0, 0,  0, 0, 0],
                                         [0, 0,  0, 0, 0],
                                         [0, 1, -2, 1, 0],
                                         [0, 0,  0, 0, 0],
                                         [0, 0,  0, 0, 0]], dtype=tf.float32)
        
        self.filter_small = tf.stack([self.filter_small, self.filter_small, self.filter_small], axis=2)
        self.filter_small = tf.expand_dims(self.filter_small, axis=-1)

        self.q_small = self.add_weight(name = 'q_small',
                                 shape = (1, ),
                                 initializer = tf.initializers.Constant(value=2),
                                 trainable = True)
        
    def call(self, inputs):
        filter_small = tf.math.divide(self.filter_small, self.q_small)
        return tf.nn.conv2d(inputs, filter_small, strides=[1,1,1,1], padding='SAME')
        

In [19]:
PSRM_Model = keras.Sequential([
    keras.layers.Input(shape=(256,256,3)),
    PSRM_Small(),
    Flatten(),
    Dense(1, activation='sigmoid')
])

PSRM_Model.compile(optimizer=keras.optimizers.Adam(),
                   loss=keras.losses.BinaryCrossentropy(),
                   metrics=['acc'])

PSRM_Model.summary()

Model: "sequential_7"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 psrm__small_6 (PSRM_Small)  (None, 256, 256, 1)       1         
                                                                 
 flatten_7 (Flatten)         (None, 65536)             0         
                                                                 
 dense_7 (Dense)             (None, 1)                 65537     
                                                                 
Total params: 65,538
Trainable params: 65,538
Non-trainable params: 0
_________________________________________________________________


In [20]:
PSRM_Model.fit(train_dataset_cdfv1, epochs=1)



<keras.callbacks.History at 0x1982a8b99f0>

### Checking Custom Dense Layer

In [10]:
class TMPDense(keras.layers.Layer):
    def __init__(self, units, input_dim):
        super(TMPDense, self).__init__()

        self.w = self.add_weight(name = 'w',
                                 shape = (input_dim, units),
                                 initializer = tf.initializers.random_normal,
                                 trainable = True)
        
        self.b = self.add_weight(name = 'b',
                                  shape = (units, ),
                                  initializer = tf._initializers.zeros,
                                  trainable = True)
        
        self.m = self.add_weight(name = 'm',
                                 shape = (1, ),
                                 initializer = tf.initializers.Constant(value=2.0),
                                 trainable=True)
        
    def call(self, inputs):
        div = tf.math.divide(self.w, self.m)
        return tf.matmul(inputs, div) + self.b

In [12]:
dModel = keras.Sequential([
    keras.layers.Input(shape=(256,256,3)),
    SSRM(),
    Flatten(),
    TMPDense(5, 196608),
    TMPDense(1, 5),
])

dModel.compile(optimizer = keras.optimizers.Adam(),
               loss = keras.losses.BinaryCrossentropy(),
               metrics = ['accuracy']
               )

dModel.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 ssrm_1 (SSRM)               (None, 256, 256, 3)       0         
                                                                 
 flatten_1 (Flatten)         (None, 196608)            0         
                                                                 
 tmp_dense_2 (TMPDense)      (None, 5)                 983046    
                                                                 
 tmp_dense_3 (TMPDense)      (None, 1)                 7         
                                                                 
Total params: 983,053
Trainable params: 983,053
Non-trainable params: 0
_________________________________________________________________


In [13]:
dModel.layers[3].get_weights()

[array([[-0.00287271],
        [-0.04635007],
        [ 0.03803451],
        [-0.02399216],
        [-0.03421838]], dtype=float32),
 array([0.], dtype=float32),
 array([2.], dtype=float32)]

In [17]:
dModel.fit(train_dataset_cdfv1, epochs=5, validation_data=val_dataset_cdfv1)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x1c33ee7e920>

In [18]:
dModel.layers[3].get_weights()

[array([[-0.01285018],
        [-0.05632751],
        [ 0.04801176],
        [-0.01401469],
        [-0.0242409 ]], dtype=float32),
 array([-0.00997706], dtype=float32),
 array([2.0099752], dtype=float32)]

### Custom Conv2D Layer

In [7]:
class TMPConv(Layer):
    def __init__(self):
        super(TMPConv, self).__init__()

        # self.kernel = self.add_weight(name = 'kernel', 
        #                               shape = (3, 3, 3, 1),
        #                               initializer = tf.initializers.random_normal,
        #                               trainable = True)

        self.kernel = tf.constant([[0, 0,  0, 0, 0],
                                         [0, 0,  0, 0, 0],
                                         [0, 1, -2, 1, 0],
                                         [0, 0,  0, 0, 0],
                                         [0, 0,  0, 0, 0]], dtype=tf.float32)
        
        self.kernel = tf.stack([self.kernel, self.kernel, self.kernel], axis=2)
        self.kernel = tf.expand_dims(self.kernel, axis=-1)
        
        self.q = self.add_weight(name = 'q',
                                 shape = (1, ),
                                 initializer = tf.initializers.Constant(value=2),
                                 trainable = True)
        
    def call(self, inputs):
        kernel = tf.math.divide(self.kernel, self.q)
        return tf.nn.conv2d(inputs, kernel, strides = [1,1,1,1], padding='SAME')

In [8]:
tmp_model = keras.Sequential([
    keras.layers.Input(shape=(256, 256, 3)),
    TMPConv(),
    Flatten(),
    Dense(1, activation='sigmoid'),
])

tmp_model.compile(loss=keras.losses.BinaryCrossentropy(),
                  optimizer=Adam(),
                  metrics=['acc'])

tmp_model.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 tmp_conv (TMPConv)          (None, 256, 256, 1)       1         
                                                                 
 flatten_1 (Flatten)         (None, 65536)             0         
                                                                 
 dense_1 (Dense)             (None, 1)                 65537     
                                                                 
Total params: 65,538
Trainable params: 65,538
Non-trainable params: 0
_________________________________________________________________


In [58]:
tmp_model.layers[0].get_weights()

[array([2.], dtype=float32)]

In [59]:
tmp_model.fit(train_dataset_cdfv1, epochs=1, validation_data=val_dataset_cdfv1)



<keras.callbacks.History at 0x1c3000a2800>

In [60]:
tmp_model.layers[0].get_weights()

[array([2.0662959], dtype=float32)]

### Old

In [15]:
class TSRM(Layer):
    def __init__(self, **kwargs):
        super(TSRM, self).__init__(**kwargs)

        self.filter_small = tf.constant([[0, 0,  0, 0, 0],
                                         [0, 0,  0, 0, 0],
                                         [0, 1, -2, 1, 0],
                                         [0, 0,  0, 0, 0],
                                         [0, 0,  0, 0, 0]], dtype=tf.float32)
        
        self.q_small = self.add_weight(shape=(),
                                       initializer='ones',
                                       dtype=tf.float32,
                                       trainable=True)
        
    def call(self, inputs):
        filter_small = self.filter_small / self.q_small

        # Input shape: [batch_size, height, width, channels]
        # Filter shape: [height, width, in_channels, out_channels]

        conv_small = tf.nn.conv2d(inputs, 
                                  tf.expand_dims(tf.expand_dims(filter_small, axis=-1), axis=-1), 
                                  strides=[1, 1, 1, 1], 
                                  padding='SAME')
        
        # conv_meddd = tf.nn.conv2d(inputs, filter_small, strides=[1, 1, 1, 1], padding='SAME')

        # outputs = tf.concat([conv_small, conv_meddd], axis=-1)
        return conv_small

In [None]:
tmp_model = keras.Sequential([
    keras.layers.Input(shape=(256, 256, 3)),
    TSRM(),
    Flatten(),
    Dense(1, activation='sigmoid')
])

tmp_model.compile(optimizer=Adam(), loss='binary_crossentropy', metrics=['accuracy'])

In [40]:
tmp_history = tmp_model.fit(train_dataset_cdfv1, epochs=1, validation_data=val_dataset_cdfv1)



In [34]:
tmp_kernel1 = np.array([[0, 0,  0, 0, 0],
                    [0, 0,  0, 0, 0],
                    [0, 1, -2, 1, 0],
                    [0, 0,  0, 0, 0],
                    [0, 0,  0, 0, 0]], dtype=float)

tmp_kernel2 = np.array([[0,  0,  0,  0, 0],
                    [0, -1,  2, -1, 0],
                    [0,  2, -4,  2, 0],
                    [0, -1,  2, -1, 0],
                    [0,  0,  0,  0, 0]], dtype=float)

tmp_kernel3 = np.array([[-1,  2,  -2,  2, -1],
                    [ 2, -6,   8, -6,  2],
                    [-2,  8, -12,  8, -2],
                    [ 2, -6,   8, -6,  2],
                    [-1,  2,  -2,  2, -1]], dtype=float)

tmp_fixed_kernals = [tmp_kernel1, tmp_kernel2, tmp_kernel3]
tmp_custom_layer = SRMLayer(filters=3, kernel_size=(5, 5), fixed_filters=tmp_fixed_kernals)