In [22]:
import net
import torch
import os
from face_alignment import align
import numpy as np
from tqdm import tqdm

In [23]:
adaface_models = {
    'ir_101':"./models/adaface_ir101_webface12m.ckpt",
}

def load_pretrained_model(architecture='ir_101'):
    # load model and pretrained statedict
    assert architecture in adaface_models.keys()
    model = net.build_model(architecture)
    statedict = torch.load(adaface_models[architecture])['state_dict']
    model_statedict = {key[6:]:val for key, val in statedict.items() if key.startswith('model.')}
    model.load_state_dict(model_statedict)
    model = model.to('cuda')
    model.eval()
    return model
model = load_pretrained_model()


In [24]:
races = ['African','Asian','Caucasian','Indian']
genders = ['Man','Woman']

In [25]:
African_wom = os.listdir('../images/test/data/African/Woman')
Caucasian_wom = os.listdir('../images/test/data/Caucasian/Woman')
Asian_wom = os.listdir('../images/test/data/Asian/Woman')
Indian_wom = os.listdir('../images/test/data/Indian/Woman')
Asian_man = os.listdir('../images/test/data/Asian/Man')
African_man = os.listdir('../images/test/data/African/Man')
Caucasian_man = os.listdir('../images/test/data/Caucasian/Man')
Indian_man = os.listdir('../images/test/data/Indian/Man')

In [26]:
def to_input(pil_rgb_image):
    np_img = np.array(pil_rgb_image)
    brg_img = ((np_img/ 255.) - 0.5) / 0.5
    tensor = torch.tensor(np.array([brg_img.transpose(2,0,1)])).float()
    return tensor

In [27]:
cache = {}

In [28]:
for race in races:
    with open(f"../images/test/txts/{race}/{race}_pairs.txt", 'r') as f:
        pairs = f.readlines()
    pairs = [pair.strip().split() for pair in pairs]
    for pair in tqdm(pairs):
        try:
            if pair[0] in eval(f'{race}_wom'):
                    gender = "Woman"
            else:
                    gender = "Man"
            sim = 0
            img1 = f"../images/test/data/Just_Images/{race}/{gender}/{pair[0]}_000{pair[1]}.jpg"
            if len(pair) == 3:
                img2 = f"../images/test/data/Just_Images/{race}/{gender}/{pair[0]}_000{pair[2]}.jpg"
                sim = 1
            else:
                if pair[2] in African_wom or pair[2] in Caucasian_wom or pair[2] in Asian_wom or pair[2] in Indian_wom:
                    gen2 = "Woman"
                    if pair[2] in African_wom:
                        race2 = "African"
                    elif pair[2] in Caucasian_wom:
                        race2 = "Caucasian"
                    elif pair[2] in Asian_wom:
                        race2 = "Asian"
                    else:
                        race2 = "Indian"
                else:
                    gen2 = "Man"
                    if pair[2] in African_man:
                        race2 = "African"
                    elif pair[2] in Caucasian_man:
                        race2 = "Caucasian"
                    elif pair[2] in Asian_man:
                        race2 = "Asian"
                    else:
                        race2 = "Indian"
                img2 = f"../images/test/data/Just_Images/{race2}/{gen2}/{pair[2]}_000{pair[3]}.jpg"
            if img1 in cache:
                 vec1 = cache[img1]
            else:
                with torch.no_grad():
                    vec1, _ = model(to_input(align.get_aligned_face(img1)).to('cuda'))
                vec1 = vec1.cpu().detach().numpy()
                vec1/=np.linalg.norm(vec1)
                cache[img1] = vec1
            if img2 in cache:
                vec2 = cache[img2]
            else:
                with torch.no_grad():
                    vec2, _ = model(to_input(align.get_aligned_face(img2)).to('cuda'))           
                vec2 = vec2.cpu().detach().numpy()            
                vec2/=np.linalg.norm(vec2)
                cache[img2] = vec2
            score = np.dot(vec1, vec2.T)[0][0]
            with open(f"./{race}_{gender}.csv", 'a') as f:
                f.write(f"{img1},{img2},{score},{sim}\n")
        except Exception as e:
            print(pair, e)
            continue

100%|██████████| 6000/6000 [02:01<00:00, 49.42it/s]
100%|██████████| 6000/6000 [01:49<00:00, 54.65it/s] 
100%|██████████| 6000/6000 [06:38<00:00, 15.07it/s]
 25%|██▌       | 1520/6000 [01:58<05:47, 12.91it/s]

['m.02r7h1g', '2', 'm.0bk56n', '5'] [Errno 2] No such file or directory: '../images/test/data/Just_Images/African/Man/m.0bk56n_0005.jpg'


100%|██████████| 6000/6000 [06:26<00:00, 15.54it/s]
