In [3]:
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy

def build_base_network(input_shape):
    model = Sequential([
        layers.Conv2D(64, (3, 3), activation='relu', input_shape=input_shape),
        layers.MaxPooling2D(),
        layers.Conv2D(128, (3, 3), activation='relu'),
        layers.MaxPooling2D(),
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.Dense(64, activation='relu')
    ])
    return model

input_shape = (128, 128, 1)

def siamese_network(input_shape):
    base_network = build_base_network(input_shape)
    
    input_1 = layers.Input(shape=input_shape)
    input_2 = layers.Input(shape=input_shape)
    
    encoded_1 = base_network(input_1)
    encoded_2 = base_network(input_2)
    
    distance = layers.Lambda(lambda tensors: tf.abs(tensors[0] - tensors[1]))([encoded_1, encoded_2])
    
    output = layers.Dense(1, activation='sigmoid')(distance)
    
    siamese_network = Model(inputs=[input_1, input_2], outputs=output)
    
    siamese_network.summary()
    return siamese_network

def train_siamese_network(siamese_network,image_pairs,labels):
    siamese_network.compile(optimizer=Adam(learning_rate=0.001), loss=BinaryCrossentropy(), metrics=['accuracy'])
    siamese_network.fit([image_pairs[:, 0], image_pairs[:, 1]], labels, batch_size=32, epochs=10, validation_split=0.2)
