In [8]:
import pickle as pkl
from io import BytesIO
import numpy as np
import base64
from PIL import Image, ImageFile
from pathlib import Path

from tqdm import tqdm
from multiprocessing import Pool

In [9]:
shifted_id = {}
for seg_idx in range(150):
    val = seg_idx + 1
    shifted_id[seg_idx] = val
shifted_id[150] = 0

In [10]:
split = 'validation' # training | validation

In [11]:
seg_files = [x for x in Path(f'./ADEChallengeData2016/annotations/{split}').glob('*.png')]
seg_files.sort()

seg_files = [[line_id+1, x] for line_id, x in enumerate(seg_files)]

In [12]:
def return_row(line_id, seg_file):
    stem = seg_file.stem
    jpg_file = Path(f'./ADEChallengeData2016/images/{split}/{stem}.jpg')
    
    image = Image.open(jpg_file)

    output = BytesIO()
    image.save(output, 'PNG')
    image_base64_str = base64.b64encode(output.getvalue())
    image_base64_str = image_base64_str.decode("utf-8")

    seg = Image.open(seg_file)
    seg = np.asarray(seg).copy()
    mask_dict = {}
    for seg_label in np.unique(seg.flatten()):
        mask_dict[seg_label] = seg==seg_label

    for seg_label, mask in mask_dict.items():
        seg[mask] = shifted_id[seg_label]
    
    seg = Image.fromarray(seg)
    output = BytesIO()
    seg.save(output, 'PNG')
    seg_base64_str = base64.b64encode(output.getvalue())
    seg_base64_str = seg_base64_str.decode("utf-8")
    
    img_id = stem[len('ADE_val_'):]
    row = '\t'.join([image_base64_str, seg_base64_str, img_id, str(line_id)])
    
    return row

class Result():
    def __init__(self):
        self.rows = []

    def update_result(self, row):
        if row is not None:
            self.rows.append(row)

    def get_rows(self):
        return self.rows

In [13]:
pool = Pool(128)
result = Result()

In [14]:
row_list = []
for line_id, seg_file in seg_files:
    pool.apply_async(return_row, args=(line_id, seg_file), callback=result.update_result)
pool.close()
pool.join()

rows = result.get_rows()
rows.sort(key=lambda x: int(x.split('\t')[-1]))
ade_fullfile = "\n".join(rows)

with open(f'./dataset/ade/{split}.tsv', 'w') as f:
    f.write(ade_fullfile)