In [None]:
import os
import sys
import numpy as np

from tqdm.notebook import tqdm
from PIL import Image

In [None]:
import matplotlib.pyplot as plt

def plot_images_with_texts(images, texts, num_each_row=5):
    num_images = len(images)
    num_rows = (num_images + num_each_row-1) // num_each_row
    num_cols = min(num_images, num_each_row) 
    
    fig, axes = plt.subplots(num_rows * 2, num_cols, figsize=(15, 6*num_rows))

    for i, (image, text) in enumerate(zip(images, texts)):
        row = (i // num_cols) * 2
        col = i % num_cols

        copy_image = image.copy()
        copy_image.thumbnail((300, 300))
        axes[row, col].imshow(copy_image)
        axes[row, col].axis('off')

        axes[row + 1, col].text(0.5, 0.5, text, ha='center', va='center', wrap=True)
        axes[row + 1, col].axis('off')

    if num_images < num_rows * num_cols:
        for i in range(num_images, num_rows * num_cols):
            row = (i // num_cols) * 2
            col = i % num_cols
            fig.delaxes(axes[row, col])
            fig.delaxes(axes[row + 1, col])

    plt.tight_layout()
    plt.show()

    
def plot_images(images, num_each_row=5):
    num_images = len(images)
    num_rows = (num_images + num_each_row-1) // num_each_row
    num_cols = min(num_images, num_each_row)

    fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 3*num_rows))

    for i, image in enumerate(images):
        row = i // num_cols
        col = i % num_cols
        
        copy_image = image.copy()
        copy_image.thumbnail((300, 300))
        axes[row, col].imshow(copy_image)
        axes[row, col].axis('off')

    if num_images < num_rows * num_cols:
        for i in range(num_images, num_rows * num_cols):
            row = i // num_cols
            col = i % num_cols
            fig.delaxes(axes[row, col])
    plt.tight_layout()


def plot_image_items(image_items, num_each_row=5, func=None):
    if func is None:
        func = lambda item: item.get_image()
    plot_images([func(item) for item in tqdm(image_items)], num_each_row=num_each_row)

In [None]:
import torch
from torchvision import transforms
from PIL import Image, ImageDraw
from facenet_pytorch import MTCNN, InceptionResnetV1
from copy import deepcopy

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
mtcnn = MTCNN(
    image_size=160, margin=0, min_face_size=20,
    thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True, select_largest=False, keep_all=True,
    device=device
)
resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device)

In [None]:
class ImageItem:
    """represents an image in the collection"""
    def __init__(self, filepath: str):
        self.filepath = filepath
        self.face_bounding_boxes = self.get_faces_images()
        self.embedding = None
        self.dict = dict()

    def get_image(self) -> Image:
        image =  Image.open(self.filepath)
        return image
        
    def get_faces_images(self):
        """returns bounding box over faces"""
        image = Image.open(self.filepath)
        image = image.convert('RGB')
        boxes, _ = mtcnn.detect(image)
        if boxes is None:
            return []
        return boxes
    
def draw_bounding_box(img, bb):
        img = deepcopy(img)
        draw = ImageDraw.Draw(img)
        x1, y1, x2, y2 = bb
        draw.line([(x1, y1), (x2, y1)], fill='red', width=15)
        draw.line([(x2, y1), (x2, y2)], fill='red', width=15)
        draw.line([(x2, y2), (x1, y2)], fill='red', width=15)
        draw.line([(x1, y2), (x1, y1)], fill='red', width=15)
        return img

In [None]:
def recursive_file_search(path):
    for thing in os.listdir(path):
        cur = os.path.join(path, thing)
        if os.path.isfile(cur):
            yield os.path.normpath(cur)
        elif os.path.isdir(cur):
            yield from recursive_file_search(cur)

In [None]:
root_path = "../image-serve-path/dataset/"
sample_path = "../image-serve-path/sample-images/"

all_image_paths = list(recursive_file_search(root_path))
all_sample_image_paths = list(recursive_file_search(sample_path))

all_images = [ImageItem(path) for path in tqdm(all_image_paths)]
all_samples = [ImageItem(path) for path in tqdm(all_sample_image_paths)]

In [None]:
from torchvision.transforms import transforms

transform = transforms.Compose([
    transforms.Resize((160, 160)),
    transforms.ToTensor()    
])

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),  # todo. remove this. maybe removing this improves it
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def get_image_embedding(image: Image):
    """returns embedding of image"""
    image = transform(image)
    image = image.unsqueeze(0) # todo use batch instead of this
    image = image.to(device)
    with torch.no_grad():
        embedding = resnet(image).cpu().numpy().reshape(-1)
    return embedding

In [None]:
def update_embedding(item):
    item.embedding = [get_image_embedding(item.get_image().crop(bb)) for bb in item.face_bounding_boxes]
    
for item in tqdm(all_images):
    update_embedding(item)
    
for item in tqdm(all_samples):
    update_embedding(item)    

In [None]:
def sim_score(all_samples, item: ImageItem):
    if len(item.embedding) == 0:
        return 0
    sample_embeddings = []
    global_embeddings = []

    global_embeddings.extend(item.embedding)
    for item in all_samples:
        sample_embeddings.extend(item.embedding)

    global_embeddings = np.array(global_embeddings)
    global_embeddings = global_embeddings / np.linalg.norm(global_embeddings, axis=-1)[:, None]
    sample_embeddings = np.array(sample_embeddings)
    sample_embeddings = sample_embeddings / np.linalg.norm(sample_embeddings, axis=-1)[:, None]

    matrix = np.sum(global_embeddings[:, None, :] * sample_embeddings[None, :, :], axis=-1)
    matrix_mean = np.mean(matrix, axis=1)
    mx =  np.max(matrix_mean) # average over samples. maximum over globals
    idx = np.argmax(matrix_mean)
    return mx, idx

In [None]:
annot = [(item, *sim_score(all_samples, item)) for item in all_images] # item, max, max_id
annot.sort(key=lambda pair: pair[1], reverse=True)

## load and save

In [None]:
import pickle

def save(filename, all_images, all_samples):    
    with open(filename, 'wb') as f:
            pickle.dump({"all_images": all_images, "all_samples": all_samples}, f)

def load(filename):
    with open(filename, 'rb') as f:
        data = pickle.load(f)
        return data["all_images"], data["all_samples"]

In [None]:
save("checkpoint1", all_images, all_samples)

In [None]:
all_images, all_samples = load("checkpoint1")

In [None]:
def small_image(img):
    img.thumbnail((300, 300))
    return img

In [None]:
    
images = [small_image(draw_bounding_box(p[0].get_image(), p[0].face_bounding_boxes[p[2]])) for p in tqdm(annot)]
texts = [f'similarity: {p[1]}' for p in annot]

In [None]:
all_images_with_bb = []

for item in tqdm(all_images):
    img = item.get_image()
    for bb in item.face_bounding_boxes:
        img = draw_bounding_box(img, bb)
    img = small_image(img)
    all_images_with_bb.append(img)

# result: InceptionResnetV1 with Vgg

In [None]:
plot_images(all_images_with_bb)

In [None]:
plot_images_with_texts(images, texts, num_each_row=3)

## compare with face_recognition

In [None]:
import face_recognition


def get_image_embedding_fr(image: Image):
    """returns embedding of image"""
    image.save("tmp.jpg")
    img = face_recognition.load_image_file("tmp.jpg")
    return face_recognition.face_encodings(img)


def update_embedding(item):
    item.embedding = get_image_embedding_fr(item.get_image())
#     item.embedding = [get_image_embedding_fr(item.get_image().crop(bb)) for bb in item.face_bounding_boxes]


for item in tqdm(all_images):
    update_embedding(item)
    
for item in tqdm(all_samples):
    update_embedding(item)    

In [None]:
save("checkpoint2", all_images, all_samples)

In [None]:
def update_item(item):
    image = item.get_image()
    image.save("tmp.jpg")
    image = face_recognition.load_image_file("tmp.jpg")

    face_locations = face_recognition.face_locations(image)
    face_encodings = face_recognition.face_encodings(image, face_locations)
    __face_locations = []
    for (a, b, c, d) in face_locations:
        __face_locations.append((b, c, d, a))
    item.face_bounding_boxes = __face_locations
    item.embedding = face_encodings
    
for item in tqdm(all_images):
    update_item(item)
    
for item in tqdm(all_samples):
    update_item(item)    

In [None]:
save("checkpoint3", all_images, all_samples)

In [None]:
annot = [(item, *sim_score(all_samples, item)) for item in all_images] # item, max, max_id
annot.sort(key=lambda pair: pair[1], reverse=True)

In [None]:
images = [small_image(draw_bounding_box(p[0].get_image(), p[0].face_bounding_boxes[p[2]])) for p in tqdm(annot)]
texts = [f'similarity: {p[1]}' for p in annot]

# result: Dlib

In [None]:
plot_images_with_texts(images, texts, num_each_row=3)

In [None]:
all_images_with_bb = []

for item in tqdm(all_images):
    img = item.get_image()
    for bb in item.face_bounding_boxes:
        img = draw_bounding_box(img, bb)
    img = small_image(img)
    all_images_with_bb.append(img)

plot_images(all_images_with_bb)

In [None]:
save("checkpoint4", all_images, all_samples)

In [None]:
all_images, all_samples = load("checkpoint4")

# bug. coordinates of location were reversed
def update_item(item):
    bbs = []
    for bb in item.face_bounding_boxes:
        x1, y1, x2, y2 = bb
        bbs.append((x2, y2, x1, y1))
    item.face_bounding_boxes = bbs
    
for item in tqdm(all_images):
    update_item(item)
    
for item in tqdm(all_samples):
    update_item(item)    

save("checkpoint5", all_images, all_samples)    

In [None]:
cpall_images, cpall_samples = load("checkpoint5")
all_images, all_samples = load("checkpoint5")

def update_item(item, cpitem):
    assert item.filepath == cpitem.filepath
    image = item.get_image()
    assert len(cpitem.embedding) == len(cpitem.face_bounding_boxes)
    embeddings = []
    for old_embedding, bb in zip(cpitem.embedding, cpitem.face_bounding_boxes):
        nw_embedding = get_image_embedding(image.crop(bb))
        embeddings.append(np.concatenate([3 * old_embedding, nw_embedding]))        
    item.embedding = embeddings
    
for item, cpitem in tqdm(zip(all_images, cpall_images)):
    update_item(item, cpitem)

for item, cpitem in tqdm(zip(all_samples, cpall_samples)):
    update_item(item, cpitem)

## result. concatenation of the two embeddings...

In [None]:
annot = [(item, *sim_score(all_samples, item)) for item in all_images] # item, max, max_id
annot.sort(key=lambda pair: pair[1], reverse=True)
images = [small_image(draw_bounding_box(p[0].get_image(), p[0].face_bounding_boxes[p[2]])) for p in tqdm(annot)]
texts = [f'similarity: {p[1]}' for p in annot]
plot_images_with_texts(images, texts, num_each_row=3)

In [None]:
save("checkpoint6", all_images, all_samples)

## Clustering Idea

In [None]:
all_images, all_samples = load("checkpoint6")

In [None]:
class FaceItem:
    def __init__(self, image_item, embedding, bb):
        self.image_item = image_item
        self.embedding = embedding
        self.bb = bb

    def get_cropped_image(self):
        return self.image_item.get_image().crop(self.bb)

all_faces = []

for item in (all_samples + all_images):
    for bb, emb in zip(item.face_bounding_boxes, item.embedding):
        all_faces.append(FaceItem(item, emb, bb))

In [None]:
import face_recognition
import os
import numpy as np
from scipy.spatial.distance import pdist
from scipy.cluster.hierarchy import linkage, fcluster

encodings = [face.embedding / np.linalg.norm(face.embedding) for face in all_faces]
distances = pdist(encodings, metric='cosine')
linkage_matrix = linkage(distances, method='complete')
threshold = 0.1
labels = fcluster(linkage_matrix, threshold, criterion='distance')

# lol!

In [None]:
images = [small_image(face.image_item.get_image().crop(face.bb)) for i, face in enumerate(all_faces) if labels[i] == 151]
plot_images(images)

In [None]:
indices = list(range(len(all_faces)))
indices.sort(key=lambda i: labels[i])

images = [small_image(all_faces[idx].image_item.get_image().crop(all_faces[idx].bb)) for idx in tqdm(indices)]
texts = [f'label: {labels[idx]}' for idx in indices]

## Clustering works very well!

In [None]:
plot_images_with_texts(images[:200], texts[:200], num_each_row=5)

In [None]:
class Cluster:
    def __init__(self, faces):
        self.faces = faces


all_clusters = []

cluster_map = {}
for face, label in zip(all_faces, labels):
    mp = cluster_map.get(label, [])
    cluster_map[label] = mp
    mp.append(face)

all_clusters = [Cluster(faces) for faces in cluster_map.values()]

In [None]:
images = []

for cluster in tqdm(all_clusters):
    for i in range(3):
        images.append(small_image(cluster.faces[min(i, len(cluster.faces)-1)].get_cropped_image()))

plot_images(images, num_each_row=3)

In [None]:
def sim_cluster(all_samples, cluster):
    sample_embeddings = []
    global_embeddings = []

    for face in cluster.faces:
        global_embeddings.append(face.embedding)
    for item in all_samples:
        sample_embeddings.extend(item.embedding)

    # TODO DRY THIS OUT LATER
    global_embeddings = np.array(global_embeddings)
    global_embeddings = global_embeddings / np.linalg.norm(global_embeddings, axis=-1)[:, None]
    sample_embeddings = np.array(sample_embeddings)
    sample_embeddings = sample_embeddings / np.linalg.norm(sample_embeddings, axis=-1)[:, None]

    matrix = np.sum(global_embeddings[:, None, :] * sample_embeddings[None, :, :], axis=-1)
    matrix_mean = np.mean(matrix, axis=1)
    mx =  np.max(matrix_mean) # average over samples. maximum over globals
    idx = np.argmax(matrix_mean)
    return mx, idx



In [None]:
annot = [(cluster, *sim_cluster(all_samples, cluster)) for cluster in all_clusters]
annot.sort(key=lambda pair: pair[1], reverse=True)

In [None]:
images = []
texts = []

for cid, (cluster, score, idx) in enumerate(tqdm(annot)):
    for face in cluster.faces:
        img = face.image_item.get_image()
        img = draw_bounding_box(img, face.bb)
        images.append(small_image(img))
        texts.append(f"cluster: {cid}\nscore: {score}")

In [None]:
plot_images_with_texts(images[:80], texts[:80], num_each_row=5)