# Thresholding
This notebook will guide you through finding the optimal threshold hyperparameters for your data and models.

In [1]:
import torch
# from ..utils import *
import utils
import nets_LV
import torch.optim as optim
import torchvision.transforms as transforms
import numpy as np
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from tqdm import tqdm
import copy
from sklearn.metrics import precision_score, accuracy_score, average_precision_score, roc_auc_score
import os
from collections import defaultdict
import json
from PIL import Image

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjonas-van-elburg[0m ([33mvqvaeanomaly[0m). Use [1m`wandb login --relogin`[0m to force relogin
  Referenced from: '/Users/jonase/opt/miniconda3/envs/lsr_mood/lib/python3.9/site-packages/torchvision/image.so'
  Expected in: '/Users/jonase/opt/miniconda3/envs/lsr_mood/lib/python3.9/site-packages/torch/lib/libtorch_cpu.dylib'
  warn(f"Failed to load image Python extension: {e}")


In [2]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])


Set your paths here. Make sure your data is structured as follows:

- data
    - FaceForensics
        - test_set
            - fake
                - 000_003
                    - face_0_0_10.jpg
                    - ...
                - ...
            - real
                - 000
                - ...
         - validation_set
            - ...
    - FFHQ
        - test_set
            - easy
                - 1 # indicating fake
                    - 1_1110.png
                    - ...
                - 0 # indicating real
                    - 51010.png
                    - ...
            - medium
                - ...
            - hard
                - ...
        - validation_set
            - ...
    - Output
        - faceforensics
        - ffhq
    - Checkpoints
        - ffhq_vqvae.pt
        - ffhq_ar.pt
        - faceforensics_vqvae.pt
        - faceforensics_ar.pt

and the checkpoints are named {dataset}_{model}.pt for dataset in ['ffhq', 'faceforensics'] and model in ['vqvae', 'ar']

In [13]:
# CHOOSE DATASET HERE AND SET PATHS AS DESCRIBED, THEN NONE OF THE CODE BELOW NEEDS TO BE CHANGED
data_dir = '../../data'
checkpoint_path = 'checkpoints'
dataset = 'faceforensics' # choose from ['ffhq', 'faceforensics']
mode = 'sample' # choose from ['sample', 'pixel']
split = 'val' # choose from ['val', 'test']
test = False # choose from [True, False]. If False, the test set is used
thresholds = [2, 4, 6, 8] # Only input 1 if testing
difficulties = ['easy', 'medium', 'hard'] # only used for ffhq
assert dataset in ['ffhq', 'faceforensics']
txt_to_label = {'real': 0, 'fake': 1}

if dataset == 'faceforensics':
    ff_sets = {}
    ff_loaders = {}
    for fr in ['real', 'fake']:
      ff_sets[fr] = ImageFolder(os.path.join(data_dir, dataset, f'{split}_set_pruned', fr), transform=transform)
      ff_loaders[fr] = DataLoader(ff_sets[fr], batch_size=64, shuffle=False)

elif dataset == "ffhq":
    ffhq_test_sets = {}
    ffhq_test_loaders = {}
    for dif in difficulties:
      ffhq_test_sets[dif]= ImageFolder(os.path.join(data_dir, dataset, split, dif), transform=transform)
      ffhq_test_loaders[dif] = DataLoader(ffhq_test_sets[dif], batch_size=64, shuffle=False)


In [14]:
vqvae_cp = torch.load(os.path.join(checkpoint_path, dataset + '_vqvae.pt'), map_location=device)
ar_cp = torch.load(os.path.join(checkpoint_path, dataset + '_ar.pt'), map_location=device)

In [15]:
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
vq_model = nets_LV.VQVAE(
    d=3,
    n_channels=(16, 32, 64, 256),
    code_size=128,
    n_res_block=2,
    dropout_p=.1
).to(device)

vq_model.load_state_dict(vqvae_cp["model"])
vq_model = vq_model.to(device)

ar_model = nets_LV.VQLatentSNAIL(
    feature_extractor_model=vq_model,
    shape=(16, 16),
    n_block=4,
    n_res_block=4,
    n_channels=128
).to(device)


ar_model.load_state_dict(ar_cp['model'])
ar_model = ar_model.to(device)

## Generate anomaly scores for various thresholds

In [16]:
!find ../../data -name ".DS_Store" -type f -delete

In [1]:
print(f"Evaluating {dataset} with thresholds {thresholds}")
input_dir = data_dir + f'/{dataset}'
output_dir = data_dir + '/output'

# bruh moment right here
predictions = defaultdict(lambda: defaultdict())


if dataset == "ffhq":
    for thr in thresholds:
        predictions[thr] = {}
        with torch.no_grad():
          for dif in ['easy', 'hard']:
            predictions[thr][dif] = {'scores':[],'labels':[]}
            for batch, cl in tqdm(ffhq_test_loaders[dif]):
              loss = ar_model.loss(batch.to(device), reduction="none")["loss"].flatten(1)
              scores = torch.sum(loss * (loss > thr), 1).float()
              predictions[thr][dif]['scores'].append(scores.detach().cpu().numpy().tolist())
              predictions[thr][dif]['labels'].append(cl.detach().cpu().numpy().tolist())
            assert len(predictions[thr][dif]['scores']) == len(predictions[thr][dif]['labels'])
    with open(os.path.join(output_dir, "ffhq", "scores.json"), "w") as write_file:
        json.dump(predictions, write_file)


elif dataset == "faceforensics":
    input_dir = os.path.join(data_dir, dataset, f'{split}_set_pruned')
    for thr in thresholds:
        predictions[thr]['scores'] = []
        predictions[thr]['labels'] = []
        with torch.no_grad():
            for rf in os.listdir(input_dir):
              label = txt_to_label[rf] # 0 for real, 1 for fake, hacky way to get the labels because imagefolder doesn't give the correct labels
              loader = ff_loaders[rf]
              for t, (batch, _) in enumerate(tqdm(loader)):
                  if t > 4:
                     continue
                  batch = batch.to(device)
                  loss = ar_model.loss(batch, reduction="none")["loss"].flatten(1)
                  scores = torch.sum(loss * (loss > thr), 1).float()
                  labels = [label] * len(scores)
                  predictions[thr]['scores'].append(scores.detach().cpu().numpy().tolist())
                  predictions[thr]['labels'].append(labels)
    # make sure your output directory exists before running                  
    with open(os.path.join(output_dir, "faceforensics","scores.json"), "w") as write_file:
        json.dump(predictions, write_file)

    

IndentationError: expected an indented block (1736643887.py, line 31)

## Sample wise score threshold

In [19]:
with open(output_dir + f"/{dataset}/scores.json") as s:
    scores = json.load(s)


if dataset == 'ffhq':
  res = defaultdict(lambda: defaultdict(lambda: {'auroc': None, 'ap': None}))
  for thr in thresholds:
    thr = str(thr)
    for dif in ['easy', 'hard']:
      labels = np.concatenate(scores[thr][dif]['labels'])
      preds = np.concatenate(scores[thr][dif]['scores'])

      auroc = roc_auc_score(labels,preds)
      ap = average_precision_score(labels,preds, pos_label=1)

      res[thr][dif]['auroc'] = auroc
      res[thr][dif]['ap'] = ap

elif dataset == 'faceforensics':
  res = defaultdict(lambda: {'auroc' : None, 'ap': None})
  for thr in thresholds:
    thr = str(thr)
    labels = np.concatenate(scores[thr]['labels'])
    preds = np.concatenate(scores[thr]['scores'])

    auroc = roc_auc_score(labels,preds)
    ap = average_precision_score(labels,preds, pos_label=1)

    res[thr]['auroc'] = auroc
    res[thr]['ap'] = ap

print(res.items())


dict_items([('7', defaultdict(<function <lambda>.<locals>.<lambda> at 0x13a6b60d0>, {'easy': {'auroc': 0.051867219917012444, 'ap': 0.9977379077086245}, 'hard': {'auroc': 0.1584375, 'ap': 0.8615225794547787}})), ('8', defaultdict(<function <lambda>.<locals>.<lambda> at 0x16ea6ca60>, {'easy': {'auroc': 0.11659751037344399, 'ap': 0.9947267227046538}, 'hard': {'auroc': 0.18838541666666667, 'ap': 0.829091633201088}}))])


Make sure to save the best threshold!

## OLD STUFF BELOW

In [37]:

thresholds = np.linspace(5,15,3)
print(thresholds)
input_dir = data_dir + f'/{dataset}'
output_dir = data_dir + '/output'

# bruh moment right here
predictions = defaultdict(lambda: defaultdict())

if dataset == "ffhq":
    for thr in thresholds:
        with torch.no_grad():            
            for difficulty in os.listdir(input_dir):
                for real_or_fake in os.listdir(os.path.join(input_dir, difficulty)):
                    for file_id in tqdm(os.listdir(os.path.join(input_dir, difficulty, real_or_fake))):
                        img = Image.open(os.path.join(input_dir, difficulty, real_or_fake, file_id))
                        img = torch.unsqueeze(transform(img).to(device), 0)
                        loss = ar_model.loss(img, reduction="none")["loss"].flatten(1)
                        scores = torch.sum(loss * (loss > thr), 1).float()
                        score = scores.sum()
                        predictions[thr][difficulty][real_or_fake][file_id] = score.detach().cpu().numpy().tolist()
    with open(os.path.join(output_dir, "ffhq", "scores.json"), "w") as write_file:
        json.dump(predictions, write_file)

# 0 = real, 1 = fake
elif dataset == "faceforensics":
    input_dir = data_dir + '/faceforensics/test_set'
    for thr in thresholds:
        predictions[thr][0] = []
        predictions[thr][1] = []
        with torch.no_grad():
            for batch, cl in tqdm(test_dataloader): 
                # assert all(x == cl[0] for x in cl)
                cl = cl.detach().cpu().numpy().tolist()[0]
                # img = Image.open(os.path.join(input_dir, real_or_fake, vid, file_id))
                # img = torch.unsqueeze(transform(img).to(device), 0)
                loss = ar_model.loss(batch, reduction="none")["loss"].flatten(1)
                scores = torch.sum(loss * (loss > thr), 1).float()
                score = scores.sum()
                predictions[thr][cl].append(score.detach().cpu().numpy().tolist())
    with open(os.path.join(output_dir, "faceforensics","scores_sample.json"), "w") as write_file:
        json.dump(predictions, write_file)

    

[ 5. 10. 15.]


  2%|▏         | 6/327 [01:40<1:29:11, 16.67s/it]


KeyboardInterrupt: 

## Evaluation

In [None]:
pred = []
accuracy = []
labels = []
ar_model.eval()

# 0 is fake
# 1 is real
thresholds = np.linspace(1,20,2)

loaded_dataloader = tqdm(val_dataloader)
print('Starting the measurements')
for thr in thresholds:
    print('for threshold', thr)
    temp_preds = []
    temp_labels = []
    for batchX, batchY in loaded_dataloader:
        batchX = batchX.to(device)
        batchY = batchY.to(device)
        with torch.no_grad():
            loss = ar_model.loss(batchX, reduction='none')['loss'].flatten(1)

            score = torch.sum(loss*(loss>thr), 1).float()
            temp_preds.extend(score.cpu().numpy())
            temp_labels.extend(batchY.cpu().numpy())

    pred.append(temp_preds)
    labels.append(temp_labels)

    


## Find best threshold

In [None]:

acc_per_thr = []
prec_per_thr = []
big_threshold = []

for threshold_idx, prediction in enumerate(pred):
    pred_copy = np.array(copy.deepcopy(prediction))
    real_preds = pred_copy[np.array(labels)[threshold_idx] == 1]
    fake_preds = pred_copy[np.array(labels)[threshold_idx] == 0]
    
    max_real = np.max(real_preds)
    min_fake = np.min(fake_preds)
    
    avg_thr = np.mean([max_real, min_fake])
    big_threshold.append(avg_thr)
    #Below are real
    pred_copy[pred_copy <= avg_thr] = 1
    
    #Above are fake
    pred_copy[pred_copy > avg_thr] = 0
    acc_per_thr.append(accuracy_score(labels[threshold_idx], pred_copy))
    prec_per_thr.append(precision_score(labels[threshold_idx], pred_copy))


best_threshold = np.argmax(acc_per_thr)
print('the mean for the big threshold:', np.mean(big_threshold))
print('best threshold value is', thresholds[best_threshold])
    

# Find pixel score threshold

In [None]:
def reconstruct(n, img, threshold_log_p = 5):
    """ Generates n reconstructions for each image in img.
    Resamples latent variables with cross-entropy > threshold
    Returns corrected images and associated latent variables"""
          
    #Use VQ-VAE to encode original image
    codes = ar_model.retrieve_codes(img)
    code_size = codes.shape[-2:]
    

    with torch.no_grad():
        samples = codes.clone().unsqueeze(1).repeat(1,n,1,1).reshape(img.shape[0]*n,*code_size)

        if not threshold_log_p == None:
            for r in tqdm(range(code_size[0])):
                for c in range(code_size[1]):        

                    code_logits = ar_model.forward_latent(samples)[:,:,r,c]
                    loss = F.cross_entropy(code_logits, samples[:, r, c], reduction='none')
                    probs = F.softmax(code_logits, dim=1)

                    samples[loss > threshold_log_p, r, c] = torch.multinomial(probs, 1).squeeze(-1)[loss > threshold_log_p]

        z = vq_model.codebook.embedding(samples.unsqueeze(1))
        z = z.squeeze(1).permute(0,3,1,2).contiguous()
        
        # Split the calculation in batches
        x_tilde = []
        for i in range(img.shape[0]):
            x_tilde.append(vq_model.decode(z[i*n:(i+1)*n]))
        x_tilde = torch.cat(x_tilde)
        
        
    return x_tilde.reshape(img.shape[0]*img.shape[1],n,*img.shape[-2:]), samples.reshape(img.shape[0],n,*code_size)

reconstructions = []

# for X,_ in loaded_dataloader:
#     x_tilde, latent_sample = reconstruct(n=15,img=X, threshold_log_p=5)
#     reconstructions.append(x_tilde)

## 5 images experiment

In [None]:
import matplotlib.pyplot as plt

In [None]:
X = next(iter(pixel_score_loader))[0]
reconstructionmax = reconstruct(n=5,img=X, threshold_log_p=None)[0]
reconstruction8 = reconstruct(n=5,img=X, threshold_log_p=8)[0]
reconstruction9 = reconstruct(n=5,img=X, threshold_log_p=9)[0]

In [None]:
diffrec = torch.abs(X[0] - torch.mean(reconstructionmax, dim=1))
diff8 = torch.mean(torch.abs(reconstructionmax - reconstruction8), dim=1)
diff9 = torch.mean(torch.abs(reconstructionmax - reconstruction9), dim=1)
reconstructionmax = torch.mean(reconstructionmax, dim=1)
reconstruction8 = torch.mean(reconstruction8, dim=1)
reconstruction9 = torch.mean(reconstruction9, dim=1)

fig, axes = plt.subplots(2,4, figsize=(20,10))

axes[0,0].imshow(X[0].permute(1,2,0))
axes[0,0].set_title('Original')
axes[0,1].imshow(reconstructionmax.squeeze().permute(1,2,0))
axes[0,1].set_title('No resampling')
axes[0,2].imshow(reconstruction8.squeeze().permute(1,2,0))
axes[0,2].set_title('Threshold 8')
axes[0,3].imshow(reconstruction9.squeeze().permute(1,2,0))
axes[0,3].set_title('Threshold 9')
axes[1,1].imshow(diffrec.squeeze().permute(1,2,0))
axes[1,1].set_title('|original - reconstruction|')
axes[1,2].imshow(diff8.squeeze().permute(1,2,0))
axes[1,2].set_title('|no resampling - threshold 8|')
axes[1,3].imshow(diff9.squeeze().permute(1,2,0))
axes[1,3].set_title('|no resampling - threshold 9|')

plt.tight_layout()
plt.show()
