In [33]:
import numpy as np
from PIL import Image
import os
import scipy.io

In [2]:
def read_mat(mat_file_name):
    import scipy.io
    import skimage.io as io
    key = 'ATmask'
    mat = scipy.io.loadmat(mat_file_name, mat_dtype=True, squeeze_me=True, struct_as_record=False)
    return mat[key];

In [3]:
def convert_mat_annotations_to_png(mat_file_name):
    """ convert mat annotation file to png image 
    
    Parameters
    ----------
    masks_file_name
        Path to the matfile.
    
    """
    
    import scipy.io
    
    import skimage.io as io
    label_colours = [(128, 0, 0), (0, 128,0), (0, 0, 128), (128, 128, 0)]
    
    key = 'ATmask'
    
    mat = scipy.io.loadmat(mat_file_name, mat_dtype=True, squeeze_me=True, struct_as_record=False)
    h,w = mat[key].shape

    outputs = np.zeros((h,w,3), dtype = np.uint8)
    num_classes = 5

    img = Image.new('RGB', (h, w))
    pixels = img.load()
    for j_, j in enumerate(mat[key]):
        for k_, k in enumerate(j):
            if k < num_classes:
                pixels[k_,j_] = label_colours[int(k-1)]
    outputs = np.array(img)
    return outputs

In [4]:
def convert_segmentation_array_to_png(mask, class_id):
    """ convert mat annotation file to png image 
    
    Parameters
    ----------
    masks_file_name
        Path to the matfile.
    
    """
    mask = np.transpose(mask)
    import scipy.io
    
    import skimage.io as io
    label_colours = [(128, 0, 0), (0, 128,0), (0, 0, 128), (128, 128, 0)]
    
    h,w = mask.shape

    outputs = np.zeros((h,w,3), dtype = np.uint8)
    num_classes = 4

    img = Image.new('RGB', (h, w))
    pixels = img.load()
    for j_, j in enumerate(mask):
        for k_, k in enumerate(j):
            if k == 1:
                pixels[j_,k_] = label_colours[int(class_id - 1)]
    outputs = np.array(img)
    return outputs

In [5]:
def convert_array_to_png(mat):
    """ convert mat annotation file to png image 
    
    Parameters
    ----------
    masks_file_name
        Path to the matfile.
    
    """
    mat = np.transpose(mat)
    import scipy.io
    
    import skimage.io as io
    label_colours = [(128, 0, 0), (0, 128,0), (0, 0, 128), (128, 128, 0)]
    
    h,w = mat.shape

    outputs = np.zeros((h,w,3), dtype = np.uint8)
    num_classes = 4

    img = Image.new('RGB', (h, w))
    pixels = img.load()
    for j_, j in enumerate(mat):
        for k_, k in enumerate(j):
            if k <= num_classes:
                pixels[j_,k_] = label_colours[int(k-1)]
    outputs = np.array(img)
    return outputs
    

In [6]:
class Unionfind:
    
    def __init__(self, n, mask):
        self.father = np.arange(n)
    
    def _find(self, x):
        if self.father[x] == x: 
            return x
        self.father[x] = self._find(self.father[x])
        return self.father[x]
    
    def connect(self, a, b):
        root_a = self._find(a)
        root_b = self._find(b)
        if (root_a != root_b):
            self.father[root_b] = root_a
            
    def set_father(self, pos, val):
        self.father[pos] = val
    


In [39]:
mat_file_name = './cedars-224/masks/0000.mat'
matfile = read_mat(mat_file_name)

In [40]:
# class_id = 4
# mask = (matfile == class_id)
# mask = mask.astype(int)
# outputs = convert_segmentation_array_to_png(mask, class_id)
# im = Image.fromarray(outputs)
# im.show()

In [41]:
outputs = convert_mat_annotations_to_png(mat_file_name)
im = Image.fromarray(outputs)
im.show()

In [42]:
mask = matfile.astype(int)
r, c = mask.shape
my_union = Unionfind(r * c, mask.reshape(-1))

dr = [0, 1, -1, 0] # right, down, up, left
dc = [1, 0, 0, -1]
old = np.ones(r * c)
while(not np.array_equal(my_union.father,old)):   
    old = np.copy(my_union.father)
    for i in range(r):
        for j in range(c):
            if(mask[i, j] == 4):
                my_union.set_father(i * c + j, -1)
            else:
                for k in range(4):
                    if (i + dr[k] >= 0 and i + dr[k] < r and
                       j + dc[k] >= 0 and j + dc[k] < c and mask[i + dr[k], j + dc[k]] == mask[i, j]):
                        my_union.connect(i * c + j, (i + dr[k]) * c + j + dc[k])
print(np.unique(my_union.father))

[     -1    7854   10274  336133  457428 1102555]


In [43]:
res = my_union.father.reshape((r, c))
print(len(np.unique(res)), np.unique(res))
# seg = (res == 1102555).astype(int)
# outputs = convert_segmentation_array_to_png(seg, 0)
# im = Image.fromarray(outputs)
# im.show()

6 [     -1    7854   10274  336133  457428 1102555]


In [44]:
unique_set = np.copy(np.unique(res))
# unique_len = len(unique_set)
# print()
# for i in range(unique_len):
#     index = np.argwhere(res == unique_set[i])
#     for ind in index:
#         res[ind[0], ind[1]] = i

class_ids = []
hash_dict = {'4': 0, '1':1, '2':2, '3':3}
for index, val in enumerate(unique_set):
    if index == 0:
        stack = (res == val).astype(int)
    else:
        stack = np.dstack((stack, (res == val).astype(int)))
    pos = np.argwhere(res == val)[0]
    class_ids.append(hash_dict[str(mask[pos[0], pos[1]])])
        


In [45]:
print(stack.shape)
print(class_ids)

(1201, 1201, 6)
[0, 3, 3, 1, 1, 3]


In [46]:
# save to mat file
res_dict = {'segmentation': stack, 'class_ids': class_ids}
scipy.io.savemat('./cedars-224/masks/0000_instance.mat', res_dict)

In [47]:
instance = scipy.io.loadmat('./cedars-224/masks/0000_instance.mat', 
                            mat_dtype=True, squeeze_me=True, struct_as_record=False)
print(instance['segmentation'].shape, instance['class_ids'])

(1201, 1201, 6) [0 3 3 1 1 3]


In [31]:
# visualize the segementation results
# for i in range(len(class_ids)):
#     outputs = convert_segmentation_array_to_png(stack[:,:,i], class_ids[i])
#     im = Image.fromarray(outputs)
#     im.show()

In [None]:
# # make a list of the class
# # to do adjust class number
# class_id = []
# for i in range(unique_len):
#     index = np.argwhere(res == i)[0]
#     class_id.append(mask[index[0], index[1]])
# print(class_id)

In [None]:
# seg = (res == 3).astype(int)
# outputs = convert_segmentation_array_to_png(seg, 0)
# im = Image.fromarray(outputs)
# im.show()