In [8]:
import torch
import utils
import nets_LV
import torch.optim as optim
import torchvision.transforms as transforms
import numpy as np
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from tqdm import tqdm
import copy
from sklearn.metrics import precision_score, accuracy_score

In [2]:
transform_pipeline = transforms.Compose([
    transforms.Resize((128, 128)),
    # transforms.RandomHorizontalFlip(),
    # transforms.ColorJitter(brightness=.3, hue=.2),
    # transforms.RandomRotation(degrees=(-10, 10)),
    transforms.ToTensor()
])

In [3]:
device = torch.device('cuda')

In [4]:
lambda_set = '../../../../../project/gpuuva022/shared/AnomalyDetection/FFHQ_Data/FFHQ_data/l_tuning'

images_pixel_score_path = '/home/lcur1737/AnomalyDetection/data/thresholding_images'
val_dataset = ImageFolder(lambda_set, transform=transform_pipeline)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=True)

In [5]:
vqvae_checkpoint_path = '/home/lcur1737/AnomalyDetection/src/checkpoints/VQVAE/ffhq_continued_020.pt'
ar_checkpoint_path = '/home/lcur1737/AnomalyDetection/src/checkpoints/AR/ffhq_ar_030.pt'

In [6]:
vq_model = nets_LV.VQVAE(
    d=3,
    n_channels=(16, 32, 64, 256),
    code_size=128,
    n_res_block=2,
    dropout_p=.1
)

vqvae_checkpoint = torch.load(vqvae_checkpoint_path)
vq_model.load_state_dict(vqvae_checkpoint["model"])
vq_model = vq_model.to(device)

ar_model = nets_LV.VQLatentSNAIL(
    feature_extractor_model=vq_model,
    shape=(16, 16),
    n_block=4,
    n_res_block=4,
    n_channels=128
)

ar_checkpoint = torch.load(ar_checkpoint_path)

ar_model.load_state_dict(ar_checkpoint['model'])

ar_model = ar_model.to(device)

# Sample wise score threshold

In [7]:
pred = []
accuracy = []
labels = []
ar_model.eval()

# 0 is fake
# 1 is real
thresholds = np.linspace(1,20,2)

loaded_dataloader = tqdm(val_dataloader)
print('Starting the measurements')
for thr in thresholds:
    print('for threshold', thr)
    temp_preds = []
    temp_labels = []
    for batchX, batchY in loaded_dataloader:
        batchX = batchX.to(device)
        batchY = batchY.to(device)
        with torch.no_grad():
            loss = ar_model.loss(batchX, reduction='none')['loss'].flatten(1)

            score = torch.sum(loss*(loss>thr), 1).float()
            temp_preds.extend(score.cpu().numpy())
            temp_labels.extend(batchY.cpu().numpy())

    pred.append(temp_preds)
    labels.append(temp_labels)

    


  0%|          | 0/8 [00:00<?, ?it/s]

Starting the measurements
for threshold 1.0


100%|██████████| 8/8 [03:05<00:00, 23.18s/it]


for threshold 20.0


## Find best threshold

In [20]:

acc_per_thr = []
prec_per_thr = []

for threshold_idx, prediction in enumerate(pred):
    pred_copy = np.array(copy.deepcopy(prediction))
    real_preds = pred_copy[np.array(labels)[threshold_idx] == 1]
    fake_preds = pred_copy[np.array(labels)[threshold_idx] == 0]
    
    max_real = np.max(real_preds)
    min_fake = np.min(fake_preds)
    
    avg_thr = np.mean([max_real, min_fake])
    #Below are real
    pred_copy[pred_copy <= avg_thr] = 1
    
    #Above are fake
    pred_copy[pred_copy > avg_thr] = 0
    acc_per_thr.append(accuracy_score(labels[threshold_idx], pred_copy))
    prec_per_thr.append(precision_score(labels[threshold_idx], pred_copy))


best_threshold = np.argmax(acc_per_thr)
print('best threshold value is', thresholds[best_threshold])
    

best threshold value is 1.0


  _warn_prf(average, modifier, msg_start, len(result))


# Find pixel score threshold

In [7]:
def reconstruct(n, img, threshold_log_p = 5):
    """ Generates n reconstructions for each image in img.
    Resamples latent variables with cross-entropy > threshold
    Returns corrected images and associated latent variables"""
          
    #Use VQ-VAE to encode original image
    codes = ar_model.retrieve_codes(img)
    code_size = codes.shape[-2:]
    
    with torch.no_grad():
        samples = codes.clone().unsqueeze(1).repeat(1,n,1,1).reshape(img.shape[0]*n,*code_size)
        logits = ar_model.forward_latent(samples)

        for r in range(code_size[0]):
            for c in range(code_size[1]):
                code_logits = logits[:, :, r, c]

                loss = F.cross_entropy(code_logits, samples[:, r, c], reduction='none')
                probs = F.softmax(code_logits, dim=1)

                samples[loss > threshold_log_p, r, c] = torch.multinomial(probs, 1).squeeze(-1)[loss > threshold_log_p]

        z = vq_model.codebook.embedding(samples.unsqueeze(1))
        z = z.squeeze(1).permute(0,3,1,2).contiguous()
        
        # Split the calculation in batches
        x_tilde = []
        for i in range(img.shape[0]):
            x_tilde.append(vq_model.decode(z[i*n:(i+1)*n]))
        x_tilde = torch.cat(x_tilde)
        
        
    return x_tilde.reshape(img.shape[0],n,*img.shape[-2:]), samples.reshape(img.shape[0],n,*code_size)

reconstructions = []

# for X,_ in loaded_dataloader:
#     x_tilde, latent_sample = reconstruct(n=15,img=X, threshold_log_p=5)
#     reconstructions.append(x_tilde)

## 5 images experiment

In [11]:
print(len(val_dataloader.dataset))
# for x, _ in loaded_dataloader:

504
