In [None]:
%matplotlib inline

import numpy as np
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import signal
from IPython import display

from sklearn.svm import SVC
from sklearn.preprocessing import LabelEncoder
from skimage.transform import resize
from tensorflow.keras.models import load_model
import tensorflow as tf
from tensorflow import keras
print(keras.__version__)
print(tf.__version__)

from sklearn.manifold import TSNE

In [None]:
cascade_path = 'haarcascades/haarcascade_frontalface_alt2.xml'

In [None]:
model_path = 'model/facenet_keras.h5'
model = load_model(model_path)

In [None]:
gan = load_model('model/pfgan-1')

In [None]:
def generate_step(img):
    print(img.shape)
    img = tf.image.resize(img, (128, 128))
    img = img / 127.5 - 1.0
    img = img[np.newaxis, :]
    img = gan(img)
    img = tf.image.resize(img, (160, 160))
    return img[0]

In [None]:
def prewhiten(x):
    if x.ndim == 4:
        axis = (1, 2, 3)
        size = x[0].size
    elif x.ndim == 3:
        axis = (0, 1, 2)
        size = x.size
    else:
        raise ValueError('Dimension should be 3 or 4')

    mean = np.mean(x, axis=axis, keepdims=True)
    std = np.std(x, axis=axis, keepdims=True)
    std_adj = np.maximum(std, 1.0/np.sqrt(size))
    y = (x - mean) / std_adj
    return y

def l2_normalize(x, axis=-1, epsilon=1e-10):
    output = x / np.sqrt(np.maximum(np.sum(np.square(x), axis=axis, keepdims=True), epsilon))
    return output

In [None]:
def calc_embs(imgs, margin, batch_size):
    aligned_images = prewhiten(imgs)
    pd = []
    for start in range(0, len(aligned_images), batch_size):
        pd.append(model.predict_on_batch(aligned_images[start:start+batch_size]))
    embs = l2_normalize(np.concatenate(pd))

    return embs

In [None]:
class FaceDemo(object):
    def __init__(self, cascade_path):
        self.vc = None
        self.cascade = cv2.CascadeClassifier(cascade_path)
        self.margin = 10
        self.batch_size = 1
        self.n_img_per_person = 10
        self.is_interrupted = False
        self.data = {}
        self.le = None
        self.clf = None
        self.person_list = {} # associate with database person_id -> database_id
        self.num_person = 0
        self.database = [] # it should be a list with 'id'
        self.unknownCount = 0 # record the count of the unknown lables
        self.distinguishableThreshold = 1.5
        
        
    def printInfo(self) :
        print(self.person_list)
        print(self.num_person)
        #print(self.database)
        print(np.array(self.database).shape)
        
    def _signal_handler(self, signal, frame):
        self.is_interrupted = True
        
    def capture_images(self, name='Unknown'):
        vc = cv2.VideoCapture(0)
        self.vc = vc
        if vc.isOpened():
            is_capturing, _ = vc.read()
        else:
            is_capturing = False

        imgs = []
        
        signal.signal(signal.SIGINT, self._signal_handler)
        self.is_interrupted = False
        
        # ==================== debug ===================== #
        #is_capturing = True
        #
        while is_capturing:
            is_capturing, frame = vc.read()
        #    
        #    # ==================== debug ===================== #
        #    is_capturing = False
        #    # for debug usage 
        #    if name == 'ivan' :
        #        frame = cv2.imread('../img/test.jpg')
        #    else :
        #        frame = cv2.imread('../img/test2.jpg')
                
                
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            faces = self.cascade.detectMultiScale(frame,
                                         scaleFactor=1.1,
                                         minNeighbors=3,
                                         minSize=(100, 100))
            
            
            if len(faces) != 0:
                # ============ update person list ================ #
                add_new_one = False
                if name not in self.person_list :
                    add_new_one = True
                    self.person_list[name] = self.num_person
                    self.num_person = self.num_person + 1
                
                face = faces[0]
                (x, y, w, h) = face
                left = x - self.margin // 2
                right = x + w + self.margin // 2
                bottom = y - self.margin // 2
                top = y + h + self.margin // 2
                img = resize(frame[bottom:top, left:right, :],
                             (160, 160), mode='reflect')
                imgs.append(img)
                cv2.rectangle(frame,
                              (left-1, bottom-1),
                              (right+1, top+1),
                              (255, 0, 0), thickness=2)
                # ================= calculate embed and store it ! ============== #
                # gan
                #img = generate_step(img)
                emb = model.predict(prewhiten(np.array([img])))
                print(type(emb))
                print(np.array(self.database).shape)
                if add_new_one == True:
                    self.database.append(emb)
                else :
                    print("database[id] to be insert is ", np.array(self.database[self.person_list[name]]).shape)
                    self.database[self.person_list[name]] = np.concatenate((self.database[self.person_list[name]], emb), axis=0)
                    #self.database[self.person_list[name]] = np.append(np.array(self.database[self.person_list[name]]), emb, axis=0).tolist()
                    
            #print(len(self.database[self.person_list[name]]))
            plt.imshow(frame)
            plt.title('{}/{}'.format(len(imgs), self.n_img_per_person))
            plt.xticks([])
            plt.yticks([])
            display.clear_output(wait=True)
            self.data[name] = np.array(imgs)
            if len(imgs) == self.n_img_per_person:
                vc.release()
                self.data[name] = np.array(imgs)
                break
            try:
                plt.pause(0.1)
            except Exception:
                pass
            if self.is_interrupted:
                vc.release()
                break
                
    def train(self):
        labels = []
        embs = []
        names = self.data.keys()
        for name, imgs in self.data.items():
            embs_ = calc_embs(imgs, self.margin, self.batch_size)    
            labels.extend([name] * len(embs_))
            embs.append(embs_)

        embs = np.concatenate(embs)
        
        
        # ================ calculate the average vector of each person =========== #
        #for key in self.person_list :
        #    feature = self.database[self.person_list[key]]
        #    new_feature = np.mean(np.array(feature), axis=0)
        #    self.database[self.person_list[key]] = [new_feature.tolist()]
        
        for i in range(len(self.database)):
            while self.database[i].shape[0] < 10:
                self.database[i] = np.concatenate((self.database[i], self.database[i]), axis = 0)
            if self.database[i].shape[0] > 10:
                self.database[i] = self.database[i][-10:]
        
        x = []
        y = []
        
        for i in range(len(self.database)):
            for j in range(len(self.database[i])):
                x.append(self.database[i][j])
                y.append(i)
                
        x = np.array(x)
        y = np.array(y)
            
        """no fucking use this !!!
        le = LabelEncoder().fit(labels)
        y = le.transform(labels)"""
        clf = SVC(kernel='linear', probability=True).fit(x, y)
        #clf = SVC(kernel='linear', probability=True).fit(np.array(self.database).reshape((self.num_person, self.n_img_per_person, 128)), list(self.person_list.values()))
        
#         self.le = le
        self.clf = clf
        
        
        
    def whoAreYou(self, img) :
        
        shouldUpdate = False
        pred = ''
        
        embeddings = model.predict(prewhiten(np.array([img])))
        emb = self.clf.predict_proba(embeddings)[0]
        
        mx1 = 0
        mx2 = 0
        mx1_index = -1
        
        for i in range(len(emb)):
            if emb[i] > mx1:
                mx2 = mx1
                mx1 = emb[i]
                mx1_index = i
            elif (emb[i] < mx1 and emb[i] > mx2):
                mx2 = emb[i]
                
        print('mx1: ' + str(mx1) + ' mx2: ' + str(mx2))
        
        if mx1 / mx2 < self.distinguishableThreshold:
            shouldUpdate = True
            self.unknownCount += 1
            pred = 'Unknown' + str(self.unknownCount)
        else:
            pred = list(self.person_list.keys())[np.argmax(emb)]
            
            if "Unknown" in pred and np.argmax(emb) > 0.8 and len(self.database[np.argmax(emb)])<10:
                self.database[np.argmax(emb)].append(embeddings)
                self.train()
        
        return shouldUpdate, pred
            
    # check whether the given key in the person_list or not
    def checkKey(self, key): 
        if key in self.person_list.keys(): 
            return True
        else: 
            return False 
            
    def infer(self):
        vc = cv2.VideoCapture(0)
        self.vc = vc
        if vc.isOpened():
            is_capturing, _ = vc.read()
        else:
            is_capturing = False

        signal.signal(signal.SIGINT, self._signal_handler)
        
        self.is_interrupted = False
        #is_capturing = True
        
        while is_capturing:
            is_capturing, frame = vc.read()
            #is_capturing = False
            # for debug usage 
            #frame = cv2.imread('../img/test2.jpg')
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            faces = self.cascade.detectMultiScale(frame,
                                         scaleFactor=1.1,
                                         minNeighbors=3,
                                         minSize=(100, 100))
            
            pred = None
            shouldUpdate = False
            img = None
            
            if len(faces) != 0:
                for faceid, (x, y, w, h) in enumerate(faces) :
                    left = x - self.margin // 2
                    right = x + w + self.margin // 2
                    bottom = y - self.margin // 2
                    top = y + h + self.margin // 2
                    img = resize(frame[bottom:top, left:right, :],
                                 (160, 160), mode='reflect')
                    #img = generate_step(img)
                    shouldUpdate, pred = self.whoAreYou(img)
                    cv2.rectangle(frame,
                                  (left-1, bottom-1),
                                  (right+1, top+1),
                                  (255, 0, 0), thickness=2)
                    cv2.putText(frame, str(pred), (left-1, bottom-1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 0, 0), 2)

            plt.imshow(frame)
            plt.title(pred)
            plt.xticks([])
            plt.yticks([])
            display.clear_output(wait=True)
            
            try:
                plt.pause(0.1)
            except Exception:
                pass
            if self.is_interrupted:
                vc.release()
                break
                
            if shouldUpdate and ("Unknown" in pred): 
                self.data[pred] = np.array([img])
                self.person_list[pred] = self.num_person
                self.num_person = self.num_person + 1
                # ====== add the fucking dick face into database ======= #
                new_face = model.predict(prewhiten(np.array([img])))
                self.database.append(new_face)
                self.train()
                self.infer()
            

In [None]:
f = FaceDemo(cascade_path)

In [None]:
# Train with two or more people
f.capture_images('ivan')

In [None]:
# Train with two or more people
f.capture_images('park seo joon')

In [None]:
# Train with two or more people
f.capture_images('jisoo')

In [None]:
# Train with two or more people
f.capture_images('iu')

In [None]:
f.printInfo()

In [None]:
f.train()

In [None]:
f.infer()

In [None]:
f.printInfo()

In [None]:
X = []
Y = []
for i in range(len(f.database)):
    X.append(f.database[i])
    Y.append([i] * f.database[i].shape[0])
    
X = np.array(X)
Y = np.array(Y)

X = X.reshape(-1, 128)

In [None]:
embedded = TSNE(n_components=2, perplexity=10, learning_rate=200, init='pca', random_state=501).fit_transform(X)
patch = []

for i in range(Y.shape[0]):
    if i == 0:
        label = 'male1'
        color = 'red'
    elif i == 1:
        label = 'male2'
        color = 'blue'
    elif i == 2:
        label = 'female1'
        color = 'orange'
    else:
        label = 'female2'
        color = 'green'

    
    points = embedded[i*10:(i+1)*10]
    plt.scatter(points[:, 0], points[:, 1], c=color)
    
    patch.append(mpatches.Patch(color=color, label=label))
plt.title("Visualizing Embedding with t-SNE")
plt.legend(handles=patch)
plt.show()