# Thresholding
This notebook will guide you through finding the optimal threshold hyperparameters for your data and models.

In [1]:
import torch
# from ..utils import *
import utils
import nets_LV
import torch.optim as optim
import torchvision.transforms as transforms
import numpy as np
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from tqdm import tqdm
import copy
from sklearn.metrics import precision_score, accuracy_score, average_precision_score, roc_auc_score
import os
from collections import defaultdict
import json
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmartentyrk[0m ([33mvqvaeanomaly[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])


Set your paths here. Make sure your data is structured as follows:

- data
    - FaceForensics
        - test_set
            - fake
                - 000_003
                    - face_0_0_10.jpg
                    - ...
                - ...
            - real
                - 000
                - ...
         - validation_set
            - ...
    - FFHQ
        - test_set
            - easy
                - 1 # indicating fake
                    - 1_1110.png
                    - ...
                - 0 # indicating real
                    - 51010.png
                    - ...
            - medium
                - ...
            - hard
                - ...
        - validation_set
            - ...
    - Output
        - faceforensics
        - ffhq
    - Checkpoints
        - ffhq_vqvae.pt
        - ffhq_ar.pt
        - faceforensics_vqvae.pt
        - faceforensics_ar.pt

and the checkpoints are named {dataset}_{model}.pt for dataset in ['ffhq', 'faceforensics'] and model in ['vqvae', 'ar']

In [None]:
# CHOOSE DATASET HERE AND SET PATHS AS DESCRIBED, THEN NONE OF THE CODE BELOW NEEDS TO BE CHANGED
data_dir = '../../data'
checkpoint_path = 'checkpoints'
dataset = 'faceforensics' # choose from ['ffhq', 'faceforensics']
# mode = 'sample' # choose from ['sample', 'pixel']
split = 'test' # choose from ['val', 'test']
pruned = False # set to True if you want to use the smaller faceforensics dataset, e.g. for hyperparameter tuning
thresholds = [7] # Only input 1 if testing
difficulties = ['easy', 'medium', 'hard'] # only used for ffhq
assert dataset in ['ffhq', 'faceforensics']
txt_to_label = {'real': 0, 'fake': 1}

if pruned:
  pr="_pruned"
else:
  pr=""

if dataset == "ffhq":
  ffhq_sets = {}
  ffhq_loaders = {}
  for dif in difficulties:
    ffhq_sets[dif]= ImageFolder(os.path.join(data_dir, dataset, split, dif), transform=transform)
    ffhq_loaders[dif] = DataLoader(ffhq_sets[dif], batch_size=64, shuffle=False)


faceforensics data is loaded without imagefolder, see the prediction loop below.

In [5]:
vqvae_cp = torch.load(os.path.join(checkpoint_path, dataset + '_vqvae.pt'), map_location=device)
ar_cp = torch.load(os.path.join(checkpoint_path, dataset + '_ar.pt'), map_location=device)

In [6]:
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
vq_model = nets_LV.VQVAE(
    d=3,
    n_channels=(16, 32, 64, 256),
    code_size=128,
    n_res_block=2,
    dropout_p=.1
).to(device)

vq_model.load_state_dict(vqvae_cp["model"])
vq_model = vq_model.to(device)

ar_model = nets_LV.VQLatentSNAIL(
    feature_extractor_model=vq_model,
    shape=(16, 16),
    n_block=4,
    n_res_block=4,
    n_channels=128
).to(device)


ar_model.load_state_dict(ar_cp['model'])
ar_model = ar_model.to(device)

## Generate anomaly scores for various thresholds

In [None]:
!find ../../data -name ".DS_Store" -type f -delete

In [None]:
print(f"Evaluating {dataset} with thresholds {thresholds}")
input_dir = data_dir + f'/{dataset}'
output_dir = data_dir + '/output'

# bruh moment right here
predictions = defaultdict(lambda: defaultdict())


if dataset == "ffhq":
    for thr in thresholds:
        predictions[thr] = {}
        with torch.no_grad():
          for dif in ['easy', 'hard']:
            predictions[thr][dif] = {'scores':[],'labels':[]}
            for batch, cl in tqdm(ffhq_loaders[dif]):
              loss = ar_model.loss(batch.to(device), reduction="none")["loss"].flatten(1)
              scores = torch.sum(loss * (loss > thr), 1).float()
              predictions[thr][dif]['scores'].append(scores.detach().cpu().numpy().tolist())
              predictions[thr][dif]['labels'].append(cl.detach().cpu().numpy().tolist())
            assert len(predictions[thr][dif]['scores']) == len(predictions[thr][dif]['labels'])
    with open(os.path.join(output_dir, "ffhq", "scores.json"), "w") as write_file:
        json.dump(predictions, write_file)


# mean version
elif dataset == "faceforensics":
   input_dir = os.path.join(data_dir, dataset, f'{split}_set_pruned')
   for thr in thresholds:
      predictions[thr] = {'scores':[],'labels':[]}
      with torch.no_grad():
          for rf in os.listdir(input_dir):
            for img_id in tqdm(os.listdir(os.path.join(input_dir, rf))):
              imgdir = os.listdir(os.path.join(input_dir, rf, img_id)) 
              batch = torch.stack([transform(Image.open(os.path.join(input_dir, rf, img_id, img))) for img in imgdir]).to(device)
              loss = ar_model.loss(batch, reduction="none")["loss"].flatten(1)
              scores = torch.sum(loss * (loss > thr), 1).float()
              score = scores.mean()
              predictions[thr]['scores'].append(score.detach().cpu().numpy().tolist())
              predictions[thr]['labels'].append(txt_to_label[rf])

with open(os.path.join(output_dir, "faceforensics","scores.json"), "w") as write_file:
        json.dump(predictions, write_file)
               

    

## Sample wise score threshold

In [None]:
with open(output_dir + f"/{dataset}/scores.json") as s:
    scores = json.load(s)


if dataset == 'ffhq':
  res = defaultdict(lambda: defaultdict(lambda: {'auroc': None, 'ap': None}))
  for thr in thresholds:
    thr = str(thr)
    for dif in ['easy', 'hard']:
      labels = np.concatenate(scores[thr][dif]['labels'])
      preds = np.concatenate(scores[thr][dif]['scores'])

      auroc = roc_auc_score(labels,preds)
      ap = average_precision_score(labels,preds, pos_label=1)

      res[thr][dif]['auroc'] = auroc
      res[thr][dif]['ap'] = ap

elif dataset == 'faceforensics':
  res = defaultdict(lambda: {'auroc' : None, 'ap': None})
  for thr in thresholds:
    thr = str(thr)
    labels = scores[thr]['labels']
    preds = scores[thr]['scores']

    auroc = roc_auc_score(labels,preds)
    ap = average_precision_score(labels,preds, pos_label=1)

    res[thr]['auroc'] = auroc
    res[thr]['ap'] = ap

for thr in thresholds:
  print(res[str(thr)].items())

Make sure to save the best threshold!

# Find pixel score threshold

In [None]:
def reconstruct(n, img, threshold_log_p = 5):
    """ Generates n reconstructions for each image in img.
    Resamples latent variables with cross-entropy > threshold
    Returns corrected images and associated latent variables"""
          
    #Use VQ-VAE to encode original image
    codes = ar_model.retrieve_codes(img)
    code_size = codes.shape[-2:]
    

    with torch.no_grad():
        samples = codes.clone().unsqueeze(1).repeat(1,n,1,1).reshape(img.shape[0]*n,*code_size)

        if not threshold_log_p == None:
            for r in tqdm(range(code_size[0])):
                for c in range(code_size[1]):        

                    code_logits = ar_model.forward_latent(samples)[:,:,r,c]
                    loss = F.cross_entropy(code_logits, samples[:, r, c], reduction='none')
                    probs = F.softmax(code_logits, dim=1)

                    samples[loss > threshold_log_p, r, c] = torch.multinomial(probs, 1).squeeze(-1)[loss > threshold_log_p]

        z = vq_model.codebook.embedding(samples.unsqueeze(1))
        z = z.squeeze(1).permute(0,3,1,2).contiguous()
        
        # Split the calculation in batches
        x_tilde = []
        for i in range(img.shape[0]):
            x_tilde.append(vq_model.decode(z[i*n:(i+1)*n]))
        x_tilde = torch.cat(x_tilde)
        
        
    return x_tilde.reshape(img.shape[0]*img.shape[1],n,*img.shape[-2:]), samples.reshape(img.shape[0],n,*code_size)

reconstructions = []