In [None]:
import os
import shutil

import json
import requests 
import import urllib, urllib.request

from PIL import Image, ImageFile
from io import BytesIO

import threading

ImageFile.LOAD_TRUNCATED_IMAGES = True

## Introduction

- The code in this notebook is used to grab more images from the ImageNet database. It is based on the mapping.json file
- This code is purely experimental and needs to be tested a bit more. However it does grab images and does have the potential to massively increase our dataset
- Due to time constraints the scraper was written, but not used. We would definitely consider grabbing more data as a viable strategy to increase performance

In [None]:
def get_labels(fname):
    """
    A function to grab the mappings from image class to label 
    Args:
        fname (string): A string representing the filename of the mapping file
        labels (list): returns a list of ImageNet labels based
    """
    folder_to_class = None
    labels = []
    with open(fname, mode="r") as f:
        folder_to_class = json.load(f)
    for label in folder_to_class:
        labels.append(label)
    return labels

In [None]:
labels = get_labels('mapping.json')

# where we want to store our new images
new_image_path = 'new_images/'

# make a place to hold our new images if it doesnt exist already
if not os.path.isdir(new_image_path):
    os.mkdir(new_image_path)

# this is the base url that can help us get a list of urls based on ImageNets internal database
URL = "http://www.image-net.org/api/text/imagenet.synset.geturls"

In [None]:
# a small function to cleanup the new images in case we get too many and need more hard disk space
def cleanup(new_image_path):
    shutil.rmtree(new_image_path)

In [None]:
# how many maximum extra images we want to grab per class
limit = 50

# this function goes through the ImageNet database of urls for a given label, and saves the images to drive
def get_image(label):
    path = new_image_path+label+'/'
    
    if not os.path.isdir(path):
        os.mkdir(path)
        
    count = len(os.listdir(path))

    if count >= limit:
        return count

    PARAMS = {'wnid': label}

    r = requests.get(url = URL, params = PARAMS)
    
    # for each url in the imagenet database, grab the images
    for i, url in enumerate(r.iter_lines()):
        url = url.decode("utf-8")
        filepath = path+"{}_{}.jpeg".format(label, i)

        if not os.path.isfile(filepath):  
            
            try:
                response = requests.get(url)
                data = response.raw.read()
                data = BytesIO(data)
                im = Image.open(BytesIO(response.content))      
                im.thumbnail((64, 64))
                im.save(filepath)
                count += 1
            except IOError as e:
                continue
        
        if count >= limit:
            break
        
    return count

In [None]:
from concurrent.futures import ThreadPoolExecutor

# to speed up the image grabbing, we can use threadpools
with ThreadPoolExecutor(max_workers=50) as pool:
    print(list(pool.map(get_image, labels)))