# Thresholding
This notebook will guide you through evaluating your models in order to find either the optimal threshold $\lambda_s$, or to obtain final results. Before running this notebook, make sure your data is structured as described in readme.md, and the checkpoints are named {dataset}_{model}.pt (for dataset in ['ffhq', 'faceforensics'] and model in ['vqvae', 'ar']).

In [4]:
import torch
# from ..utils import *
import utils
import nets_LV
import torch.optim as optim
import torchvision.transforms as transforms
import numpy as np
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from tqdm import tqdm
import copy
from sklearn.metrics import precision_score, accuracy_score, average_precision_score, roc_auc_score
import os
from collections import defaultdict
import json
from PIL import Image

In [5]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

txt_to_label = {'real': 0, 'fake': 1}


Define your paths and evaluation parameters here. If you want to find the optimal threshold, use the 'val' split and define thresholds as a list (e.g.: [5,6,7,8,9,10]). Set pruned to True to speed up the process. If you want to evaluate final results, use the 'test' split and define only one threshold as follows [$\lambda$]. Set pruned to false when testing.

In [9]:
# CHOOSE DATASET HERE AND SET PATHS AS DESCRIBED, THEN NONE OF THE CODE BELOW NEEDS TO BE CHANGED
data_dir = '../../data'
checkpoint_path = 'checkpoints'
dataset = 'faceforensics' # choose from ['ffhq', 'faceforensics']
# mode = 'sample' # choose from ['sample', 'pixel']
split = 'test' # choose from ['val', 'test']
pruned = False # set to True if you want to use the smaller faceforensics dataset, e.g. for hyperparameter tuning
thresholds = [7] # Only input 1 if testing
difficulties = ['easy', 'medium', 'hard'] # only used for ffhq

assert split in ['val', 'test']
assert dataset in ['ffhq', 'faceforensics']
assert set(difficulties).issubset(set(['easy', 'medium', 'hard']))

if pruned:
  pr="_pruned"
else:
  pr=""

# For FFHQ, we create dataloaders here. For FaceForensics, data is loaded directly from the filesystem later.
if dataset == "ffhq":
  ffhq_sets = {}
  ffhq_loaders = {}
  for dif in difficulties:
    ffhq_sets[dif]= ImageFolder(os.path.join(data_dir, dataset, split, dif), transform=transform)
    ffhq_loaders[dif] = DataLoader(ffhq_sets[dif], batch_size=64, shuffle=False)


In [10]:
vqvae_cp = torch.load(os.path.join(checkpoint_path, dataset + '_vqvae.pt'), map_location=device)
ar_cp = torch.load(os.path.join(checkpoint_path, dataset + '_ar.pt'), map_location=device)

In [11]:
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
vq_model = nets_LV.VQVAE(
    d=3,
    n_channels=(16, 32, 64, 256),
    code_size=128,
    n_res_block=2,
    dropout_p=.1
).to(device)

vq_model.load_state_dict(vqvae_cp["model"])
vq_model = vq_model.to(device)

ar_model = nets_LV.VQLatentSNAIL(
    feature_extractor_model=vq_model,
    shape=(16, 16),
    n_block=4,
    n_res_block=4,
    n_channels=128
).to(device)


ar_model.load_state_dict(ar_cp['model'])
ar_model = ar_model.to(device)

## Generate anomaly scores for various thresholds
Here we use the vqvae model to make representations of images in the val/test set. These representations are assigned likelihoods by the ar-model, which are used to calculate anomaly scores. See the methods section of the blogpost for more information.

In [13]:
# If you are using a Mac, run this command to delete the .DS_Store files
# !find ../../data -name ".DS_Store" -type f -delete

In [None]:
print(f"Evaluating {dataset} with thresholds {thresholds}")
input_dir = data_dir + f'/{dataset}'
output_dir = data_dir + '/output'

# bruh moment right here
predictions = defaultdict(lambda: defaultdict())


if dataset == "ffhq":
    for thr in thresholds:
        predictions[thr] = {}
        with torch.no_grad():
          for dif in difficulties:
            predictions[thr][dif] = {'scores':[],'labels':[]}
            for batch, cl in tqdm(ffhq_loaders[dif]):
              loss = ar_model.loss(batch.to(device), reduction="none")["loss"].flatten(1)
              scores = torch.sum(loss * (loss > thr), 1).float()
              predictions[thr][dif]['scores'].append(scores.detach().cpu().numpy().tolist())
              predictions[thr][dif]['labels'].append(cl.detach().cpu().numpy().tolist())
            assert len(predictions[thr][dif]['scores']) == len(predictions[thr][dif]['labels'])
    with open(os.path.join(output_dir, "ffhq", "scores.json"), "w") as write_file:
        json.dump(predictions, write_file)


# mean version
elif dataset == "faceforensics":
   input_dir = os.path.join(data_dir, dataset, f'{split}_set_pruned')
   for thr in thresholds:
      predictions[thr] = {'scores':[],'labels':[]}
      with torch.no_grad():
          for rf in os.listdir(input_dir):
            for img_id in tqdm(os.listdir(os.path.join(input_dir, rf))):
              imgdir = os.listdir(os.path.join(input_dir, rf, img_id)) 
              batch = torch.stack([transform(Image.open(os.path.join(input_dir, rf, img_id, img))) for img in imgdir]).to(device)
              loss = ar_model.loss(batch, reduction="none")["loss"].flatten(1)
              scores = torch.sum(loss * (loss > thr), 1).float()
              score = scores.mean()
              predictions[thr]['scores'].append(score.detach().cpu().numpy().tolist())
              predictions[thr]['labels'].append(txt_to_label[rf])

with open(os.path.join(output_dir, "faceforensics","scores.json"), "w") as write_file:
        json.dump(predictions, write_file)
               

    

## Evaluate results per threshold

In [None]:
with open(output_dir + f"/{dataset}/scores.json") as s:
    scores = json.load(s)


if dataset == 'ffhq':
  res = defaultdict(lambda: defaultdict(lambda: {'auroc': None, 'ap': None}))
  for thr in thresholds:
    thr = str(thr)
    for dif in ['easy', 'hard']:
      labels = np.concatenate(scores[thr][dif]['labels'])
      preds = np.concatenate(scores[thr][dif]['scores'])

      auroc = roc_auc_score(labels,preds)
      ap = average_precision_score(labels,preds, pos_label=1)

      res[thr][dif]['auroc'] = auroc
      res[thr][dif]['ap'] = ap

elif dataset == 'faceforensics':
  res = defaultdict(lambda: {'auroc' : None, 'ap': None})
  for thr in thresholds:
    thr = str(thr)
    labels = scores[thr]['labels']
    preds = scores[thr]['scores']

    auroc = roc_auc_score(labels,preds)
    ap = average_precision_score(labels,preds, pos_label=1)

    res[thr]['auroc'] = auroc
    res[thr]['ap'] = ap

for thr in thresholds:
  print(f"threshold = {thr}")
  print(res[str(thr)].items())