In [1]:
from skibidi_face_detector.dataset.small_celebrities import train_loader, test_loader
from skibidi_face_detector.face_embedder.Model import Model
from sklearn.neighbors import KNeighborsClassifier
from tqdm import tqdm
import torch
from sklearn.metrics import accuracy_score
from numpy import dot
from numpy.linalg import norm
from resources.model import MODEL_FILE, MODEL_PARAMS

In [2]:
model = Model.load_from_checkpoint(MODEL_FILE, **MODEL_PARAMS)
model.transformer = None
model.augments = None
model.eval()

Model(
  (embedder): Embedder(
    (feature_extractor): Sequential(
      (0): VggFace2(
        (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
        (layer1): Sequential(
          (0): Bottleneck(
            (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True,

In [3]:
X = []
Y = []

for batch in tqdm(train_loader):
    x, y = model.transform_batch(batch)
    with torch.inference_mode():
        embeddings = model(x)
    
    X.append(embeddings.cpu())
    Y.append(y.cpu())

100%|██████████| 75/75 [04:49<00:00,  3.87s/it]


In [4]:
Xs = torch.cat(X)
Ys = torch.cat(Y)

In [5]:
def cosine_distance(a, b):
    return 1 - dot(a, b)/(norm(a)*norm(b))

knn = KNeighborsClassifier(n_neighbors=1, metric=cosine_distance)
knn.fit(Xs, Ys)

In [6]:
X_test = []
Y_test = []

for batch in tqdm(test_loader):
    x, y = model.transform_batch(batch)
    with torch.inference_mode():
        embeddings = model(x)

    X_test.append(embeddings.cpu())
    Y_test.append(y.cpu())

100%|██████████| 19/19 [01:18<00:00,  4.13s/it]


In [7]:
Xs_test = torch.cat(X_test)
Ys_test = torch.cat(Y_test)

In [8]:
predictions = knn.predict(Xs_test)

In [9]:
accuracy_score(predictions, Ys_test)

0.968013468013468