<a href="https://www.kaggle.com/code/kenny3s/lfw-transformer?scriptVersionId=157677262" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
!pip install git+https://github.com/openai/CLIP.git
!pip install open_clip_torch
!pip install sentence_transformers

import torch
import open_clip
import cv2
from sentence_transformers import util
import torchvision.datasets as datasets
from PIL import Image
import numpy as np
import os
# image processing model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16-plus-240', pretrained="laion400m_e32")
model.to(device)
def imageEncoder(img):
    img1 = Image.fromarray(img).convert('RGB')
    img1 = preprocess(img1).unsqueeze(0).to(device)
    img1 = model.encode_image(img1)
    return img1
def generateScore(image1, image2):
    test_img = cv2.imread(image1, cv2.IMREAD_UNCHANGED)
    data_img = cv2.imread(image2, cv2.IMREAD_UNCHANGED)
    img1 = imageEncoder(test_img)
    img2 = imageEncoder(data_img)
    cos_scores = util.pytorch_cos_sim(img1, img2)
    score = round(float(cos_scores[0][0])*100, 2)
    return score
#print(f"similarity Score: ", round(generateScore(image1, image2), 2))


Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-snwt8eeu
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-snwt8eeu
  Resolved https://github.com/openai/CLIP.git to commit a1d071733d7111c9c014f024669f959182114e33
  Preparing metadata (setup.py) ... [?25l- done
[?25hCollecting ftfy (from clip==1.0)
  Obtaining dependency information for ftfy from https://files.pythonhosted.org/packages/91/f8/dfa32d06cfcbdb76bc46e0f5d69c537de33f4cedb1a15cd4746ab45a6a26/ftfy-6.1.3-py3-none-any.whl.metadata
  Downloading ftfy-6.1.3-py3-none-any.whl.metadata (6.2 kB)
Collecting wcwidth<0.3.0,>=0.2.12 (from ftfy->clip==1.0)
  Obtaining dependency information for wcwidth<0.3.0,>=0.2.12 from https://files.pythonhosted.org/packages/31/b1/a59de0ad3aabb17523a39804f4c6df3ae87ead053a4e25362ae03d73d03a/wcwidth-0.2.12-py2.py3-none-any.whl.metadata
  Downloading wcwidth-0.2.12-p

100%|████████████████████████████████████████| 834M/834M [00:08<00:00, 102MiB/s]


In [2]:
class LFWDataset(datasets.ImageFolder):
    def __init__(self, dir, pairs_path, image_size, transform=None, half_face=False, both=False):
        super(LFWDataset, self).__init__(dir, transform)
        self.image_size = image_size
        self.pairs_path = pairs_path
        self.validation_images = self.get_lfw_paths(dir)
        self.half = half_face
        self.both = both

    def read_lfw_pairs(self, pairs_filename):
        pairs = []
        with open(pairs_filename, 'r') as f:
            for line in f.readlines()[1:]:
                pair = line.strip().split()
                pairs.append(pair)
        return np.array(pairs, dtype=object)

    def get_lfw_paths(self, lfw_dir, file_ext="png"):

        pairs = self.read_lfw_pairs(self.pairs_path)

        nrof_skipped_pairs = 0
        path_list = []
        issame_list = []

        for i in range(len(pairs)):
            # for pair in pairs:
            pair = pairs[i]
            if len(pair) == 3:
                path0 = os.path.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[1]) + '.' + file_ext)
                path1 = os.path.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[2]) + '.' + file_ext)
                issame = True
            elif len(pair) == 4:
                path0 = os.path.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[1]) + '.' + file_ext)
                path1 = os.path.join(lfw_dir, pair[2], pair[2] + '_' + '%04d' % int(pair[3]) + '.' + file_ext)
                issame = False
            if os.path.exists(path0) and os.path.exists(path1):  # Only add the pair if both paths exist
                path_list.append((path0, path1, issame))
                issame_list.append(issame)
            else:
                nrof_skipped_pairs += 1
        if nrof_skipped_pairs > 0:
            print('Skipped %d image pairs' % nrof_skipped_pairs)

        return path_list

#     def __getitem__(self, index):
#         (path_1, path_2, issame) = self.validation_images[index]
#         image1, image2 = Image.open(path_1), Image.open(path_2)

#         if self.half:
#             image2 = image2.convert('1')
#             if self.both:
#                 image1 = image1.convert('1')

#         image1 = resize_image(image1, [self.image_size[1], self.image_size[0]], letterbox_image=True)
#         image2 = resize_image(image2, [self.image_size[1], self.image_size[0]], letterbox_image=True)

#         image1, image2 = np.transpose(preprocess_input(np.array(image1, np.float32)), [2, 0, 1]), np.transpose(
#             preprocess_input(np.array(image2, np.float32)), [2, 0, 1])

#         return image1, image2, issame
    def __getitem__(self, index):
        (path_1, path_2, issame) = self.validation_images[index]
        return path_1, path_2, issame

    def __len__(self):
        return len(self.validation_images)

In [3]:
dataset = LFWDataset(dir="/kaggle/input/lfw-codeformer/lfw_codeformer", pairs_path="/kaggle/input/lfwpeople/pairs.txt", image_size=(512, 512))

In [4]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm import tqdm, trange
# Step 1: Calculate scores for all image pairs
scores = []
labels = []
for i in trange(len(dataset)):
    image1_path, image2_path, issame = dataset[i]
    score = generateScore(image1_path, image2_path)
    scores.append(score)
    labels.append(issame)

# Step 2: Calculate metrics for a range of thresholds
best_accuracy = 0
best_thresh = 0
best_precision = 0
best_recall = 0
best_f1 = 0
for thresh in tqdm(np.arange(0, 101)):  # assuming scores are in range 0-100
    predictions = [s > thresh for s in scores]
    accuracy = accuracy_score(labels, predictions)
    precision = precision_score(labels, predictions, zero_division=1)
    recall = recall_score(labels, predictions, zero_division=1)
    f1 = f1_score(labels, predictions, zero_division=1)
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        best_thresh = thresh
        best_precision = precision
        best_recall = recall
        best_f1 = f1

# Step 3: Print the best threshold and corresponding metrics
print(f"Best threshold = {best_thresh}")
print(f"Accuracy = {best_accuracy}")
print(f"Precision = {best_precision}")
print(f"Recall = {best_recall}")
print(f"F1_score = {best_f1}")

100%|██████████| 6000/6000 [05:54<00:00, 16.93it/s]
100%|██████████| 101/101 [00:04<00:00, 22.06it/s]

Best threshold = 58
Accuracy = 0.8318333333333333
Precision = 0.8224813735017816
Recall = 0.8463333333333334
F1_score = 0.8342368983078692



