In [366]:
import numpy as np
import pickle
import matplotlib.pyplot as plt
import os
import sys
from os.path import join
%matplotlib inline

In [372]:
def color_map(N=256, normalized=False):
    def bitget(byteval, idx):
        return ((byteval & (1 << idx)) != 0)

    dtype = 'float32' if normalized else 'uint8'
    cmap = np.zeros((N, 3), dtype=dtype)
    for i in range(N):
        r = g = b = 0
        c = i
        for j in range(8):
            r = r | (bitget(c, 0) << 7-j)
            g = g | (bitget(c, 1) << 7-j)
            b = b | (bitget(c, 2) << 7-j)
            c = c >> 3

        cmap[i] = np.array([r, g, b])

    cmap = cmap/255 if normalized else cmap
    return cmap


In [373]:
cmap = color_map(normalized=True)
labels = ['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 
              'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 
              'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 
              'sofa', 'train', 'tvmonitor', 'void']
cmap_dict = {tuple(clr):i for i, clr in enumerate(cmap)}
label_dict = {i:label for i,label in enumerate(labels[:-1])}
label_dict[255] = labels[-1]

In [374]:
with open('voc_cmap_dict.pkl','wb') as f:
    pickle.dump(cmap_dict, f)
    
with open('voc_label_dict.pkl','wb') as f:
    pickle.dump(label_dict, f)
    

In [379]:
cmap_num = np.sum((cmap[:21]*255/64)*([[100,10,1]]), axis=1)

In [382]:
def img2seg(img, cmap_dict):
    img = np.sum((img*255/64)*([100, 10, 1]), axis = -1).astype(np.int64)
    seg = list(map(cmap_dict.get, img.reshape(-1)))
    seg = np.reshape(seg, img.shape).astype(np.int64)
    return seg

In [383]:
cm_dict = dict(zip(cmap_num.astype(np.int64), list(range(21))))
cm_dict[388] = 255 #colour for 255 is [0.8784314, 0.8784314, 0.7529412] 
#and ([0.8784314, 0.8784314, 0.7529412]*255/64)*[100,10,1] = 388

In [392]:
img_files = os.listdir('VOCdevkit/VOC2012/SegmentationClass')

In [439]:
for i, img_file in enumerate(img_files):
    sys.stdout.write('\r{}/{}'.format(i, len(img_files)))
    img = plt.imread(os.path.join('VOCdevkit/VOC2012/SegmentationClass', img_file))
    seg = img2seg(img, cm_dict)
    plt.imsave(arr=seg, fname=os.path.join('VOCdevkit/VOC2012/SegmentationMap', img_file), cmap='gray')

2912/2913