## Create Image Pair Candidates for Editing

In [1]:
# General imports
import torch
import numpy as np
import os, sys
import json
from tqdm import tqdm
from sklearn.neighbors import NearestNeighbors

In [2]:
# Local imports
sys.path.insert(0, 'src')
from utils import read_json, read_lists, write_lists, load_image
from utils.knn_utils import _get_k_nearest_neighbors as get_k_nearest_neighbors
from utils.model_utils import prepare_device
from utils.visualizations import show_image_rows, show_image
from parse_config import ConfigParser
from data_loader import data_loaders
import datasets.datasets as module_data
import model.model as module_arch

In [3]:
# Define constants, paths
config_path = 'configs/cinic10_imagenet_edit_knn.json'
class_list_path = 'metadata/cinic-10/class_names.txt'
# target_class =  3  # 5 (dog) is worst accuracy (40.914%) followed by 3 () with 54.94%
np.random.seed(0)  # for reproducibility

In [4]:
# Load config file, models
config_json = read_json(config_path)
config = ConfigParser(config_json)

layernum = config.config['layernum']
device, device_ids = prepare_device(config['n_gpu'])
print("Read in config file from {}".format(config_path))
      
model = config.init_obj('arch', module_arch, layernum=layernum)
model.eval()
print("Initialized model from {}".format(config.config['arch']['args']['checkpoint_path']))

class_list = read_lists(class_list_path)

Read in config file from configs/cinic10_imagenet_edit_knn.json
Initialized model from external_code/PyTorch_CIFAR10/cifar10_models/state_dicts/vgg16_bn.pt


In [5]:
data_loader_args = dict(config_json["data_loader"]["args"])
dataset_args = dataset_args = dict(config_json["dataset_args"])

# Create training data loader
image_paths = read_lists(config_json['dataset_paths']['train_images'])
labels = read_lists(config_json['dataset_paths']['train_labels'])
train_data_loader = torch.utils.data.DataLoader(
    module_data.CINIC10Dataset(
        data_dir="",
        image_paths=image_paths,
        labels=labels,
        return_paths=True,
        **dataset_args
    ),
    **data_loader_args
)


print("Initialized train data loader")

Initialized train data loader


Find correct and incorrectly predicted images from dataloader

In [6]:
# Pedal to the metal!
for target_class in range(len(class_list)):
    print("Target class: {} ({})".format(class_list[target_class], target_class))
    
    correct_image_paths = []
    correct_images = []
    incorrect_image_paths = []
    incorrect_images = []
    incorrect_predictions = []
    with torch.no_grad():
        for idx, item in enumerate(tqdm(train_data_loader)):
            image, target, path = item

            # Skip any batches with no examples from target class
            if (target != target_class).all():
                continue

            # Find indices where target = target class
            target_idxs = (target == target_class).nonzero()
            target_idxs = torch.squeeze(target_idxs)

            image = image[target_idxs]
            target = target[target_idxs]
            path = [path[idx] for idx in target_idxs]  # path[target_idxs]

            # Move data and label to GPU
            image, target = image.to(device), target.to(device)

            # print("image shape {}".format(image.shape))
            output = model(image)
            prediction = torch.argmax(output, dim=1)

            # Obtain indices of where model predicted correctly and incorrectly
            correct_idxs = torch.squeeze((prediction == target_class).nonzero())
            incorrect_idxs = torch.squeeze((prediction != target_class).nonzero())

            correct_image_paths += [path[idx] for idx in correct_idxs] 
            correct_images.append(image[correct_idxs])

            incorrect_image_paths += [path[idx] for idx in incorrect_idxs]
            incorrect_images.append(image[incorrect_idxs])
            incorrect_predictions.append(prediction[incorrect_idxs])
            
    n_correct = len(correct_image_paths)
    n_incorrect = len(incorrect_image_paths)
    n_total = n_correct + n_incorrect

    correct_images = torch.cat(correct_images, dim=0)
    correct_images = correct_images.cpu()

    incorrect_images = torch.cat(incorrect_images, dim=0)
    incorrect_images = incorrect_images.cpu()
    
    print("{} ({:.2f} %) correct images and {} ({:.2f} %) incorrect images".format(
        n_correct,
        100 * n_correct / n_total,
        n_incorrect,
        100 * n_incorrect / n_total))
    
    save_dir = os.path.join(
        'metadata', 
        'CINIC10-ImageNet', 
        class_list[target_class],
        config.config['arch']['args']['type'])

    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)

    # Save list of correct image paths and the images
    correct_image_paths_filepath = os.path.join(save_dir, 'correct_image_paths.txt')
    correct_images_save_path = os.path.join(save_dir, 'correct_images.pth')


    incorrect_image_paths_filepath = os.path.join(save_dir, 'incorrect_image_paths.txt')
    incorrect_images_save_path = os.path.join(save_dir, 'incorrect_images.pth')

    print("Saving lists to {}".format(save_dir))
    
    write_lists(correct_image_paths_filepath, correct_image_paths)
    write_lists(incorrect_image_paths_filepath, incorrect_image_paths)
    

Target class: airplane (0)


100%|█████████████████████████████████████████████████████████████████████| 274/274 [05:08<00:00,  1.12s/it]


6142 (87.74 %) correct images and 858 (12.26 %) incorrect images
Saving lists to metadata/CINIC10-ImageNet/airplane/vgg16_bn
Target class: automobile (1)


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:28<00:00,  9.59it/s]


5283 (75.47 %) correct images and 1717 (24.53 %) incorrect images
Saving lists to metadata/CINIC10-ImageNet/automobile/vgg16_bn
Target class: bird (2)


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:31<00:00,  8.80it/s]


5113 (73.04 %) correct images and 1887 (26.96 %) incorrect images
Saving lists to metadata/CINIC10-ImageNet/bird/vgg16_bn
Target class: cat (3)


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:23<00:00, 11.50it/s]


3761 (53.73 %) correct images and 3239 (46.27 %) incorrect images
Saving lists to metadata/CINIC10-ImageNet/cat/vgg16_bn
Target class: deer (4)


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:12<00:00, 21.71it/s]


4248 (60.69 %) correct images and 2752 (39.31 %) incorrect images
Saving lists to metadata/CINIC10-ImageNet/deer/vgg16_bn
Target class: dog (5)


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:38<00:00,  7.19it/s]


2975 (42.50 %) correct images and 4025 (57.50 %) incorrect images
Saving lists to metadata/CINIC10-ImageNet/dog/vgg16_bn
Target class: frog (6)


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.92it/s]


5633 (80.47 %) correct images and 1367 (19.53 %) incorrect images
Saving lists to metadata/CINIC10-ImageNet/frog/vgg16_bn
Target class: horse (7)


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:07<00:00, 37.88it/s]


5176 (73.94 %) correct images and 1824 (26.06 %) incorrect images
Saving lists to metadata/CINIC10-ImageNet/horse/vgg16_bn
Target class: ship (8)


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:08<00:00, 31.80it/s]


4765 (68.07 %) correct images and 2235 (31.93 %) incorrect images
Saving lists to metadata/CINIC10-ImageNet/ship/vgg16_bn
Target class: truck (9)


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:07<00:00, 39.00it/s]

4793 (68.47 %) correct images and 2207 (31.53 %) incorrect images
Saving lists to metadata/CINIC10-ImageNet/truck/vgg16_bn



