In [1]:
DOWNLOAD_AMOUNT = 500 # set to none for unlimited 
LABELS_PATH = '../data/labels/'
LABEL_CNT = 228
IMG_SIZE = 299

In [2]:
import numpy as np
import os
import zipfile
import json
import h5py
import urllib3
import multiprocessing
from PIL import Image
from tqdm import tqdm
from urllib3.util import Retry
urllib3.disable_warnings()
import io

# Load image urls

In [3]:
def parse(fname, path, max_parse=None):
    """
    If the given filename does not exist, unzips a file called "<fname>.zip"
    """
    if not os.path.exists(fname):
        # unzip first
        with zipfile.ZipFile(path + fname + '.zip',"r") as zip_ref:
            zip_ref.extractall(path)
            
    ids_urls = []
    ids_labels = []
    with open(path + fname, 'r') as f:
        data = json.load(f)
        for image in data["images"]:
            url = image["url"]
            id = image["imageId"]
            ids_urls.append((id, url))
        if "annotations" in data.keys():
            for image in data["annotations"]:
                label_list = np.array(list(map(int, image["labelId"])))
                label_list = label_list - 1
                id = image["imageId"]
                label_vector = np.zeros(LABEL_CNT, dtype=np.int8)
                label_vector[label_list] = 1
                ids_labels.append((id, label_vector))
    
    if max_parse is not None:
        ids_urls = ids_urls[:max_parse]
        ids_labels = ids_labels[:max_parse]
        
    return ids_urls, ids_labels

In [4]:
train_ids_urls, train_ids_labels = parse('train.json', LABELS_PATH, max_parse=DOWNLOAD_AMOUNT)
val_ids_urls, val_ids_labels = parse('validation.json', LABELS_PATH, max_parse=DOWNLOAD_AMOUNT)
test_ids_urls, _ = parse('test.json', LABELS_PATH, max_parse=DOWNLOAD_AMOUNT)

# Load Images

In [10]:
def download_image(id_url_fname):
    id, url, fname = id_url_fname
    if not os.path.exists(fname):
        http = urllib3.PoolManager(retries=Retry(connect=3, read=2, redirect=3))
        response = http.request("GET", url)
        image = Image.open(io.BytesIO(response.data))
        image = image.resize((IMG_SIZE, IMG_SIZE))
        image_rgb = image.convert("RGB")
        image_rgb.save(fname, format='JPEG', quality=90)
    return
    

def download(ids_urls, outdir):
    if not os.path.exists(outdir):
        os.mkdir(outdir)
    params = [(id, url, os.path.join(outdir, "{}.jpg".format(id))) for (id, url) in ids_urls]
    pool = multiprocessing.Pool(processes=30)
    with tqdm(total=len(ids_urls)) as progress_bar:
        for _ in pool.imap_unordered(download_image, params):
            progress_bar.update(1)

In [12]:
if not os.path.exists('../data/raw_images/'):
    os.mkdir('../data/raw_images/')
download(train_ids_urls, '../data/raw_images/train/')
download(val_ids_urls, '../data/raw_images/validation/')
download(test_ids_urls, '../data/raw_images/test/')

100%|██████████| 500/500 [00:05<00:00, 84.51it/s]
100%|██████████| 500/500 [00:06<00:00, 80.73it/s]
100%|██████████| 500/500 [00:06<00:00, 82.77it/s]
