In [None]:
import os
from keras.models import Sequential
from keras.applications.vgg16 import VGG16, preprocess_input
from keras.layers import BatchNormalization, Dense, Input, GlobalMaxPooling2D, MaxPooling2D,Flatten,Concatenate
from keras.models import Model
import keras.backend as K
import shutil
import random
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

In [None]:
class batch_generator:
    def __init__(self,batch_size=256,
                 target_size=(128,128),
                 horizontal_flip=False,
                 rescale=1,
                 rotate=False,
                 other_transform = False,
                 path = "",
                 
                ):
        self.batch_size=batch_size
        self.target_size = target_size
        self.horizontal_flip = horizontal_flip
        self.rescale = rescale
        self.rotate = rotate
#         self.color_argumentation = color_argumentation
        self.other_transform = other_transform
        self.path = path
        self.x_shift = 0.5 * self.target_size[0]
        coeffs = self.find_coeffs(
            [(0, 0), (256, 0), (256, 256), (0, 256)],
            [(0, 0), (256, 0), (self.target_size[0], self.target_size[1]), (self.x_shift, self.target_size[1])])
        
    
    def __iter__(self):
        return self
    
    def read_image(self):
        dir_list = os.listdir(self.path)
        original = random.choice(dir_list)
        dir_list.remove(original)
        fake = random.choice(dir_list)
        img_list = os.listdir(os.path.join(self.path,original))
        anchor = random.choice(img_list)
        img_list.remove(anchor)
        positive = random.choice(img_list)
        img_list = os.listdir(os.path.join(self.path,fake))
        negative = random.choice(img_list)
        
        anchor_image = Image.open(os.path.join(self.path,original,anchor))
        negative_image = Image.open(os.path.join(self.path,fake,negative))
        positive_image = Image.open(os.path.join(self.path,original,positive))
        
        anchor_image = anchor_image.resize(self.target_size, Image.NEAREST) 
        positive_image = positive_image.resize(self.target_size, Image.NEAREST) 
        negative_image = negative_image.resize(self.target_size, Image.NEAREST)
        
        anchor_image = np.array(self.transform(anchor_image))
        positive_image = np.array(self.transform(positive_image))
        negative_image = np.array(self.transform(negative_image))
        
        
        return (anchor_image,positive_image,negative_image)
        
        
        
        
    def transform(self,img):
        
        if self.horizontal_flip and random.choice([True,False]) :
            img = img.transform(self.target_size,Image.FLIP_LEFT_RIGHT)
            
        if self.rotate :
            deg = random.randrange(0,359)
            img = img.rotate(deg)
            
        if self.other_transform :
            r = random.choice([True,False])
            if r:
                img = img.transform(self.target_size, Image.PERSPECTIVE, coeffs,Image.BICUBIC)
            else:
                img = img.transform(self.target_size, Image.AFFINE,(1, -0.5, -self.x_shift if -0.5 > 0 else 0, 0, 1, 0), Image.BICUBIC)
        
        return img
    
    def find_coeffs(self,pa, pb):
        matrix = []
        for p1, p2 in zip(pa, pb):
            matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0]*p1[0], -p2[0]*p1[1]])
            matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1]*p1[0], -p2[1]*p1[1]])

        A = np.matrix(matrix, dtype=np.float)
        B = np.array(pb).reshape(8)

        res = np.dot(np.linalg.inv(A.T * A) * A.T, B)
        return np.array(res).reshape(8)

    
    def __next__(self):
        anchor = np.zeros((self.batch_size,)+self.target_size+(3,))
        positive = np.zeros((self.batch_size,)+self.target_size+(3,))
        negative = np.zeros((self.batch_size,)+self.target_size+(3,))
        
        for i in range(self.batch_size):
            a,p,n = self.read_image()
            anchor[i] = a
            positive[i] = p
            negative[i] = n
            
        return ([anchor*self.rescale,positive*self.rescale,negative*self.rescale],np.zeros((self.batch_size,3*128)))
            

In [None]:
vgg16 = VGG16(include_top=False, weights='imagenet', input_shape=(128,128,3))

x = vgg16.output
output = GlobalMaxPooling2D()(vgg16.output)
pre_trained = Model(vgg16.input, output)
for layer in vgg16.layers:
    layer.trainable = False
# for layer in vgg16.layers[8:]:
#     layer.trainable = True
    
    
def base_model(input_shape):
    input = Input((input_shape,), name="input")
    
    x = Dense(1024,activation="relu")(input)
    
    x = Dense(128)(x)
    
    return Model(input,x)

conv_feat_size = K.int_shape(pre_trained.output)[-1]
base = base_model(conv_feat_size)

def Final_model(pre_trained):
    inp_shape = K.int_shape(pre_trained.input)[1:]
    
    input1 = Input(inp_shape, name="anchor")
    input2 = Input(inp_shape, name="postive")
    input3 = Input(inp_shape, name="negative")
    
    
    output1 = base(pre_trained(input1))
    output2 = base(pre_trained(input2))
    output3 = base(pre_trained(input3))
    
    concat = Concatenate(axis=-1)([output1,output2,output3])
    
    return Model([input1,input2,input3],concat)

Final_model = Final_model(pre_trained)

In [None]:
def triplet_loss(y_true, y_pred, alpha = 0.2):
    embedding_size = K.int_shape(y_pred)[-1] // 3
    ind = int(embedding_size * 2)
    a_pred = y_pred[:, :embedding_size]
    p_pred = y_pred[:, embedding_size:ind]
    n_pred = y_pred[:, ind:]
   
    positive_distance = K.sqrt(K.sum(K.square(a_pred - p_pred), axis=-1))
    negative_distance = K.sqrt(K.sum(K.square(a_pred - n_pred), axis=-1))
    
    loss = K.maximum(0.0, positive_distance - negative_distance + alpha)
    return loss
    

In [None]:
Final_model.compile(loss=triplet_loss,optimizer="adam")

In [None]:
path = "../input/facedatasets/facedata/faceData/"
test_gen = iter(batch_generator(path=path, 
                                target_size=(128,128),
                                horizontal_flip=False,
                                rotate=True,
                                other_transform = False,
                                rescale=1./256,
                                batch_size=16))

In [None]:
Final_model.fit_generator(test_gen,epochs=10,validation_data=test_gen, validation_steps= 100, steps_per_epoch=1000)

In [None]:
x,y = next(test_gen)

In [None]:
x = np.array(x)

In [None]:
model_predict = base.predict(pre_trained.predict(x[0,:,:,:]))

In [None]:
y_ = pre_trained.predict(x[0,:,:,:])

In [None]:
base.predict(y_)

In [None]:
!wget https://thenypost.files.wordpress.com/2019/05/selena-gomez-cannes.jpg

In [None]:
!wget https://static-ssl.businessinsider.com/image/57fe6c004046ddf8008b5668-2400/ap_616611858368.jpg

In [None]:
!wget https://www.biography.com/.image/t_share/MTI2NDQwNDA2NTg5MTUwNDgy/ariana-grande-shutterstock_213445195-600x487jpg.jpg

In [None]:
img1 = Image.open('./ap_616611858368.jpg')
img1 = img1.resize((128,128),Image.NEAREST)
img1 = np.array(img1)

img2 = Image.open('./selena-gomez-cannes.jpg')
img2 = img2.resize((128,128),Image.NEAREST)
img2 = np.array(img2)

img3 = Image.open('./ariana-grande-shutterstock_213445195-600x487jpg.jpg')
img3 = img3.resize((128,128),Image.NEAREST)
img3 = np.array(img3)

In [None]:
val_gen = iter(batch_generator(path='../input/onshottesting/testing/testing/', 
                                target_size=(128,128),
                                horizontal_flip=False,
                                rotate=True,
                                other_transform = False,
                                rescale=1./255,
                                batch_size=16))

In [None]:
x_,y_ = next(val_gen)

In [None]:
y_ = base.predict(pre_trained.predict(x_[0]))
y__ = base.predict(pre_trained.predict(x_[1]))
y___ = base.predict(pre_trained.predict(x_[2]))
y____ = base.predict(pre_trained.predict(x_[3]))

In [None]:
i = 0

In [None]:
Y = Final_model.predict(x_)

In [None]:
Y[0].shape[0]/4

In [None]:
np.abs(1-np.sum(Y[i][128:2*128]*Y[i][3*128:]))