In [6]:
%load_ext autoreload
%autoreload 1

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
import os
import shutil
from multiprocessing import Pool

from tqdm import tqdm

from config import *
from mrf import *
from utilities import *

In [8]:
train_imageset_path = '../trainval/DAVIS/ImageSets/2017/train.txt'
val_imageset_path = '../trainval/DAVIS/ImageSets/2017/val.txt'
testd_imageset_path = '../testd/DAVIS/ImageSets/2017/test-dev.txt'

train_flow_root = '../flow/trainval/'
testd_flow_root = '../flow/test/'
result_root = '../result/mrf/'
train_image_root = '../trainval/DAVIS/JPEGImages/480p/'
train_mask_root = '../trainval/DAVIS/Annotations/480p/'
testd_image_root = '../testd/DAVIS/JPEGImages/480p/'
testd_mask_root = '../testd/DAVIS/Annotations/480p/'
rough_annotation_root = '../rough_annotation/osvos/'

train_list = []
val_list = []
testd_list = []

with open(train_imageset_path, 'r') as f:
    for line in f:
        train_list.append(line.strip())
with open(val_imageset_path, 'r') as f:
    for line in f:
        val_list.append(line.strip())
with open(testd_imageset_path, 'r') as f:
    for line in f:
        testd_list.append(line.strip())

print(val_list)

['bike-packing', 'blackswan', 'bmx-trees', 'breakdance', 'camel', 'car-roundabout', 'car-shadow', 'cows', 'dance-twirl', 'dog', 'dogs-jump', 'drift-chicane', 'drift-straight', 'goat', 'gold-fish', 'horsejump-high', 'india', 'judo', 'kite-surf', 'lab-coat', 'libby', 'loading', 'mbike-trick', 'motocross-jump', 'paragliding-launch', 'parkour', 'pigs', 'scooter-black', 'shooting', 'soapbox']


In [9]:
def compute_energy(task):
    t, x, y, k = task
    return energy(mask, osvos_mask, t, x, y, k)

In [10]:
for p in range(len(val_list)):
    if p != 15:
        continue
    image_path = os.path.join(train_image_root, val_list[p])
    mask_path = os.path.join(train_mask_root, val_list[p] + '/00000.png')
    osvos_path = os.path.join(rough_annotation_root, val_list[p])
    result_path = os.path.join(result_root, val_list[p])
    flow_path = os.path.join(train_flow_root, val_list[p])

    image_list = sorted(os.listdir(image_path))
    image_list = image_list[:10]
    mask = cv2.imread(mask_path)

    mask = np.expand_dims(mask, axis=0)
    mask = np.tile(mask, (len(image_list), 1, 1, 1))
    imgs = np.zeros_like(mask)
    osvos_mask = np.zeros_like(mask)
    mask, color_to_gray_map, gray_to_color_map = convert_to_gray_mask(mask)
    gray_imgs = np.zeros_like(mask)

    type_cnt = len(color_to_gray_map)
    print('type_cnt:', type_cnt)
    # TODO: type_cnt > 2

    for i in range(len(image_list)):
        osvos_mask[i] = cv2.imread(os.path.join(osvos_path, f"{i:05d}.png"))

    osvos_mask, _, _ = convert_to_gray_mask(osvos_mask)
    osvos_mask[0] = mask[0]

    resized_mask = np.zeros((mask.shape[0], Resize[1], Resize[0]))
    resized_osvos_mask = np.zeros((mask.shape[0], Resize[1], Resize[0]))

    for i in range(mask.shape[0]):
        resized_mask[i] = cv2.resize(mask[i], Resize, interpolation=cv2.INTER_NEAREST)
        resized_osvos_mask[i] = cv2.resize(osvos_mask[i], Resize, interpolation=cv2.INTER_NEAREST)

    mask = resized_mask
    osvos_mask = resized_osvos_mask
    flo = np.zeros_like(mask)
    flo = np.tile(np.expand_dims(flo, axis=-1), (1, 1, 1, 2))
    del resized_mask, resized_osvos_mask

    resized_imgs = np.zeros((mask.shape[0], Resize[1], Resize[0], 3), dtype=np.uint8)
    for i in range(len(image_list)):
        cv2.imread(os.path.join(image_path, f"{i:05d}.jpg"))
        imgs[i] = cv2.imread(os.path.join(image_path, f"{i:05d}.jpg"))
        imgs[i] = cv2.cvtColor(imgs[i], cv2.COLOR_BGR2RGB)

    print(imgs[i].shape)

    for i in range(len(image_list)):
        resized_imgs[i] = cv2.resize(imgs[i], Resize, interpolation=cv2.INTER_NEAREST)

    imgs = resized_imgs
    del resized_imgs
    gray_imgs = np.zeros((mask.shape[0], Resize[1], Resize[0]), dtype=np.uint8)

    for i in range(len(image_list)):
        gray_imgs[i] = cv2.cvtColor(imgs[i], cv2.COLOR_RGB2GRAY)

    for i in range(len(image_list) - 1):
        flo[i] = cv2.calcOpticalFlowFarneback(gray_imgs[i], gray_imgs[i + 1], None, 0.5, 3, 15, 3, 5, 1.2, 0)

    norm = np.linalg.norm(flo, axis=-1)
    print(norm.max(), norm.mean(), norm.var())
    flo = filter_unreliable_flow(flo)
    norm = np.linalg.norm(flo, axis=-1)
    print(norm.max(), norm.mean(), norm.var())

    init(flo, mask)

    print('Start ICM...')

    for i in tqdm(range(ICM_iter)):
        tasks = [(t, x, y, k) for t in range(len(image_list)) for x in range(mask.shape[1]) for y in
                 range(mask.shape[2]) for k in range(type_cnt)]
        with Pool() as pool:
            e_results = np.array(list(pool.map(compute_energy, tasks)))
        e_results = np.array(e_results).reshape((len(image_list), mask.shape[1], mask.shape[2], type_cnt))
        mask = np.argmin(e_results, axis=-1)
    
    print('ICM done')
    
    # write
    if os.path.exists(result_path):
        shutil.rmtree(result_path)
    os.makedirs(result_path)
    
    resized_mask = np.zeros((mask.shape[0], OriginalSize[1], OriginalSize[0]))
    for i in range(len(image_list)):
        resized_mask[i] = cv2.resize(mask[i], OriginalSize, interpolation=cv2.INTER_NEAREST)
    
    for i in range(len(image_list)):
        result_i_path = result_path + f"/{i:05d}.png"
        cv2.imwrite(result_i_path, restore_color_mask(resized_mask[i], gray_to_color_map))
    
    del resized_mask
    
    break


type_cnt: 3
(480, 854, 3)
9.75387896024259 2.1221505112863 1.7667939514830064
5.397999285277433 1.8829396540998464 1.1169505493136194
Start ICM...


100%|██████████| 5/5 [00:11<00:00,  2.24s/it]


ICM done
