In [1]:
import os
import re
#import tensorrt as trt
import cv2
import traceback
import numpy as np
import time
import torch
from PIL import Image
from tqdm import tqdm
import time
import glob
import grpc
#import tritonclient.grpc as grpcclient
from albumentations.augmentations.functional import image_compression
from training.zoo.classifiers import DeepFakeClassifier, DeepFakeClassifierWithViT
from facenet_pytorch.models.mtcnn import MTCNN
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import torchvision.transforms as transforms
from torchvision.transforms import Normalize
#TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
normalize_transform = Normalize(mean, std)
DeNormalize = transforms.Compose([  Normalize(mean = [0., 0., 0. ], std = [1/0.229, 1/0.224, 1/0.225]),
                                    Normalize(mean = [-0.485, -0.456,-0.406], std = [1., 1., 1.]), ])
cuda = torch.device(2)
def put_to_center(img, input_size):
    img = img[:input_size, :input_size]
    image = np.zeros((input_size, input_size, 3), dtype=np.uint8)
    start_w = (input_size - img.shape[1]) // 2
    start_h = (input_size - img.shape[0]) // 2
    image[start_h:start_h + img.shape[0], start_w: start_w + img.shape[1], :] = img
    return image



def isotropically_resize_image(img, size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC):
    h, w = img.shape[:2]
    if max(w, h) == size:
        return img
    if w > h:
        scale = size / w
        h = h * scale
        w = size
    else:
        scale = size / h
        w = w * scale
        h = size
    interpolation = interpolation_up if scale > 1 else interpolation_down
    resized = cv2.resize(img, (int(w), int(h)), interpolation=interpolation)
    return resized

In [2]:
detector = MTCNN(margin=0, thresholds=[0.7, 0.8, 0.8], device=cuda)
model = DeepFakeClassifier(encoder="resnest269e").to(cuda)
checkpoint = torch.load("../weights/resnest269rec/resnest269rec_999_DeepFakeClassifier_resnest269e_0_best_dice")
state_dict = checkpoint.get("state_dict", checkpoint)
model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=False)
model.eval()
model = model.half()
del checkpoint

In [3]:
img_path_list = sorted(glob.glob('../../faceforensics_benchmark_images/*.png'))
file_name = []
prob = []
label = []

with torch.no_grad():
    n = 0
    for i, img_path in enumerate(img_path_list):
        if i % 100 == 0:
            print(f"{i} of 1000")
        img = Image.open(img_path)
        frame = np.array(img)
        img = img.resize(size=[s // 2 for s in img.size])

        batch_boxes, probs = detector.detect(img, landmarks=False)
        if batch_boxes is None:
            file_name.append(img_path.split('/')[-1])
            prob.append(0.5)
            label.append('real')
            continue

        faces = []
        scores = []

        for bbox, score in zip(batch_boxes, probs):
            if bbox is not None:
                xmin, ymin, xmax, ymax = [int(b * 2) for b in bbox]
                w = xmax - xmin
                h = ymax - ymin
                p_h = h // 3
                p_w = w // 3
                crop = frame[max(ymin - p_h, 0):ymax + p_h, max(xmin - p_w, 0):xmax + p_w]
                faces.append(crop)
                scores.append(score)

        batch_size=len(faces)
        input_size=416
        x = np.zeros((batch_size,input_size,input_size, 3), dtype=np.uint8)

        n = 0
        for face in faces:
            resized_face = isotropically_resize_image(face, input_size)
            resized_face = put_to_center(resized_face, input_size)
            x[n] = resized_face
            n += 1

        x = torch.tensor(x, device=cuda).float()

        x = x.permute((0, 3, 1, 2))
        for i in range(len(x)):
            x[i] = normalize_transform(x[i] / 255.)
        y_pred = model(x.half())
        y_pred = torch.sigmoid(y_pred.squeeze())
        y_pred = y_pred.cpu().numpy()
        y_pred = y_pred.mean()
        file_name.append(img_path.split('/')[-1])
        prob.append(y_pred)
        if(y_pred >= 0.5):
            label.append('fake')
        else:
            label.append('real')

0 of 1000
100 of 1000
200 of 1000
300 of 1000
400 of 1000
500 of 1000
600 of 1000
700 of 1000
800 of 1000
900 of 1000


In [4]:
print(file_name)
print(label)
print(prob)

['0000.png', '0001.png', '0002.png', '0003.png', '0004.png', '0005.png', '0006.png', '0007.png', '0008.png', '0009.png', '0010.png', '0011.png', '0012.png', '0013.png', '0014.png', '0015.png', '0016.png', '0017.png', '0018.png', '0019.png', '0020.png', '0021.png', '0022.png', '0023.png', '0024.png', '0025.png', '0026.png', '0027.png', '0028.png', '0029.png', '0030.png', '0031.png', '0032.png', '0033.png', '0034.png', '0035.png', '0036.png', '0037.png', '0038.png', '0039.png', '0040.png', '0041.png', '0042.png', '0043.png', '0044.png', '0045.png', '0046.png', '0047.png', '0048.png', '0049.png', '0050.png', '0051.png', '0052.png', '0053.png', '0054.png', '0055.png', '0056.png', '0057.png', '0058.png', '0059.png', '0060.png', '0061.png', '0062.png', '0063.png', '0064.png', '0065.png', '0066.png', '0067.png', '0068.png', '0069.png', '0070.png', '0071.png', '0072.png', '0073.png', '0074.png', '0075.png', '0076.png', '0077.png', '0078.png', '0079.png', '0080.png', '0081.png', '0082.png', '00

In [5]:
import pandas as pd
a = pd.DataFrame()

In [6]:
a['file_name'] = file_name
a['label'] = label
a['score'] = prob
a

Unnamed: 0,file_name,label,score
0,0000.png,fake,0.983398
1,0001.png,real,0.017853
2,0002.png,real,0.007935
3,0003.png,real,0.036987
4,0004.png,fake,0.982422
...,...,...,...
995,0995.png,fake,0.971191
996,0996.png,real,0.043610
997,0997.png,fake,0.981934
998,0998.png,real,0.008186


In [7]:
freq = a.groupby(['label']).count()
print(freq)

       file_name  score
label                  
fake         311    311
real         689    689


In [8]:
detector = MTCNN(margin=0, thresholds=[0.7, 0.8, 0.8], device=cuda)
model = DeepFakeClassifierWithViT(encoder="deit_base_patch16_384").to(cuda)
checkpoint = torch.load("weights/deitb_384_1")
state_dict = checkpoint.get("state_dict", checkpoint)
model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=False)
model.eval()
#model = model.half()
del checkpoint

In [9]:
img_path_list = sorted(glob.glob('../../faceforensics_benchmark_images/*.png'))
file_name = []
prob = []
label = []

with torch.no_grad():
    n = 0
    for i, img_path in enumerate(img_path_list):
        if i % 100 == 0:
            print(f"{i} of 1000")
        img = Image.open(img_path)
        frame = np.array(img)
        img = img.resize(size=[s // 2 for s in img.size])

        batch_boxes, probs = detector.detect(img, landmarks=False)
        if batch_boxes is None:
            file_name.append(img_path.split('/')[-1])
            prob.append(0.5)
            label.append('real')
            continue

        faces = []
        scores = []

        for bbox, score in zip(batch_boxes, probs):
            if bbox is not None:
                xmin, ymin, xmax, ymax = [int(b * 2) for b in bbox]
                w = xmax - xmin
                h = ymax - ymin
                p_h = h // 3
                p_w = w // 3
                crop = frame[max(ymin - p_h, 0):ymax + p_h, max(xmin - p_w, 0):xmax + p_w]
                faces.append(crop)
                scores.append(score)

        batch_size=len(faces)
        input_size=384
        x = np.zeros((batch_size,input_size,input_size, 3), dtype=np.uint8)
        #print(x.shape)
        n = 0
        for face in faces:
            resized_face = isotropically_resize_image(face, input_size)
            resized_face = put_to_center(resized_face, input_size)
            x[n] = resized_face
            n += 1

        x = torch.tensor(x, device=cuda).float()

        x = x.permute((0, 3, 1, 2))
        for i in range(len(x)):
            x[i] = normalize_transform(x[i] / 255.)
        #print(x.shape)
        y_pred = model(x)
        y_pred = torch.sigmoid(y_pred.squeeze())
        y_pred = y_pred.cpu().numpy()
        y_pred = y_pred.mean()
        file_name.append(img_path.split('/')[-1])
        prob.append(y_pred)
        if(y_pred >= 0.5):
            label.append('fake')
        else:
            label.append('real')

0 of 1000
100 of 1000
200 of 1000
300 of 1000
400 of 1000
500 of 1000
600 of 1000
700 of 1000
800 of 1000
900 of 1000


In [10]:
b = pd.DataFrame()
b['file_name'] = file_name
b['label'] = label
b['score'] = prob
b

Unnamed: 0,file_name,label,score
0,0000.png,real,0.468024
1,0001.png,real,0.007693
2,0002.png,real,0.011367
3,0003.png,real,0.013138
4,0004.png,fake,0.980663
...,...,...,...
995,0995.png,fake,0.991924
996,0996.png,fake,0.979821
997,0997.png,fake,0.987997
998,0998.png,real,0.027533


In [11]:
freq = b.groupby(['label']).count()
print(freq)

       file_name  score
label                  
fake         370    370
real         630    630


In [12]:
result = dict()

In [13]:
for i in range(len(file_name)):
    result[file_name[i]] = label[i]

In [14]:
print(a)

    file_name label     score
0    0000.png  fake  0.983398
1    0001.png  real  0.017853
2    0002.png  real  0.007935
3    0003.png  real  0.036987
4    0004.png  fake  0.982422
..        ...   ...       ...
995  0995.png  fake  0.971191
996  0996.png  real  0.043610
997  0997.png  fake  0.981934
998  0998.png  real  0.008186
999  0999.png  fake  0.983398

[1000 rows x 3 columns]


In [15]:
import json
with open("benchmark_deit.json", "w") as outfile:
    json.dump(result,outfile)

In [16]:
c = pd.DataFrame()

In [17]:
import math
p = 0.644
w = math.log(p/(1-p))
print(w)
p2 = 0.620
w2 = math.log(p2/(1-p2))
print(w2)
r_w = w/(w+w2)*1
r_w2 = w2/(w+w2)*1
print(r_w)
print(r_w2)

0.5927679952523233
0.4895482253187058
0.5476846636739672
0.4523153363260329


In [69]:
import math
p = 0.636
w = math.log(p/(1-p))
print(w)
p2 = 0.836
w2 = math.log(p2/(1-p2))
print(w2)
r_w = w/(w+w2)
r_w2 = w2/(w+w2)
print(r_w)
print(r_w2)

0.5580446957033814
1.6287621852605028
0.28070570411007295
0.8192942958899271


In [70]:
c['file_name'] = a['file_name']
c['score'] = a['score'] *r_w + b['score'] * r_w2

In [71]:
ensembled_label = []
for i in c['score']:
    if i > 0.5:
        ensembled_label.append('fake')
    else:
        ensembled_label.append('real')

In [72]:
len(ensembled_label)

1000

In [73]:
ensembled_result = dict()

In [74]:
for i in range(len(file_name)):
    ensembled_result[file_name[i]] = ensembled_label[i]

In [75]:
with open("benchmark_ensembled3.json", "w") as outfile:
    json.dump(ensembled_result,outfile)