<a href="https://colab.research.google.com/github/aviguptatx/GansResearch/blob/master/Precision_and_Recall.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This code is based off of https://github.com/blandocs/improved-precision-and-recall-metric-pytorch (with a few modifications)


In [38]:
import os, torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torchsummary import summary
from tqdm import tqdm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class feature_extractor(object):
    def __init__(self, args):
        # parameters
        self.args = args
        self.generated_dir = args.generated_dir
        self.real_dir = args.real_dir
        self.batch_size = args.batch_size
        self.cpu = args.cpu
        self.data_size = args.data_size

    def extract(self):
        cnn = models.vgg16(pretrained=True)
        cnn.classifier = nn.Sequential(*[cnn.classifier[i] for i in range(5)])
        cnn = cnn.to(device).eval()
        generated_features = []
        real_features = []
        generated_img_paths = []

        with torch.no_grad():

            generated_data = ImageDataset(self.generated_dir, self.data_size, self.batch_size)
            generated_loader = DataLoader(generated_data, batch_size=self.batch_size, shuffle=False)

            for imgs, img_paths in tqdm(generated_loader, ncols=80):
                target_features = cnn(imgs)

                img_paths = list(img_paths)
                generated_img_paths.extend(img_paths)

                for target_feature in torch.chunk(target_features, target_features.size(0), dim=0):
                    generated_features.append(target_feature)

            real_data = ImageDataset(self.real_dir, self.data_size, self.batch_size)
            real_loader = DataLoader(real_data, batch_size=self.batch_size, shuffle=False)

            for imgs, _ in tqdm(real_loader, ncols=80):
                target_features = cnn(imgs)

                for target_feature in torch.chunk(target_features, target_features.size(0), dim=0):
                    real_features.append(target_feature)

        return generated_features, real_features, generated_img_paths

    def show_image(self, img):
        unloader = transforms.ToPILImage()
        plt.ion()
        plt.figure()
        image = img.cpu().clone()
        image = image.squeeze(0)
        image = unloader(image)
        plt.imshow(image)
        plt.title(Image)
        plt.pause(10)

class ImageDataset(Dataset):
    def __init__(self, dir_path, data_size=100, batch_size=64):
        self.dir_path = dir_path

        data_size = data_size - data_size%batch_size

        self.img_paths = []

        for i, img_name in enumerate(os.listdir(dir_path)):
            if i >= data_size:
                break
            img_path = os.path.join(dir_path, img_name)
            self.img_paths.append(img_path)

        self.imsize = 224

        self.transformations = transforms.Compose([
            transforms.Resize(self.imsize),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])]),
            transforms.Normalize(mean=[0.40760392, 0.45795686, 0.48501961],
            std=[1,1,1]),
            transforms.Lambda(lambda x: x.mul_(255)),
            ])

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        image = Image.open(img_path)
        image = self.transformations(image)
        return image.to(device, torch.float), img_path

    def __len__(self):
        return len(self.img_paths)

In [39]:
import os, torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import operator
from tqdm import tqdm

class precision_and_recall(object):
    def __init__(self, args):
        # parameters
        self.args = args
        self.result_dir = args.result_dir
        self.batch_size = args.batch_size
        self.cpu = args.cpu
        self.data_size = args.data_size
        self.k = 3

    def run(self):
        
        # load data using vgg16
        extractor = feature_extractor(self.args)
        generated_features, real_features, _ = extractor.extract()

        # equal number of samples
        data_num = min(len(generated_features), len(real_features))
        print(f'data num: {data_num}')

        if data_num <= 0:
            print("there is no data")
            return
        generated_features = generated_features[:data_num]
        real_features = real_features[:data_num]

        # get precision and recall
        precision = self.manifold_estimate(real_features, generated_features, self.k)
        recall = self.manifold_estimate(generated_features, real_features, self.k)
 
        print("Precision is: " + str(precision))        
        print("Recall is   : " + str(recall))

    def manifold_estimate(self, A_features, B_features, k):
        
        KNN_list_in_A = {}
        for A in tqdm(A_features, ncols=80):
            pairwise_distances = np.zeros(shape=(len(A_features)))

            for i, A_prime in enumerate(A_features):
                d = torch.norm((A-A_prime), 2)
                pairwise_distances[i] = d

            v = np.partition(pairwise_distances, k)[k]
            KNN_list_in_A[A] = v

        n = 0 

        for B in tqdm(B_features, ncols=80):
            for A_prime in A_features:
                d = torch.norm((B-A_prime), 2)
                if d <= KNN_list_in_A[A_prime]:
                    n+=1
                    break

        return n/len(B_features)

class realism(object):
    def __init__(self, args):
        # parameters
        self.args = args
        # self.data_dir = args.data_dir
        self.result_dir = args.result_dir
        self.batch_size = args.batch_size
        self.cpu = args.cpu
        self.k = 3  

    def run(self):

        # load data using vgg16
        extractor = feature_extractor(self.args)
        generated_features, real_features, generated_img_paths = extractor.extract()

        # equal number of samples
        data_num = min(len(generated_features), len(real_features))
        print(f'data num: {data_num}')

        if data_num <= 0:
            print("there is no data")
            return
        generated_features = generated_features[:data_num]
        real_features = real_features[:data_num]
        generated_img_paths = generated_img_paths[:data_num]

        KNN_list_in_real = self.calculate_real_NNK(real_features, self.k, data_num)

        for i, generated_feature in enumerate(tqdm(generated_features, ncols=80)):

            max_value = 0
            for real_feature, KNN_radius in KNN_list_in_real:
                d = torch.norm((real_feature-generated_feature), 2)
                value = KNN_radius/d
                if max_value < value:
                    max_value = value

            # print images with specific names
            if 'high_realism' in generated_img_paths[i] or 'low_realism' in generated_img_paths[i]:
                print(f'{generated_img_paths[i]} realism score: {max_value}')

        return

    def calculate_real_NNK(self, real_features, k, data_num):
        KNN_list_in_real = {}
        for real_feature in tqdm(real_features, ncols=80):
            pairwise_distances = np.zeros(shape=(len(real_features)))

            for i, real_prime in enumerate(real_features):
                d = torch.norm((real_feature-real_prime), 2)
                pairwise_distances[i] = d

            v = np.partition(pairwise_distances, k)[k]
            KNN_list_in_real[real_feature] = v

        # remove half of larger values
        KNN_list_in_real = sorted(KNN_list_in_real.items(), key=operator.itemgetter(1)) 
        KNN_list_in_real = KNN_list_in_real[:int(data_num/2)]


        return KNN_list_in_real

In [40]:
# Sketch thing I have to do since I'm running in colab - Avi
!rmdir /content/fake_data/.ipynb_checkpoints
!rmdir /content/real_data/.ipynb_checkpoints

rmdir: failed to remove '/content/fake_data/.ipynb_checkpoints': No such file or directory
rmdir: failed to remove '/content/real_data/.ipynb_checkpoints': No such file or directory


In [41]:
import argparse, os, torch

G_DIRECTORY = '/content/fake_data'
R_DIRECTORY = '/content/real_data'

def parse_args():
    desc = "calcualte precision and recall OR realism"
    parser = argparse.ArgumentParser(description=desc)
    parser.add_argument('--cal_type', type=str, default='precision_and_recall', choices=['precision_and_recall', 'realism'], help='The type of calcualtion')
    
    parser.add_argument('--result_dir', type=str, default='results', help='Directory name to save the model')
    parser.add_argument('--batch_size', type=int, default=2, help='The size of batch')
    parser.add_argument('--cpu', action='store_true')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--data_size', type=int, default=4)

    parser.add_argument('--generated_dir', default=G_DIRECTORY)
    parser.add_argument('--real_dir', default=R_DIRECTORY)
    args, unknown = parser.parse_known_args()
    # print(args)
    return check_args(args)


def check_args(args):
    # --result_dir
    if not os.path.exists(args.result_dir):
        os.makedirs(args.result_dir)    
    # --batch_size
    try:
        assert args.batch_size >= 1
    except:
        print('batch size must be larger than or equal to one')

    return args


def main():
    # parse arguments
    args = parse_args()

    if args.cal_type == 'precision_and_recall':
        task = precision_and_recall(args)
    else:
        task = realism(args)

    task.run()
    
if __name__ == '__main__':
    main()   






100%|█████████████████████████████████████████████| 2/2 [00:00<00:00, 32.15it/s]





100%|█████████████████████████████████████████████| 2/2 [00:00<00:00, 26.72it/s]





100%|████████████████████████████████████████████| 4/4 [00:00<00:00, 199.62it/s]





100%|████████████████████████████████████████████| 4/4 [00:00<00:00, 763.26it/s]





100%|████████████████████████████████████████████| 4/4 [00:00<00:00, 624.87it/s]





100%|████████████████████████████████████████████| 4/4 [00:00<00:00, 837.60it/s]

data num: 4
Precision is: 1.0
Recall is   : 1.0



