In [1]:
%matplotlib inline
from pycocotools.coco import COCO
import urllib.request
import numpy as np
import skimage.io as io
import matplotlib.pyplot as plt
import pylab
from tqdm import tqdm
pylab.rcParams['figure.figsize'] = (8.0, 10.0)


In [2]:
import functools
import os.path
import random
import sys
import xml.etree.ElementTree
import numpy as np
import matplotlib.pyplot as plt
import skimage.data
import cv2
import PIL.Image

In [3]:
import os
import random
from argparse import ArgumentParser
from PIL import Image

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torchvision.utils as vutils

from model.networks import Generator
from utils.tools import get_config, random_bbox, mask_image, is_image_file, default_loader, normalize, get_model_list

SEED = 42 

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)


In [4]:
from pathlib import Path
from glob import glob

In [52]:
dataDir='datasets/coco'
dataType='val2017'
annFile_full='{}/annotations/instances_{}.json'.format(dataDir,dataType)
annFile_human_pose = '{}/annotations/person_keypoints_{}.json'.format(dataDir,dataType)

In [76]:
# initialize COCO api for instance annotations
coco=COCO(annFile_full)
coco_kps=COCO(annFile_human_pose)

# get all images containing given categories, select one at random
catIds = coco.getCatIds(catNms=['person'])
imgIds = coco.getImgIds(catIds=catIds)

list_of_anns = []
for img_id in imgIds:
    img = coco.loadImgs(img_id)[0]
    annIds = coco_kps.getAnnIds(imgIds=img['id'], catIds=catIds, iscrowd=0)
    if len(annIds) != 1:
        continue
    anns = coco_kps.loadAnns(annIds)
    if anns[0]['area'] < 3000:
        continue
    
    if anns[0]['num_keypoints'] < 10:
        continue
    list_of_anns.append(anns[0])

loading annotations into memory...
Done (t=0.53s)
creating index...
index created!
loading annotations into memory...
Done (t=0.17s)
creating index...
index created!


In [94]:
class DataLoader:
    ''' Table to load the Data from folders '''
    def __init__(self,
                 real_img_path,
                 occ_img_path,
                 gen_img_path,
                 ls_gt,
                 ):
        ''' Define the constructor variables'''
        self.real_img_path = real_img_path
        self.occ_img_path = occ_img_path
        self.gen_img_path = gen_img_path
        self.ls_gt = ls_gt
        
    def get_the_paths(self, folder_path):
        return sorted([str(file).split('/')[1] for file in Path(folder_path).glob('*.jpg')])
    
    def get_shared_img(self, real_list, occ_list, gen_list):
        return set(real_list) & set(occ_list) & set(gen_list)
    
    def remove_unshared_img(self, list_imgs, large_list):
        return [x for x in large_list if x in list_imgs]
    
    def remove_unshared_img_gt(self, list_imgs, large_list):
        check_list = list(list_imgs)
        check_list = [int(x.split('.')[0]) for x in check_list]
        
        return [x for x in large_list if x['image_id'] in check_list]
    
    def create_dictionary_of_lists(self, real_list, occ_list, gen_list, final_ls_gt):
        print(' ------------- Check List Length ------------- ')
        print(f'Real list : {len(real_list)}\nOccluded list : {len(occ_list)}\nGenerated list : {len(gen_list)}\nGround Thruth list : {len(final_ls_gt)}')
        return {
            self.real_img_path: real_list,
            self.occ_img_path: occ_list,
            self.gen_img_path: gen_list,
            'ground_truth': final_ls_gt,
               }
        
    def make_dataset(self):
        # Get the lists of folder path
        real_imgs_paths = self.get_the_paths(self.real_img_path)
        occ_imgs_paths = self.get_the_paths(self.occ_img_path)
        gen_imgs_paths = self.get_the_paths(self.gen_img_path)
        
        # Get shared list
        shared_img = self.get_shared_img(real_imgs_paths, occ_imgs_paths, gen_imgs_paths)
        
        # Check if unwanted files are present
        real_imgs_paths = self.remove_unshared_img(shared_img, real_imgs_paths)
        occ_imgs_paths = self.remove_unshared_img(shared_img, occ_imgs_paths)
        gen_imgs_paths = self.remove_unshared_img(shared_img, gen_imgs_paths)
        final_ls_gt = self.remove_unshared_img_gt(shared_img, self.ls_gt)
        
        return self.create_dictionary_of_lists(real_imgs_paths, occ_imgs_paths, gen_imgs_paths, final_ls_gt)

In [95]:
loader = DataLoader('real_img', 'occ_img', 'gen_img', list_of_anns)
dataset = loader.make_dataset()

 ------------- Check List Length ------------- 
Real list : 545
Occluded list : 545
Generated list : 545
Ground Thruth list : 545
