In [None]:
from PIL import Image
from tqdm.auto import tqdm
from numpy import asarray
import math
import torch
from utils import utils, verification
import os

In [None]:
def check(target_FRS='t1', source_FRS='AWS', dataset='LFW', device='cuda:0', target_bin=False):

    # Valid source_FRS (string): t1, t2, t3, t4, t5, AWS, KAIROS, FACEPP
    # Valid target_FRS (string): t1, t2, t3, t4, t5
    # Valid dataset    (string): LFW, AGE, CFP
    
    print("Source FRS : "+source_FRS)
    print("Target FRS : "+target_FRS)
    print("Target Dataset : "+dataset)
    
    # You can download png files including target images and corresponding reconstructed image using Zenodo link in README file.
    # There are small ASR gap in terms of input image files (bin file or png file).
    
    device = torch.device(device)
    blackbox = utils.target_FRS(target_FRS=target_FRS, device=device)
    img_size = (112,112)
    
    if dataset=='LFW':
        target_number = 3000
        if target_FRS=='t1':
            thx=72.54239687627792
        elif target_FRS=='t2':
            thx=77.29096700560457
        elif target_FRS=='t3':
            thx=74.33573314862649
        elif target_FRS=='t4':
            thx=75.81816627143655
        elif target_FRS=='t5':
            thx=63.57666123410324
            img_size = (160,160)
            
        if target_bin:
            print("Using bin file instead of png image")
            dataset_dir = "./utils/dataset/lfw.bin"
            target_dataset = verification.load_bin(dataset_dir, img_size)
            target_ind = torch.where(torch.tensor(target_dataset[1]))[0]
            
    elif dataset=="AGE":
        target_number = 3000
        if target_FRS=='t1':
            thx=76.40837722605214
        elif target_FRS=='t2':
            thx=80.50276553148295
        elif target_FRS=='t3':
            thx=79.04721580110888
        elif target_FRS=='t4':
            thx=78.75527640762357
        elif target_FRS=='t5':
            thx=70.42746205557108
            img_size = (160,160)
            
        if target_bin:
            print("Using bin file instead of png image")
            dataset_dir = "./utils/dataset/agedb_30.bin"
            target_dataset = verification.load_bin(dataset_dir, img_size)
            target_ind = torch.where(torch.tensor(target_dataset[1]))[0]
    
    elif dataset=="CFP":
        target_number = 3500
        if target_FRS=='t1':
            thx=77.29096700560457
        elif target_FRS=='t2':
            thx=81.08320374700904
        elif target_FRS=='t3':
            thx=79.63024019452259
        elif target_FRS=='t4':
            thx=81.08320374700904
        elif target_FRS=='t5':
            thx=73.73979529168804
            img_size = (160,160)
            
        if target_bin:
            print("Using bin file instead of png image")
            dataset_dir = "./utils/dataset/cfp_fp.bin"
            target_dataset = verification.load_bin(dataset_dir, img_size)
            target_ind = torch.where(torch.tensor(target_dataset[1]))[0]
    
    result = torch.zeros(2,target_number)

    for i in tqdm(range(target_number)):
        if target_bin:
            if target_FRS == 't2':
                data_Type1 = (target_dataset[0][0][0::2][int(target_ind[i])]).unsqueeze(0).to(device)
                data_Type2 = (target_dataset[0][0][1::2][int(target_ind[i])]).unsqueeze(0).to(device)
            else:
                data_Type1 = ((target_dataset[0][0][0::2][int(target_ind[i])]-127.5)/255).unsqueeze(0).to(device)
                data_Type2 = ((target_dataset[0][0][1::2][int(target_ind[i])]-127.5)/255).unsqueeze(0).to(device)
        else:
            img_Type1 = Image.open("./recon/"+dataset+"/Type1_"+dataset+"/"+str(i)+".png")
            img_Type1 = img_Type1.resize(img_size)
            img_Type1 = img_Type1.convert('RGB')
            data_Type1 = asarray(img_Type1)
            if target_FRS == 't2':
                data_Type1 = ((torch.Tensor(data_Type1))).permute(2,0,1).unsqueeze(0).to(device)
            else:
                data_Type1 = ((torch.tensor(data_Type1)-127.5)/255).permute(2,0,1).unsqueeze(0).to(device)

            img_Type2 = Image.open("./recon/"+dataset+"/Type2_"+dataset+"/"+str(i)+".png")
            img_Type2 = img_Type2.resize(img_size)
            img_Type2 = img_Type2.convert('RGB')
            data_Type2 = asarray(img_Type2)
            if target_FRS == 't2':
                data_Type2 = ((torch.Tensor(data_Type2))).permute(2,0,1).unsqueeze(0).to(device)
            else:
                data_Type2 = ((torch.tensor(data_Type2)-127.5)/255).permute(2,0,1).unsqueeze(0).to(device)


        img_recon = Image.open("./recon/"+dataset+"/"+source_FRS+"_"+dataset+"/"+str(i)+".png")
        img_recon = img_recon.resize(img_size)
        img_recon = img_recon.convert('RGB')
        data_recon = asarray(img_recon)
        if target_FRS == 't2':
            data_recon = ((torch.Tensor(data_recon))).permute(2,0,1).unsqueeze(0).to(device)
        else:
            data_recon = ((torch.tensor(data_recon)-127.5)/255).permute(2,0,1).unsqueeze(0).to(device)


        with torch.no_grad():
            feat_Type1 = blackbox(data_Type1)
            feat_Type2 = blackbox(data_Type2)
            feat_recon = blackbox(data_recon)

            feat_Type1 = feat_Type1/feat_Type1.norm()
            feat_Type2 = feat_Type2/feat_Type2.norm()
            feat_recon = feat_recon/feat_recon.norm()

            result[0,i]=(torch.acos((feat_Type1*feat_recon).sum())*180/math.pi).to("cpu")
            result[1,i]=(torch.acos((feat_Type2*feat_recon).sum())*180/math.pi).to("cpu")

    print("Type-1 ASR : "+str(round(float(100*(result[0,:] <= thx).sum(0)/target_number),2))+"%")
    print("Type-2 ASR : "+str(round(float(100*(result[1,:] <= thx).sum(0)/target_number),2))+"%")

    return result

In [None]:
# Table 8 Direct attacks
def check_direct(target='t1',target_bin=False):
    tab_lfw = check(target,target,'LFW','cuda:0',target_bin)
    tab_age = check(target,target,'AGE','cuda:0',target_bin)
    tab_cfp = check(target,target,'CFP','cuda:0',target_bin)
    return [tab_lfw,tab_age,tab_cfp]

In [None]:
# Table 8
dir_t1 = check_direct('t1')
dir_t2 = check_direct('t2')
dir_t3 = check_direct('t3')
dir_t4 = check_direct('t4')
dir_t5 = check_direct('t5')

In [1]:
# Table 10 and 11 transfer attacks (against non-commercial targets)
# Note that paper only reports for LFW, but we can run for all 3 datasets
sources = ['t1','t2','t3','t4','t5','AWS','FACEPP','KAIROS']
def check_transfer(target='t1',dataset='LFW',target_bin=False):
    return [check(target,source,dataset,'cuda:0',target_bin) for source in sources if source !=target]

In [None]:
# Table 10 (bin)
trans_t1 = check_transfer('t1',target_bin=True)
trans_t2 = check_transfer('t2',target_bin=True)
trans_t3 = check_transfer('t3',target_bin=True)
trans_t4 = check_transfer('t4',target_bin=True)
trans_t5 = check_transfer('t5',target_bin=True)

In [None]:
# Table 10 (png)
trans_t1_nb = check_transfer('t1',target_bin=False)
trans_t2_nb = check_transfer('t2',target_bin=False)
trans_t3_nb = check_transfer('t3',target_bin=False)
trans_t4_nb = check_transfer('t4',target_bin=False)
trans_t5_nb = check_transfer('t5',target_bin=False)