In [1]:
from tqdm import tqdm
from glob import glob

import copy
import numpy as np
import cv2
import pandas as pd
import tensorflow as tf
import os
import seaborn as sb
import matplotlib.pyplot as plt
import collections

from makiflow.augmentation.segmentation.balancing.hc_scanner import HCScanner
from makiflow.augmentation.segmentation.augment_ops import AffineAugment, ElasticAugment, FlipAugment
from makiflow.augmentation.segmentation.data_provider import Data 
from makiflow.augmentation.segmentation.image_mask_cutter import ImageCutter
from makiflow.tf_scripts import set_main_gpu

In [2]:
def mutate_masks(masks, mapping):
    """
    Remaps classes on the given `masks` according to the `mapping`.
    Parameters
    ----------
    masks : list or numpy.array
        List or numpy array of masks.
    mapping : list
        List of tuples: [(source_class_number, new_class_number)],
        where `source_class_number` will be changed to `new_class_number` in the `masks`.
    Returns
    ---------
    new_masks : the same type as `masks`
        New masks with changed class numbers.
    """
    if type(mapping) is not list or (len(mapping) != 0 and type(mapping[0]) is not tuple):
        raise TypeError('mapping should be list of typles')

    new_masks = copy.deepcopy(masks)

    for i in range(len(new_masks)):
        for elem in mapping:
            old_value = elem[0]
            new_value = elem[1]
            new_masks[i][masks[i] == old_value] = new_value

    return  new_masks

In [40]:
batch_num = 3

In [41]:
mask_names = glob(f'/raid/rustam/med_data/unbalanced_batches/batch_{batch_num}/original_data/masks/*.bmp')

In [42]:
masks = []
images = []
for mask_name in mask_names:
    mask = cv2.imread(mask_name, cv2.IMREAD_GRAYSCALE)
    mask = cv2.resize(mask, (1024, 1024), interpolation=cv2.INTER_NEAREST)
    image = cv2.imread(mask_name.replace('masks', 'images'))
    image = cv2.resize(image, (1024, 1024))
    masks.append(mask)
    images.append(image)

In [43]:
masks = mutate_masks(masks, [(0, 0), (1, 1), (2, 2), (3, 3), (4, 0), (5, 4), (6, 5), (7, 6), (8, 7), (9, 8)])

In [44]:
uniq, counts = np.unique(masks, return_counts=True)
uniq, counts 

(array([0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=uint8),
 array([86893907,  2978643,  1104453, 10559481,   917278,   224188,
           16700,    42817,  1071557]))

In [45]:
assert len(uniq) == 9

In [46]:
has_5_class_indicies = []
has_6_class_indicies = []
has_7_class_indicies = []

In [47]:
for i, mask in enumerate(masks):
    uniq = np.unique(mask)
    if 5 in uniq:
        has_5_class_indicies.append(i)
    if 6 in uniq:
        has_6_class_indicies.append(i)
    if 7 in uniq:
        has_7_class_indicies.append(i)

In [48]:
len(has_5_class_indicies), len(has_6_class_indicies), len(has_7_class_indicies)

(36, 6, 9)

In [49]:
def classes_vector(masks):
    distribution = np.zeros((9, ))
    for mask in masks:
        uniques = np.unique(mask)
        for uniq in uniques:
            distribution[uniq] += 1
    return distribution

In [50]:
distribution = classes_vector(masks)

In [51]:
distribution = distribution / distribution[0]

In [52]:
distribution

array([1.        , 1.        , 0.96969697, 1.        , 0.90909091,
       0.36363636, 0.06060606, 0.09090909, 0.95959596])

In [53]:
np.unique(np.array(masks)[has_6_class_indicies])

array([0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=uint8)

In [54]:
def f():
    global masks, images
    distribution = classes_vector(masks)
    distribution = distribution / distribution[0]
    while distribution[5] < 0.50 or distribution[6] < 0.50 or distribution[7] < 0.50:
        while distribution[5] < 0.50:
            has_5_data = Data(np.array(images)[np.random.choice(has_5_class_indicies, size=5)], np.array(masks)[np.random.choice(has_5_class_indicies, size=6)])
            data = ElasticAugment(alpha=700, noise_invert_scale=7, std=11, border_mode='reflect_101', num_maps=1, keep_old_data=False)(has_5_data)
            images_5, masks_5 = data.get_data()
            images = np.concatenate((images, images_5))
            masks = np.concatenate((masks, masks_5))  
            distribution = classes_vector(masks)
            distribution = distribution / distribution[0]
            print(distribution)
            if distribution[5] > .41 and distribution[6] >.41 and distribution[7] > .31:
                return images, masks
        while distribution[6] < 0.50:
            has_6_data = Data(np.array(images)[np.random.choice(has_6_class_indicies, size=5)], np.array(masks)[np.random.choice(has_6_class_indicies, size=6)])
            data = ElasticAugment(alpha=700, noise_invert_scale=7, std=11, border_mode='reflect_101', num_maps=1, keep_old_data=False)(has_6_data)
            images_6, masks_6 = data.get_data()
            images = np.concatenate((images, images_6))
            masks = np.concatenate((masks, masks_6)) 
            distribution = classes_vector(masks)
            distribution = distribution / distribution[0]
            print(distribution)
            if distribution[5] > .41 and distribution[6] >.41 and distribution[7] > .31:
                return images, masks
        while distribution[7] < 0.50:
            has_7_data = Data(np.array(images)[np.random.choice(has_7_class_indicies, size=5)], np.array(masks)[np.random.choice(has_7_class_indicies, size=6)])
            data = ElasticAugment(alpha=700, noise_invert_scale=7, std=11, border_mode='reflect_101', num_maps=1, keep_old_data=False)(has_7_data)
            images_7, masks_7 = data.get_data()
            images = np.concatenate((images, images_7))
            masks = np.concatenate((masks, masks_7))
            distribution = classes_vector(masks)
            distribution = distribution / distribution[0]
            print(distribution)
            if distribution[5] > .41 and distribution[6] >.41 and distribution[7] > .31:
                return images, masks

images, masks = f()
print(distribution)

[1.         1.         0.97115385 1.         0.91346154 0.39423077
 0.06730769 0.08653846 0.96153846]
[1.         1.         0.97247706 1.         0.91743119 0.42201835
 0.06422018 0.09174312 0.96330275]
[1.         1.         0.97368421 1.         0.92105263 0.44736842
 0.06140351 0.09649123 0.96491228]
[1.         1.         0.96638655 1.         0.92436975 0.47058824
 0.05882353 0.09243697 0.96638655]
[1.         1.         0.96774194 1.         0.91935484 0.49193548
 0.05645161 0.09677419 0.96774194]
[1.         1.         0.96899225 1.         0.92248062 0.51162791
 0.05426357 0.09302326 0.96899225]
[1.         1.         0.96268657 1.         0.92537313 0.50746269
 0.08955224 0.08955224 0.97014925]
[1.         1.         0.96402878 1.         0.92805755 0.50359712
 0.12230216 0.09352518 0.97122302]
[1.         1.         0.96527778 1.         0.93055556 0.48611111
 0.15277778 0.09722222 0.97222222]
[1.         1.         0.96644295 1.         0.93288591 0.47651007
 0.18120805 0.0

In [55]:
distribution = classes_vector(masks)
distribution = distribution / distribution[0]
distribution

array([1.        , 0.9984472 , 0.91925466, 1.        , 0.98136646,
       0.47360248, 0.41149068, 0.37111801, 0.99378882])

In [56]:
len(masks)

644

In [57]:
os.makedirs(f'/raid/rustam/med_data/upsampling_batches/paper_batch_{batch_num}/train_set/orig_set/images/', exist_ok=True)
os.makedirs(f'/raid/rustam/med_data/upsampling_batches/paper_batch_{batch_num}/train_set/orig_set/masks/', exist_ok=True)

for i, (img, mask) in tqdm(enumerate(zip(images, masks))):
    cv2.imwrite(f'/raid/rustam/med_data/upsampling_batches/paper_batch_{batch_num}/train_set/orig_set/images/{i}.bmp', img)
    cv2.imwrite(f'/raid/rustam/med_data/upsampling_batches/paper_batch_{batch_num}/train_set/orig_set/masks/{i}.bmp', mask)    

644it [00:01, 338.07it/s]


SyntaxError: invalid syntax (<ipython-input-19-67b8243dc594>, line 1)