In [44]:
import argparse
import csv
import json
from tqdm import tqdm
from PIL import Image
import os


##### "validation" "train" "test"

In [48]:
dataset = "validation"

In [49]:
base_dir = "/Users/yuxiong/Desktop/Capstone/data"

### Tanslate Class Descriptions
#### Trainable Class: Handbag

In [50]:
def translate_class_descriptions(trainable_classes_file, descriptions_file):
    with open(trainable_classes_file, 'r') as file:
        trainable_classes = file.read().strip()
        #print(trainable_classes)
    description_table = {}
    with open("{}/classes/class-descriptions.csv".format(base_dir)) as f:
        #print(csv.reader(f))
        for row in csv.reader(f):
            if len(row):
                description_table[row[0]] = row[1].replace("\"", "").replace("'", "").replace('`', '')
        #print(description_table)
    output = []

    output.append(description_table[trainable_classes])
    return output


def save_classes(formatted_data, translated_path):
    with open(translated_path, 'w+') as f:
        json.dump(formatted_data, f)

In [51]:
translated = translate_class_descriptions("{}/classes/classes-bbox-trainable.txt".format(base_dir), "{}/classes/class-descriptions.csv".format(base_dir))

save_classes(translated, "{}/classes/trainable_translated.csv".format(base_dir))


###### 1 Format Annotations

In [52]:
def format_annotations(annotation_path, trainable_classes_path):
    annotations = []
    ids = []
    with open(trainable_classes_path, 'r') as file:
        trainable_classes = file.read().strip()
        #print(type(trainable_classes))
    with open(annotation_path, 'r') as annofile:
        for row in csv.reader(annofile):
            annotation = {'id': row[0], 'label': row[2], 'confidence': row[3], 'x0': row[4],
                          'x1': row[5], 'y0': row[6], 'y1': row[7]}
            #print(annotation["id"])
            #print(annotation)
            if annotation['label'] == trainable_classes:
            
                annotations.append(annotation)
                ids.append(row[0])
    ids = dedupe(ids)
    return annotations, ids


In [53]:
def dedupe(seq):
    seen = set()
    seen_add = seen.add
    return [x for x in seq if not (x in seen or seen_add(x))]

In [54]:
annotations, valid_image_ids = format_annotations("{}/bbox_annotations/{}/annotations-human-bbox.csv".format(base_dir,dataset), "{}/classes/classes-bbox-trainable.txt".format(base_dir))



In [55]:
annotations[0]

{'confidence': '1',
 'id': '3339a5c598981879',
 'label': '/m/080hkjn',
 'x0': '0.054879',
 'x1': '0.945046',
 'y0': '0.020680',
 'y1': '0.961674'}

###### 2 Format Images URL csv file

In [56]:
def format_images(images_path):
    images = []
    with open(images_path, 'r') as f:
        reader = csv.reader(f)
        dataset = list(reader)
        for row in dataset:
            image = {'id': row[0], 'url': row[2]}
            images.append(image)
    return images

In [57]:
images = format_images("{}/images/{}/images.csv".format(base_dir,dataset))


In [58]:
images[1]

{'id': '0001eeaf4aed83f9',
 'url': 'https://c2.staticflickr.com/6/5606/15611395595_f51465687d_o.jpg'}

###### 3 Filter Image URL by Trainable Class

In [59]:
# Lets check each image and only keep it if it's ID has a bounding box annotation associated with it.
def filter_images(dataset, ids):
    output_list = []
    unique_ids = set(ids)
    for element in tqdm(dataset, desc="filtering out non-essential images"):
        if element['id'] in unique_ids:
            output_list.append(element)
    return output_list


def save_data(data, out_path):
    with open(out_path, 'w+') as f:
        json.dump(data, f)


# Gathers annotations for each image id, to be easier to work with.
def points_maker(annotations):
    by_id = {}
    
    for anno in tqdm(annotations, desc="grouping annotations"):
        #print(anno)
        if anno['id'] in by_id:
            by_id[anno['id']].append(anno)
        else:
            by_id[anno['id']] = []
            by_id[anno['id']].append(anno)
    groups = []
    while len(by_id) >= 1:
        key, value = by_id.popitem()
        groups.append({'id': key, 'annotations': value})
    return groups

In [60]:
points = points_maker(annotations)
#filtered_images = filter_images(images, valid_image_ids)

grouping annotations: 100%|██████████| 60/60 [00:00<00:00, 143886.93it/s]


In [61]:
points[0]

{'annotations': [{'confidence': '1',
   'id': '368f0045b95affc3',
   'label': '/m/080hkjn',
   'x0': '0.000049',
   'x1': '1.000000',
   'y0': '0.000000',
   'y1': '0.999843'}],
 'id': '368f0045b95affc3'}

In [62]:
filtered_images = filter_images(images, valid_image_ids)

filtering out non-essential images: 100%|██████████| 41621/41621 [00:00<00:00, 1081116.51it/s]


In [63]:
filtered_images[0]

{'id': '0040009ad56c2bc2',
 'url': 'https://farm3.staticflickr.com/5174/5405309020_7ce65b0636_o.jpg'}

In [64]:
save_data(filtered_images, "{}/filtered_images.json".format(base_dir))
save_data(points,"{}/points.json".format(base_dir))

###### 4. Download Handbag Images

In [65]:
import json
import os
import random
import requests
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
import argparse
# Downloads all image files contained in dataset, if an image fails to download lets skip it.


# This is a nice parallel processing tool that uses tqdm
# to help visualize time-to-completion.
def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=3):
    """
        A parallel version of the map function with a progress bar.
        Args:
            array (array-like): An array to iterate over.
            function (function): A python function to apply to the elements of array
            n_jobs (int, default=16): The number of cores to use
            use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of
                keyword arguments to function
            front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job.
                Useful for catching bugs
        Returns:
            [function(array[0]), function(array[1]), ...]
    """
    #We run the first few iterations serially to catch bugs
    if front_num > 0:
        front = [function(**a) if use_kwargs else function(a) for a in array[:front_num]]
    #If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging.
    if n_jobs==1:
        return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])]
    #Assemble the workers
    with ProcessPoolExecutor(max_workers=n_jobs) as pool:
        #Pass the elements of array into function
        if use_kwargs:
            futures = [pool.submit(function, **a) for a in array[front_num:]]
        else:
            futures = [pool.submit(function, a) for a in array[front_num:]]
        kwargs = {
            'total': len(futures),
            'unit': 'it',
            'unit_scale': True,
            'leave': True
        }
        #Print out the progress as tasks complete
        for f in tqdm(as_completed(futures), **kwargs):
            pass
    out = []
    #Get the results from the futures.
    for i, future in tqdm(enumerate(futures)):
        try:
            out.append(future.result())
        except Exception as e:
            out.append(e)
    return front + out


def download(element):
    image_content = None
    dir_path = "{}/images/{}".format(base_dir,dataset)
    browser_headers = [
        {
            "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/51.0.2704 Safari/537.36"},
        {
            "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/52.0.2743 Safari/537.36"},
        {"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.11; rv:44.0) Gecko/20100101 Firefox/44.0"}
    ]
    try:
        response = requests.get(element['url'],
                                headers=random.choice(browser_headers),
                                verify=False)
        image_content = response.content
    except:
        pass
    if image_content:
        complete_file_path = os.path.join(dir_path, element['id']+'.'+element['url'].split('.')[-1])
        with open(complete_file_path, "wb") as f:
            f.write(image_content)
            f.close()



In [68]:
parser = argparse.ArgumentParser()
parser.add_argument('--images_path', dest='images_path', required=True)
parser.add_argument('--images_output_directory', dest='images_output_directory', required=True)


browser_headers = [
    {
        "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/51.0.2704 Safari/537.36"},
    {
        "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/52.0.2743 Safari/537.36"},
    {"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.11; rv:44.0) Gecko/20100101 Firefox/44.0"}
]
try:
    os.makedirs("{}/images/{}".format(base_dir,dataset))
except OSError:
    pass  # already exists
with open("{}/filtered_images.json".format(base_dir), 'r') as f:
    image_urls = json.load(f)
parallel_process(image_urls, download)


###### 5. Process_Images: Image Verification and Dimension Reduction

In [69]:
def process_images(saved_images_path, resized_images_path, points):
    cleaned_points = []
    for point in tqdm(points, desc="checking if images are valid from label index"):
        
        stored_path = os.path.join(saved_images_path, point['id'] + '.jpg')

        im = Image.open(stored_path)

        
       
        im.thumbnail((256, 256))

        if resized_images_path:
            resized_path = os.path.join(resized_images_path, point['id'] + '.jpg')
            
            im.save(resized_path)
        else:
            os.remove(stored_path)
            im.save(stored_path)
        cleaned_points.append(point)

    return cleaned_points

In [70]:
def load_dataset(file_path):
    with open(file_path, 'r') as f:
        annotations = json.load(f)
    return annotations


def save_dataset(data, file_path):
    with open(file_path, 'w+') as f:
        json.dump(data, f)
        


    
points = load_dataset("{}/points.json".format(base_dir))
filtered_points = process_images("{}/images/{}".format(base_dir,dataset), "{}/images/{}".format(base_dir,dataset), points)
save_dataset(filtered_points, "{}/points.json".format(base_dir))


checking if images are valid from label index: 100%|██████████| 55/55 [00:01<00:00, 51.52it/s]
