In [2]:
import os
import time
import math
import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import mode
from sklearn.cluster import KMeans
from tqdm import tqdm
import seaborn as sns
import pandas as pd
import torch
import shelve
from carvekit.api.high import HiInterface


from PIL import Image

from carvekit.api.interface import Interface
from carvekit.ml.wrap.fba_matting import FBAMatting
from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
from carvekit.ml.wrap.u2net import U2NET
from carvekit.pipelines.postprocessing import MattingMethod
from carvekit.pipelines.preprocessing import PreprocessingStub
from carvekit.trimap.generator import TrimapGenerator

In [3]:
# manual
# Check doc strings for more information
object_type = "object"
_device='cuda' if torch.cuda.is_available() else 'cpu'

if object_type == 'hair_like':
    seg_net = U2NET(device=_device, batch_size=5,
                    input_image_size=320, fp16=True)
else:
    seg_net = TracerUniversalB7(
        device=_device, batch_size=5, input_image_size=640, fp16=True)


fba = FBAMatting(device=_device,
                 input_tensor_size=2048,
                 batch_size=1)

trimap = TrimapGenerator(prob_threshold=231,
                         kernel_size=30,
                         erosion_iters=5)

preprocessing = PreprocessingStub()

postprocessing = MattingMethod(matting_module=fba,
                               trimap_generator=trimap,
                               device=_device)
interface = Interface(pre_pipe=preprocessing,
                      post_pipe=postprocessing,
                      seg_pipe=seg_net)


In [4]:
# input video path
DIR_PATH = 'videos'
FILENAME = 'test1.mp4'
INPUT_VIDEO_PATH = os.path.join(DIR_PATH, FILENAME)
OUTPUT_PATH = 'out'

OUTPUT_FOREGROUND_VIDEO_NAME = FILENAME.split('.')[0] + '_foreground.mp4'
OUTPUT_BACKGROUND_VIDEO_NAME = FILENAME.split('.')[0] + '_background.mp4'
OUTPUT_PANORAMA_IMG_NAME = FILENAME.split('.')[0] + '_panorama.png'
HOLE_FILLED_BACKGROUND_VIDEO_NAME = FILENAME.split('.')[0] + '_filled_background_video.mp4'

OUTPUT_FOREGROUND_VIDEO_PATH = os.path.join(OUTPUT_PATH, OUTPUT_FOREGROUND_VIDEO_NAME)
OUTPUT_BACKGROUND_VIDEO_PATH = os.path.join(OUTPUT_PATH, OUTPUT_BACKGROUND_VIDEO_NAME)
OUTPUT_PANORAMA_IMG_PATH = os.path.join(OUTPUT_PATH, OUTPUT_PANORAMA_IMG_NAME)
HOLE_FILLED_BACKGROUND_VIDEO_PATH = os.path.join(OUTPUT_PATH, HOLE_FILLED_BACKGROUND_VIDEO_NAME)

H_PERSIST_FILENAME = 'H_persist'

In [5]:
def read_frames(path):
    """
    return video frames
    """
    cap = cv.VideoCapture(path)
    if not cap.isOpened():
        raise IOError("Open video failed!")

    frames = []
    while True:
        ret, frame = cap.read()
        if not ret or frame is None:
            break

        frames.append(frame)
        
    cap.release()
    return frames


cap = cv.VideoCapture(INPUT_VIDEO_PATH)
if not cap.isOpened():
    raise IOError("Open video failed!")

fps = int(cap.get(cv.CAP_PROP_FPS))
width = int(cap.get(cv.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv.CAP_PROP_FRAME_HEIGHT))

ret, img1 = cap.read()
ret, img2 = cap.read()

cap.release()

In [11]:
gray2 = cv.cvtColor(img2, cv.COLOR_BGR2GRAY)
indice = gray2 > 0
print(indice)

[[ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 ...
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]]


In [12]:
img1[indice] = img2[indice]
cv.imshow('asdf', img1)
cv.imshow('qwer', img2)
cv.waitKey(0)
cv.destroyAllWindows()

In [7]:
def get_foreground_labels(path, early_quit_frame_number=math.inf, early_quit_time=None):
    cap = cv.VideoCapture(path)
    if not cap.isOpened:
        raise IOError("Open video failed!")

    foreground_labels = []
    startTime = time.time()
    try:
        for _ in tqdm(range(int(min(cap.get(cv.CAP_PROP_FRAME_COUNT), early_quit_frame_number)))):
            if early_quit_time and time.time() - startTime > early_quit_time:
                break
            
            ret, frame = cap.read()
            if not ret:
                raise IOError("Read frame failed!")
                
            frame = cv.cvtColor(frame, cv.COLOR_BGR2RGB)
            image = Image.fromarray(frame, mode="RGB")
            resImg = interface([image])[0]
            foreground = np.asarray(resImg)
            foreground_label = set()
            for i in range(len(foreground)):
                for j in range(len(foreground[0])):
                    if foreground[i][j][-1] != 0:
                        foreground_label.add((i, j))
            foreground_labels.append(foreground_label)

    except KeyboardInterrupt:
        print('Interrupted!')    

    cap.release()
    cv.destroyAllWindows()
    return foreground_labels

In [8]:
# # compute foreground labels
# foreground_labels = get_foreground_labels(INPUT_VIDEO_PATH)

In [9]:
# # save foreground labels
# foreground_label_persistence_file = '{}_foreground_labels.npy'.format(FILENAME.split('.')[0])
# np.save(foreground_label_persistence_file, [list(label) for label in foreground_labels])

In [10]:
# load foreground_labels
# foreground_label_persistence_file = '{}_foreground_labels.npy'.format(FILENAME.split('.')[0])
# foreground_labels = np.load(foreground_label_persistence_file, allow_pickle=True)
# # convert foreground_labels to list of set
# foreground_labels = [set(_) for _ in foreground_labels]

In [11]:
def generate_foreground_background_videos(foreground_labels):
    if not os.path.exists('out'):
        os.makedirs('out')

    cap = cv.VideoCapture(INPUT_VIDEO_PATH)
    if not cap.isOpened:
        raise IOError("Open video failed!")

    foreground_out = cv.VideoWriter(OUTPUT_FOREGROUND_VIDEO_PATH, cv.VideoWriter_fourcc(*'XVID'), int(cap.get(cv.CAP_PROP_FPS)),
                                (int(cap.get(cv.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv.CAP_PROP_FRAME_HEIGHT))))
    background_out = cv.VideoWriter(OUTPUT_BACKGROUND_VIDEO_PATH, cv.VideoWriter_fourcc(*'XVID'), int(cap.get(cv.CAP_PROP_FPS)),
                                (int(cap.get(cv.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv.CAP_PROP_FRAME_HEIGHT))))
    if not foreground_out.isOpened() or not background_out.isOpened():
        raise IOError("Init videoWriter failed!")

    try:
        for i in tqdm(range(len(foreground_labels))):
            ret, frame = cap.read()
            if not ret:
                raise IOError("Read frame failed!")
                
            label = foreground_labels[i]
            foreground = np.zeros_like(frame)
            foreground.fill(255)
            for (j, k) in label:
                foreground[j][k] = frame[j][k]
                frame[j][k] = [255,255,255]

            foreground_out.write(foreground.copy())
            background_out.write(frame.copy())
            
    except KeyboardInterrupt:
        print('Interrupted!')    

    cap.release()
    foreground_out.release()
    background_out.release()
    cv.destroyAllWindows()

In [12]:
# generate_foreground_background_videos(foreground_labels)

In [13]:
def get_H_matrix(img1, img2):
    # minimum number of matches we want find between these two images
    MIN_MATCH_COUNT = 10

    # initiate feature detector, currently use SIFT, may try orb later
    sift = cv.SIFT_create()

    # find the keypoints and descriptors with SIFT
    kp1, des1 = sift.detectAndCompute(cv.cvtColor(img1, cv.COLOR_BGR2GRAY),None)
    kp2, des2 = sift.detectAndCompute(cv.cvtColor(img2, cv.COLOR_BGR2GRAY),None)


    # BF_Matcher = cv.BFMatcher()
    # InitialMatches = BF_Matcher.knnMatch(des1,des2,k=2)
    # good = []
    # for m, n in InitialMatches:
    #     if m.distance <0.75 * n.distance:
    #         good.append([m])
    # if len(good) < 4:
    #     raise ValueError('asdfasdfasdfasdgdfagfdsg')
    #     exit(0)
    # src_pts = []
    # dst_pts = []
    # for m in good:
    #     src_pts.append(kp1[m[0].queryIdx].pt)
    #     dst_pts.append(kp2[m[0].trainIdx].pt)
    
    # src_pts = np.float32(src_pts)
    # dst_pts = np.float32(dst_pts)
    # H, _ = cv.findHomography(dst_pts, src_pts, cv.RANSAC, 4.0)

    FLANN_INDEX_KDTREE = 1
    index_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 5)
    search_params = dict(checks = 50)
    flann = cv.FlannBasedMatcher(index_params, search_params)
    matches = flann.knnMatch(des1,des2,k=2)

    # store all the good matches as per Lowe's ratio test.
    good = []
    for m,n in matches:
        if m.distance < 0.7*n.distance:
            good.append(m)

    if len(good)>MIN_MATCH_COUNT:
        src_pts = np.float32([ kp1[m.queryIdx].pt for m in good ]).reshape(-1,1,2)
        dst_pts = np.float32([ kp2[m.trainIdx].pt for m in good ]).reshape(-1,1,2)
        H, _ = cv.findHomography(src_pts, dst_pts, cv.RANSAC,5.0)
        # matchesMask = mask.ravel().tolist()
        # h,w = img1.shape
        # pts = np.float32([ [0,0],[0,h-1],[w-1,h-1],[w-1,0] ]).reshape(-1,1,2)
        # dst = cv.perspectiveTransform(pts,H)
        # img2 = cv.polylines(img2,[np.int32(dst)],True,255,3, cv.LINE_AA)
    else:
        raise ValueError( "Not enough matches are found - {}/{}".format(len(good), MIN_MATCH_COUNT) )








        # matchesMask = None

    # draw_params = dict(matchColor = (0,255,0), # draw matches in green color
    #                singlePointColor = None,
    #                matchesMask = matchesMask, # draw only inliers
    #                flags = 2)
    # img3 = cv.drawMatches(img1,kp1,img2,kp2,good,None,**draw_params)
    # plt.imshow(img3, 'gray'),plt.show()


    return H

In [14]:
# compute all H per gap
# Hs = get_all_H_matrix_per_step(INPUT_VIDEO_PATH, gap)

In [15]:
# save Hs
# np.save('{}_all_H.npy'.format(FILENAME.split('.')[0]), Hs)

In [16]:
# load Hs
# Hs = np.load('{}_all_H.npy'.format(FILENAME.split('.')[0]))

In [17]:
def fill_holes_with_H_cached(step, grap, max_iteration_depth, persist_file, start_index = 0, end_index = 10000000):
    with shelve.open(writeback=True, filename=persist_file) as d:
        if FILENAME not in d:
            d[FILENAME] = {}
        m = d[FILENAME]
        return fill_holes(step, grap, max_iteration_depth, m, start_index, end_index)


# fill per step frame's hole using adjacent frames per gap, with max_iteration_depth
def fill_holes(step, gap, max_iteration_depth, m = None, start_index = 0, end_index = 10000000):
    """
    :param step: controls the distance between each two frames for which we are filling their holes
    :param gap: is the distance unit between target frame and current frame when filling hole
    :param max_iteration_depth: we iterate on one frame for at most max_iteration_depth, even when we haven't filled the hole completely. This is experimental feature to control the time complexity of the method
    :param m: a dictionary to persist H matrix between (i, j) frames
    :return: a list of hole_filled frames
    """
    frames = read_frames(OUTPUT_BACKGROUND_VIDEO_PATH)
    res = []
    all_filled_count = 0
    count = 0
    end_index = min(len(frames), end_index)

    if len(frames) != len(foreground_labels):
        raise IndexError("foreground_labels length not equal to frames!")

    if len(frames) < 1 + max_iteration_depth * gap:
        raise IndexError("video number of frames too small!")

    for i in tqdm(range(start_index, end_index, step)):
        count += 1
        cur_label = foreground_labels[i].copy()
        iteration_depth = 0
        left = 1
        right = 1
        # True means right, False means left
        last_direction = False
        # print('at frame {}'.format(i))
        while len(cur_label) > 0 and iteration_depth < max_iteration_depth:
            # we follow right, left, right, left pattern
            target_index = i + right * gap if not last_direction else i - left * gap

            if 0 <= target_index < len(frames):
                # print('cur_label {} remaining to fill'.format(len(cur_label)))
                target_label = foreground_labels[target_index]
                srcs = [list(t) for t in cur_label]
                # point x = j, y = i, so reverse below in a[::-1]
                #perspectivetransform needs 3 dimensions, so we must manually wrap a nonsense dimension to make below 1 * N * 2
                srcpts = np.array([[a[::-1] for a in srcs]]).astype(np.float32)
                if m is not None:
                    if (i, target_index) in m:
                        H = m[(i, target_index)]
                    else:
                        H = get_H_matrix(frames[i], frames[target_index])
                        m[(i, target_index)] = H
                else:
                    H = get_H_matrix(frames[i], frames[target_index])
                dstpts = np.rint(cv.perspectiveTransform(srcpts, H)[0]).astype(int)
                for (x, y), (corresponding_x, corresponding_y) in zip(np.rint(srcpts[0]).astype(int), dstpts):
                    if (corresponding_y, corresponding_x) not in target_label and 0 <= corresponding_x < width and 0 <= corresponding_y < height:
                        cur_label.remove((y, x))
                        frames[i][y][x] = frames[target_index][corresponding_y][corresponding_x]
                iteration_depth += 1

                # if this time to right
                if not last_direction:
                    right += 1
                else:
                    left += 1

            last_direction = not last_direction

        res.append(frames[i])
        # print("cur_label finished, good to go!" if len(cur_label) == 0 else 'cur_label {} not filled, go anyway'.format(len(cur_label)))
        all_filled_count += 1 if len(cur_label) == 0 else 0

    print('total number of frames filled: {}'.format(count))
    print('all_filled frames count: {}'.format(all_filled_count))
    print('{} frames not fully filled'.format(count - all_filled_count))
    return res


# filled_frames = fill_holes_with_H_cached(60, 6, 6, persist_file=H_PERSIST_FILENAME)

In [18]:
# fill hole for all frames in range(start_index : end_index : 1), end_index frame is not included
def generate_hole_filled_background_images_between_indices(folder_path, start_index, end_index):
    if not os.path.exists(folder_path):
        os.mkdir(folder_path)
    filled_frames = fill_holes_with_H_cached(1, 6, 30, persist_file=H_PERSIST_FILENAME, start_index = start_index, end_index = end_index)
    for i in range(len(filled_frames)):
        cv.imwrite(os.path.join(folder_path, '{}_background_{}.jpg'.format(FILENAME.split('.')[0], i + start_index)), filled_frames[i])

# generate_hole_filled_background_images_between_indices('out_images', 0, 10)

In [19]:
# print(os.path.join('out_images', '{}_background_{}.jpg'.format(FILENAME.split('.')[0], 1)))

In [20]:
# print(len(filled_frames))
# for f in filled_frames:
#     cv.imshow('asdf', f)
#     cv.waitKey(0)
#     cv.destroyAllWindows()
#     cv.waitKey(1)

In [21]:
# H = get_H_matrix(filled_frames[1], filled_frames[0])
# cv.imshow('src', filled_frames[1])
# cv.imshow('dst', filled_frames[0])
# cv.waitKey(0)
# cv.destroyAllWindows()
# warped = cv.warpPerspective(filled_frames[1], H, (width, height))
# cv.imshow('src warped', warped)
# cv.waitKey(0)
# cv.destroyAllWindows()

In [22]:
def warpPerspectivePadded(
        src, dst, M,
        flags=cv.INTER_LINEAR,
        borderMode=cv.BORDER_CONSTANT,
        borderValue=0):
    assert M.shape == (3, 3), \
        'Perspective transformation shape should be (3, 3).\n' \
        + 'Use warpAffinePadded() for (2, 3) affine transformations.'

    M = M / M[2, 2]  # ensure a legal homography
    if flags in (cv.WARP_INVERSE_MAP,
                 cv.INTER_LINEAR + cv.WARP_INVERSE_MAP,
                 cv.INTER_NEAREST + cv.WARP_INVERSE_MAP):
        M = cv.invert(M)[1]
        flags -= cv.WARP_INVERSE_MAP

    # it is enough to find where the corners of the image go to find
    # the padding bounds; points in clockwise order from origin
    src_h, src_w = src.shape[:2]
    lin_homg_pts = np.array([
        [0, src_w, src_w, 0],
        [0, 0, src_h, src_h],
        [1, 1, 1, 1]])

    # transform points
    transf_lin_homg_pts = M.dot(lin_homg_pts)
    transf_lin_homg_pts /= transf_lin_homg_pts[2, :]

    # find min and max points
    min_x = np.floor(np.min(transf_lin_homg_pts[0])).astype(int)
    min_y = np.floor(np.min(transf_lin_homg_pts[1])).astype(int)
    max_x = np.ceil(np.max(transf_lin_homg_pts[0])).astype(int)
    max_y = np.ceil(np.max(transf_lin_homg_pts[1])).astype(int)

    # add translation to the transformation matrix to shift to positive values
    anchor_x, anchor_y = 0, 0
    transl_transf = np.eye(3, 3)
    if min_x < 0:
        anchor_x = -min_x
        transl_transf[0, 2] += anchor_x
    if min_y < 0:
        anchor_y = -min_y
        transl_transf[1, 2] += anchor_y
    shifted_transf = transl_transf.dot(M)
    shifted_transf /= shifted_transf[2, 2]

    # create padded destination image
    dst_h, dst_w = dst.shape[:2]

    padding = [anchor_y, max(max_y, dst_h) - dst_h,
                  anchor_x, max(max_x, dst_w) - dst_w]

    dst_padded = cv.copyMakeBorder(dst, *padding,
                                    borderType=borderMode, value=borderValue)
    
    dst_pad_h, dst_pad_w = dst_padded.shape[:2]
    src_warped = cv.warpPerspective(
        src, shifted_transf, (dst_pad_w, dst_pad_h),
        flags=flags, borderMode=borderMode, borderValue=borderValue)

    return dst_padded, src_warped, padding

In [23]:
def stitch(dst, src, H, flags=cv.INTER_LINEAR,
        borderMode=cv.BORDER_CONSTANT,
        borderValue=0):
    assert H.shape == (3,3)
    H = H / H[2, 2]  # ensure a legal homography
    src_h, src_w = src.shape[:2]
    lin_homg_pts = np.array([
        [0, src_w, src_w, 0],
        [0, 0, src_h, src_h],
        [1, 1, 1, 1]])
        # transform points
    transf_lin_homg_pts = H.dot(lin_homg_pts)
    transf_lin_homg_pts /= transf_lin_homg_pts[2, :]

    # find min and max points
    min_x = np.floor(np.min(transf_lin_homg_pts[0])).astype(int)
    min_y = np.floor(np.min(transf_lin_homg_pts[1])).astype(int)
    max_x = np.ceil(np.max(transf_lin_homg_pts[0])).astype(int)
    max_y = np.ceil(np.max(transf_lin_homg_pts[1])).astype(int)

    # add translation to the transformation matrix to shift to positive values
    anchor_x, anchor_y = 0, 0
    transl_transf = np.eye(3, 3)
    if min_x < 0:
        anchor_x = -min_x
        transl_transf[0, 2] += anchor_x
    if min_y < 0:
        anchor_y = -min_y
        transl_transf[1, 2] += anchor_y
    shifted_transf = transl_transf.dot(H)
    shifted_transf /= shifted_transf[2, 2]

    dst_h, dst_w = dst.shape[:2]
    padding = [anchor_y, max(max_y, dst_h) - dst_h,
                  anchor_x, max(max_x, dst_w) - dst_w]
    
    stitched_image = cv.warpPerspective(
        src, shifted_transf, (dst_w + padding[2] + padding[3], dst_h + padding[0] + padding[1]),
        flags=flags, borderMode=borderMode, borderValue=borderValue)
    
    for i in range(0, dst_h):
        for j in range(0, dst_w):
            if any(dst[i][j]):
                stitched_image[i + padding[0]][j + padding[2]] = dst[i][j]

    # for i in range(padding[0], padding[0] + dst_h):
    #     for j in range(padding[2], padding[2] + dst_w):
    #         if not any(stitched_image[i][j]):
    #             stitched_image[i][j] = dst[i - padding[0]][j - padding[2]]
    # stitched_image[padding[0]:padding[0] + dst_h, padding[2]:padding[2] + dst_w] = dst
    return stitched_image, shifted_transf

In [24]:
# need fix, issue introduced by my change, original was right
def stitch_two_images(dst, src, old_padding, H):
    print('old panorama shape:', dst.shape)
    dst_padded, src_warped, tmp_padding = warpPerspectivePadded(src, dst, H)
    new_padding = [max(a,b) for a,b in zip(tmp_padding, old_padding)]
    if any([a > b for a, b in zip(new_padding, old_padding)]):
        dst_padded = cv.copyMakeBorder(dst, *[max(a - b, 0) for a, b in zip(new_padding, old_padding)], borderType=cv.BORDER_CONSTANT, value=0)
    # here we can rest assured that old_padding is always greater or equal to new_padding
    print('old_padding:', old_padding, 'tmp_padding:', tmp_padding, 'new padding:', new_padding)
    print('src wrapped shape:', src_warped.shape)
    print('new panorama shape:', dst_padded.shape)
    assert dst_padded.shape >= src_warped.shape
    for i in range(len(src_warped)):
        for j in range(len(src_warped[0])):
            if any(src_warped[i][j]) and not any(dst_padded[i + new_padding[0] - tmp_padding[0]][j + new_padding[2] - tmp_padding[2]]):
                dst_padded[i + new_padding[0] - tmp_padding[0]][j + new_padding[2] - tmp_padding[2]] = src_warped[i][j]
    return dst_padded, new_padding

# map src to dst using H, with dst's original padding parameter given for warped src to fill external hole in dst with new padding
# this method requirs the dst image does not have padding, *conceptually*(i.e. if you view dst image's padding as its content, then you are free to go)
def worker(dst, src, H):
    dst_padded, src_warped, padding = warpPerspectivePadded(src, dst, H)
    assert dst_padded.shape == src_warped.shape and len(padding) == 4
    for i in range(len(src_warped)):
        for j in range(len(src_warped[0])):
            if i < padding[0] or i >= len(src_warped) - padding[1] or j < padding[2] or j >= len(src_warped[0]) - padding[3]:
                dst_padded[i][j] = src_warped[i][j]
    return dst_padded, padding

# map src to dst using H, with dst's original padding parameter given for warped src to fill external hole in dst with new padding
# this method requirs the dst image does not have padding, *conceptually*(i.e. if you view dst image's padding as its content, then you are free to go)
def tmp_worker(dst, src, H):
    dst_padded, src_warped, padding = warpPerspectivePadded(src, dst, H)
    assert dst_padded.shape == src_warped.shape and len(padding) == 4
    for i in range(len(src_warped)):
        for j in range(len(src_warped[0])):
            if any(src_warped[i][j]) and not any(dst_padded[i][j]):
                dst_padded[i][j] = src_warped[i][j]
    return dst_padded, padding


In [25]:
# get lower resolution image for faster H matrix computation
def scale_image(img, scale_factor=0.8):
    return cv.resize(img, [int(round(img.shape[1] * scale_factor)), int(round(img.shape[0] * scale_factor))], interpolation=cv.INTER_LINEAR_EXACT)

def scale_to_1920_1080(img):
    return cv.resize(img,(1920, 1080), interpolation=cv.INTER_LINEAR_EXACT)

In [26]:
def get_H_matrix_cached_version(i, j, i_frame, j_frame, m):
    if ((i, j)) in m:
        return m[(i, j)]
    H = get_H_matrix(i_frame, j_frame)
    m[(i, j)] = H
    return H


# H computed in this method would be persisted using shelve, would return panorama frame and its index
def generate_panorama(step=10, original_frames=None, hole_filled_frames=None, right_cache=None, left_cache=None):
    if original_frames is None:
        original_frames = read_frames(INPUT_VIDEO_PATH)
    if hole_filled_frames is None:
        hole_filled_frames = read_frames(HOLE_FILLED_BACKGROUND_VIDEO_PATH)
    if len(hole_filled_frames) != len(original_frames):
        raise IndexError("original frames and hole_filled background frames have different length! {}, {}".format(len(original_frames), len(hole_filled_frames)))
    
    try:
        with shelve.open(writeback=True, filename=H_PERSIST_FILENAME) as d:
            if FILENAME not in d:
                d[FILENAME] = {}
            m = d[FILENAME]
            anchor_index = len(original_frames) // 2

            # iterate from middle to right
            right = len(original_frames) - 1
            right_padded = hole_filled_frames[right].copy()
            for i in tqdm(range(right, anchor_index + step, -step)):
                H = get_H_matrix_cached_version(i, i - step, original_frames[i], original_frames[i - step], m)
                right_padded, padding_from_right = worker(hole_filled_frames[i - step], right_padded, H)
                if right_cache:
                    right_cache.append(right_padded.copy())
            r = i

            # iterate from left to middle
            left = 0
            left_padded = hole_filled_frames[left].copy()
            for i in tqdm(range(left, anchor_index - step, step)):
                H = get_H_matrix_cached_version(i, i + step, original_frames[i], original_frames[i + step], m)
                left_padded, padding_from_left = worker(hole_filled_frames[i + step], left_padded, H)
                if left_cache:
                    left_cache.append(left_padded.copy())
            l = i

            assert l < anchor_index < r

            # now only left_padded, anchor frame and right_padded need to blend
            # first blend right_padded with anchor
            H = get_H_matrix_cached_version(r, anchor_index, original_frames[r], original_frames[anchor_index], m)
            right_padded, padding_from_right = worker(hole_filled_frames[anchor_index], right_padded, H)
            # then blend left_padded with above result
            H = get_H_matrix_cached_version(l, anchor_index, original_frames[l], original_frames[anchor_index], m)
            left_padded, padding_from_left = worker(right_padded, left_padded, H)

            # return panorama and its final padding info
            return left_padded, padding_from_left
    except KeyboardInterrupt:
        print('Interrupted!')









def generate_panorama_from_middle_to_both_sides(shifted_H_persist, right_step = 10, left_step=10, original_frames=None, 
                    hole_filled_frames=None, right_cache=None, left_cache=None, clear_cache=False, anchor_index=-1):
    if original_frames is None:
        original_frames = read_frames(INPUT_VIDEO_PATH)
    if hole_filled_frames is None:
        hole_filled_frames = read_frames(HOLE_FILLED_BACKGROUND_VIDEO_PATH)
    if len(hole_filled_frames) != len(original_frames):
        raise IndexError("original frames and hole_filled background frames have different length! {}, {}".format(len(original_frames), len(hole_filled_frames)))
    
    try:
        with shelve.open(writeback=True, filename=H_PERSIST_FILENAME) as d:
            # using original video H cache
            if FILENAME not in d:
                d[FILENAME] = {}
            if 'original' not in d[FILENAME]:
                d[FILENAME]['original'] = {}
            m = d[FILENAME]['original']
            anchor_index = len(original_frames) // 2 if anchor_index == -1 else anchor_index
            if 'panorama{}'.format(anchor_index) not in d[FILENAME]:
                d[FILENAME]['panorama{}'.format(anchor_index)] = {}
            panorama_cache = d[FILENAME]['panorama{}'.format(anchor_index)]

            if clear_cache:
                m.clear()
                panorama_cache.clear()

            dst = hole_filled_frames[anchor_index].copy()
            left = anchor_index - left_step
            right = anchor_index + right_step
            last_direction = False
            padding = [0] * 4
            count = 1

            while left >= 0 or right < len(original_frames):
                print('iteration {}, time: {}'.format(count, time.time()))
                count += 1
                # right, left, right, left pattern
                if not last_direction:
                    if right < len(original_frames):
                        if right in panorama_cache:
                            print('cache hit')
                            H = panorama_cache[right]
                        else:
                            print('cache missed')
                            H = get_H_matrix(scale_image(hole_filled_frames[right], scale_factor=1), scale_image(dst, scale_factor=1))
                            panorama_cache[right] = H
                        dst, shifted_H = stitch(dst, hole_filled_frames[right], H)
                        shifted_H_persist[(right, anchor_index)] = shifted_H
                        right_cache.append(dst)
                        right += right_step
                else:
                    if left >= 0:
                        if left in panorama_cache:
                            print('cache hit')
                            H = panorama_cache[left]
                        else:
                            print('cache missed')
                            H = get_H_matrix(scale_image(hole_filled_frames[left], scale_factor=1), scale_image(dst, scale_factor=1))
                            panorama_cache[left] = H
                        dst, shifted_H = stitch(dst, hole_filled_frames[left], H)
                        shifted_H_persist[(left, anchor_index)] = shifted_H
                        left_cache.append(dst)
                        left -= left_step
                last_direction = not last_direction
            return dst
    except KeyboardInterrupt:
        print('Interrupted!')

In [27]:
# def generate_panorama_from_middle_to_both_sides(step=10, original_frames=None, hole_filled_frames=None, right_cache=None, left_cache=None, shifted_H_persist = False):
#     # if original_frames is None:
#     #     original_frames = read_frames(INPUT_VIDEO_PATH)
#     if hole_filled_frames is None:
#         hole_filled_frames = read_frames(HOLE_FILLED_BACKGROUND_VIDEO_PATH)
#     # if len(hole_filled_frames) != len(original_frames):
#     #     raise IndexError("original frames and hole_filled background frames have different length! {}, {}".format(len(original_frames), len(hole_filled_frames)))
#     if shifted_H_persist:
#         holder = {}
    
#     try:
#         with shelve.open(writeback=True, filename=H_PERSIST_FILENAME) as d:
#             # using original video H cache
#             if FILENAME not in d:
#                 d[FILENAME] = {}
#             if 'original' not in d[FILENAME]:
#                 d[FILENAME]['original'] = {}
#             m = d[FILENAME]['original']
#             anchor_index = len(hole_filled_frames) // 2

#             dst = hole_filled_frames[anchor_index].copy()
#             left = anchor_index - step
#             right = anchor_index + step
#             last_direction = False
#             padding = [0] * 4
#             count = 1
#             while left >= 0 or right < len(hole_filled_frames):
#                 print('iteration {}, time: {}'.format(count, time.time()))
#                 count += 1
#                 # right, left, right, left pattern
#                 if not last_direction:
#                     if right < len(hole_filled_frames):
#                         dst, padding = tmp_worker(dst, hole_filled_frames[right], get_H_matrix(scale_image(hole_filled_frames[right], scale_factor=1), scale_image(dst, scale_factor=1)))
#                         right_cache.append(dst)
#                         right += step
#                 else:
#                     if left >= 0:
#                         dst, padding = tmp_worker(dst, hole_filled_frames[left], get_H_matrix(scale_image(hole_filled_frames[left], scale_factor=1), scale_image(dst, scale_factor=1)))
#                         left_cache.append(dst)
#                         left -= step
#                 last_direction = not last_direction
#             return dst, padding
#     except KeyboardInterrupt:
#         print('Interrupted!')

In [28]:
with shelve.open(writeback=True, filename=H_PERSIST_FILENAME) as d:
    d.clear()

In [29]:
right_cache = []
left_cache = []
shifted_H_holder = {}
anchor_index = -1
panorama = generate_panorama_from_middle_to_both_sides(shifted_H_holder, left_step = 5, right_step = 5,
             right_cache=right_cache, left_cache=left_cache, clear_cache=True, anchor_index = anchor_index)
# panorama, shifted_H_holder = generate_panorama_from_middle_to_both_sides(step = 30, right_cache=right_cache, left_cache=left_cache)

iteration 1, time: 1670556383.5891812
cache missed
iteration 2, time: 1670556386.7917035
cache missed
iteration 3, time: 1670556390.0528283
cache missed
iteration 4, time: 1670556393.313234
cache missed
iteration 5, time: 1670556396.3712335
cache missed
iteration 6, time: 1670556399.839995
cache missed
iteration 7, time: 1670556403.0646582
cache missed
iteration 8, time: 1670556406.6185622
cache missed
iteration 9, time: 1670556410.1721678
cache missed
iteration 10, time: 1670556413.6592703
cache missed
iteration 11, time: 1670556416.9512503
cache missed
iteration 12, time: 1670556421.3692458
cache missed
iteration 13, time: 1670556425.517329
cache missed
iteration 14, time: 1670556430.0081985
cache missed
iteration 15, time: 1670556434.2537668
cache missed
iteration 16, time: 1670556438.8662415
cache missed
iteration 17, time: 1670556443.24675
cache missed
iteration 18, time: 1670556448.073892
cache missed
iteration 19, time: 1670556452.28196
cache missed
iteration 20, time: 167055645

In [212]:
cv.imshow('adf', scale_to_1920_1080(panorama))
cv.waitKey(0)
cv.destroyAllWindows()

In [213]:
with shelve.open(writeback=True, filename='shifted_H_persist') as d:
    if FILENAME not in d:
        d[FILENAME] = {}
    m = d[FILENAME]
    m[anchor_index] = shifted_H_holder
    

In [214]:
for left in left_cache:
    cv.imshow('adf', scale_to_1920_1080(left))
    cv.waitKey(0)
    cv.destroyAllWindows()

In [89]:
for right in right_cache:
    cv.imshow('adf', scale_to_1920_1080(right))
    cv.waitKey(0)
    cv.destroyAllWindows()

In [None]:
with shelve.open(writeback=True, filename=H_PERSIST_FILENAME) as d:
    # using original video H cache
    if FILENAME not in d:
        d[FILENAME] = {}
    if 'original' not in d[FILENAME]:
        d[FILENAME]['original'] = {}
    m = d[FILENAME]['original']
    original_frames = read_frames(INPUT_VIDEO_PATH)
    Hs = []
    for i in range(10):
        Hs.append(get_H_matrix_cached_version(i, i + 1,original_frames[i], original_frames[i + 1], m))
    H_direct = get_H_matrix_cached_version(0, 10, original_frames[0], original_frames[10], m)
    

In [None]:
cv.imshow('asdf', scale_image(original_frames[0]))
cv.waitKey(0)
cv.destroyAllWindows()

In [None]:
H_direct

In [None]:
a = Hs[-1]
for h in Hs[-2:0:-1]:
    a = a.dot(h)
print(a)

In [None]:
scale = 2560 / panorama.shape[1]
resized_panorama = cv.resize(panorama, (int(panorama.shape[1] * scale), int(panorama.shape[0] * scale)))
cv.imwrite(OUTPUT_PANORAMA_IMG_NAME, resized_panorama)

In [40]:
for right, left in zip(right_cache, left_cache):
    scale = 2560 / right.shape[1]
    cv.imshow('asd', cv.resize(right, (int(right.shape[1] * scale), int(right.shape[0] * scale))))
    cv.waitKey(0)
    cv.destroyAllWindows()
    scale = 2560 / left.shape[1]
    cv.imshow('adf', cv.resize(left, (int(left.shape[1] * scale), int(left.shape[0] * scale))))
    cv.waitKey(0)
    cv.destroyAllWindows()
    

In [27]:
for right, left in zip(right_cache, left_cache):
    cv.imshow('asd', scale_to_1920_1080(right))
    cv.waitKey(0)
    cv.destroyAllWindows()
    cv.imshow('adf', scale_to_1920_1080(left))
    cv.waitKey(0)
    cv.destroyAllWindows()
    

In [None]:
# cap = cv.VideoCapture(INPUT_VIDEO_PATH)
# background_video_out = cv.VideoWriter("out/{}_filled_background_video.mp4".format(FILENAME.split('.')[0]), cv.VideoWriter_fourcc(*'XVID'), int(cap.get(cv.CAP_PROP_FPS)),
#                                 (int(cap.get(cv.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv.CAP_PROP_FRAME_HEIGHT))))

frames = []
for i in range(0, 404):
    frames.append(cv.imread("out_images/out_images/{}_background_{}.jpg".format(FILENAME.split('.')[0], i)))
# cap.release()
# background_video_out.release()

# cv.destroyAllWindows()
# cv.waitKey(1)

In [None]:
s = cv.Stitcher.create(mode = cv.Stitcher_PANORAMA)
pano = None
def test(gap):
    global pano
    anchor_index = len(frames) // 2
    left = anchor_index - gap
    right = anchor_index + gap
    last_direction = False
    pano = frames[anchor_index]
    try:
        while left >= 0 or right < len(frames):
            # right, left, right, left pattern
            if not last_direction:
                if right < len(frames):
                    status, pano = s.stitch([pano, frames[right]])
                    if status != 0:
                        raise KeyError("{}".format(status))
                    right += gap
            else:
                if left >= 0:
                    status, pano = s.stitch([pano, frames[left]])
                    if status != 0:
                        raise KeyError("{}".format(status))
                    left -= gap
            last_direction = not last_direction
    except KeyboardInterrupt:
        print("Interrupted!")
test(10)


In [None]:
s.estimateTransform([frames[0], frames[10]])
s.composePanorama([frames[0], frames[10]])

In [None]:
print(status)

In [None]:
cv.imshow('asdf', pano)
cv.waitKey(0)
cv.destroyAllWindows()

In [None]:
panorama = generate_panorama(filled_frames, len(filled_frames) // 2)
scale = 2560 / panorama.shape[1]
resized_panorama = cv.resize(panorama, (int(panorama.shape[1] * scale), int(panorama.shape[0] * scale)))
cv.imwrite(OUTPUT_PANORAMA_IMG_NAME, resized_panorama)

In [None]:
cv.imshow('resized_panorama', resized_panorama)
cv.waitKey(0)
cv.destroyAllWindows()

In [None]:
s = cv.Stitcher.create(cv.Stitcher_PANORAMA)

In [None]:
print(dst_padded.shape)
print(src_warped.shape)

In [None]:
print (all([0, 0, 1]))

In [None]:
import sys
from stitching import Stitcher

# def generate_panorama(filled_frames):
#     stitcher = Stitcher(filled_frames)

#     # resize images
#     stitcher.resize()

#     # find features
#     stitcher.find_features()

#     # Match Features
#     stitcher.match()

#     # subset
#     stitcher.subset()

#     # camera Estimation
#     stitcher.camera_correction()

#     # warp image
#     stitcher.wrap()

#     # Seam Masks
#     stitcher.seam()

#     # Exposure Error Compensation
#     stitcher.exposure_error_compensation()

#     # Blending
#     panorama = stitcher.blending()

#     cv.imwrite(OUTPUT_PANORAMA_IMG_PATH, panorama)

In [None]:
imgs = [cv.imread('out_images/out_images/{}_background_{}.jpg'.format(FILENAME.split('.')[0], i)) for i in range(404)]

In [None]:
with shelve.open(writeback=True, filename='H_persist') as d:
        if FILENAME not in d:
            d[FILENAME] = {}
        m = d[FILENAME]
        

In [None]:
def plot_image(img, figsize_in_inches=(5, 5)):
    fig, ax = plt.subplots(figsize=figsize_in_inches)
    ax.imshow(cv.cvtColor(img, cv.COLOR_BGR2RGB))
    plt.show()

In [None]:
from pathlib import Path
def get_image_paths(img_set):
    return [str(path.relative_to('.')) for path in Path('tmp').rglob(f'{img_set}*')]
img_set = get_image_paths('test1')

In [None]:
settings = {
            'try_use_gpu': True,
            # 'warper_type':'plane',
            # 'adjuster':'no',
            # 'match_conf': 0.1
            'crop': False,
            # 'match_conf': 0.1
            }

In [None]:
s = Stitcher(**settings)


In [None]:
s.initialize_registration(img_set)

In [None]:
images = s.resize_medium_resolution()
features = s.find_features(images)

In [None]:
matches = s.match_features(features)
cameras = s.estimate_camera_parameters(features, matches)

In [None]:
cameras = s.perform_wave_correction(cameras)
s.estimate_scale(cameras)

In [None]:
images = s.resize_low_resolution(images)

In [None]:
images, masks, corners, sizes = s.warp_low_resolution(images, cameras)
s.estimate_exposure_errors(corners, images, masks)
seam_masks = s.find_seam_masks(images, corners, masks)

In [None]:
images = s.resize_final_resolution()
images, masks, corners, sizes = s.warp_final_resolution(images, cameras)
s.set_masks(masks)
images = s.compensate_exposure_errors(corners, images)
seam_masks = s.resize_seam_masks(seam_masks)

s.initialize_composition(corners, sizes)
s.blend_images(images, seam_masks, corners)
pano = s.create_final_panorama()

In [None]:
cv.imshow('pano', pano)
cv.waitKey(0)
cv.destroyAllWindows()

In [None]:
# Generate background panorama
generate_panorama(filled_frames)

In [None]:
# def fill_hole(step):
#     background_frames = read_frames(OUTPUT_BACKGROUND_VIDEO_PATH)
#     print(len(background_frames))
#     hole_filled_background_out = cv.VideoWriter(HOLE_FILLED_BACKGROUND_VIDEO_PATH, cv.VideoWriter_fourcc(*'XVID'), fps, (width, height))
#     if not hole_filled_background_out.isOpened():
#         raise IOError("cap init failed")

#     try:
#         # i is frame index
#         for i in tqdm(range(len(foreground_labels) - step)):        
#             # label is current frame's foreground label set
#             label = foreground_labels[i]
#             for row, col in label:
#                 # H = np.identity(3)
#                 # done = False
#                 # index = i
#                 # while not done:
#                 #     for _ in range(0, step):
#                 #         if index >= len(Hs):
#                 #             # if can't fill, just break
#                 #             break
#                 #         H = np.matmul(H, np.linalg.inv(Hs[index]))
#                 #         index += 1

#                 #     if index >= len(Hs):
#                 #         break

#                 #     [corresponding_row, corresponding_col, _] = np.matmul(H, np.array([row, col, 1]))
#                 #     corresponding_row = round(corresponding_row)
#                 #     corresponding_col = round(corresponding_col)
#                 #     if corresponding_row >= 0 and corresponding_row < height and corresponding_col >= 0 and corresponding_col < width and \
#                 #                             (corresponding_row, corresponding_col) not in foreground_labels[index]:
#                 #         background_frames[i][row][col] = background_frames[index][corresponding_row][corresponding_col].copy()
#                 #     done = True
#                 H = Hs[0]
#                 [corresponding_row, corresponding_col, _] = np.matmul(H, np.array([row, col, 1]))
#                 corresponding_row = round(corresponding_row)
#                 corresponding_col = round(corresponding_col)
#                 if corresponding_row >= 0 and corresponding_row < height and corresponding_col >= 0 and corresponding_col < width and \
#                                         (corresponding_row, corresponding_col) not in foreground_labels[step]:
#                     background_frames[i][row][col] = background_frames[i + step][corresponding_row][corresponding_col].copy()
#             hole_filled_background_out.write(background_frames[i])
#             break

#     except KeyboardInterrupt:
#         print('Interrupted!')   
    
#     hole_filled_background_out.release()


In [None]:
# pd.DataFrame(np.array(choices[0])).describe()

In [None]:
# all_differences.view([(f'f{i}',all_differences.dtype) for i in range(all_differences.shape[-1])])[...,0].astype('O')

In [None]:
# test = np.array([[[1, 1], [1, 2], [1, 996]],
#                  [[1, 1], [2, 1], [1, 996]],
#                  [[1, 1], [1, 2], [1, 996]],
#                  [[1, 1], [1, 2], [1, 996]],
#                  [[1, 1], [1, 2], [1, 996]]])

In [None]:
def get_macroblocks_vectors(flow, macroblock_size):
    width = len(flow[0])
    height = len(flow)

    macroblocks = np.zeros(
        (height // macroblock_size + int(height % macroblock_size != 0),
        width // macroblock_size + int(width % macroblock_size != 0),
                            2))
    print(macroblocks.shape)
    for y in range(0, height , macroblock_size):
        for x in range(0, width, macroblock_size):

            bw = macroblock_size if x + macroblock_size <= width else width - x
            bh = macroblock_size if y + macroblock_size <= height else height - y
            # current block
            cur_block = np.array([ i[x:x+bw] for i in flow[y:y + bh]])
            # get mean vector
            # flatten current block to make process easier
            mean = cur_block.reshape(-1, cur_block.shape[-1]).mean(axis = 0)
            macroblocks[y//macroblock_size + int(y % macroblock_size != 0)][x//macroblock_size + int(x % macroblock_size != 0)] = mean

    return macroblocks
# flow0_macroblock = get_macroblocks_vectors(flows[-1], 16)
# flow0_macroblock.shape

In [None]:
# convert pixel flow to block flow
block_flow = get_macroblocks_vectors(flows[-1], 16)

In [None]:
# total_k_means.cluster_centers_

In [None]:
# total_k_means.labels_

In [None]:
# find mode
# for row in flows[0]:
#     u, c = np.unique(row, axis=0, return_counts=True)
#     y = u[c == c.max()]
#     print(y)
#     break

In [None]:
# fgmasks = np.zeros((len(flows), 1080, 1920), np.uint8)
# for flow in tqdm(range(len(flows))):
#     for row in range(1080):
#         for col in range(1920):
#             dir_y,dir_x = flows[flow][row][col][0], flows[flow][row][col][1]
#             if dir_y  <= 0.2:
#                         fgmasks[flow, row, col] = 1
#     cv.imshow("dasd", frames[flow] * fgmasks[flow, :, :, np.newaxis])
#     cv.waitKey(1000)


In [None]:
# total_inliers = np.array([tuple(inlier) for inlier in flows[-1].reshape(-1, flows[-1].shape[-1])])
# total_k_means = KMeans(n_clusters=2, random_state=0).fit(total_inliers)
# algorithm = {“lloyd”, “elkan”, “auto”, “full”}
blocks_inliers = np.array([tuple(inlier) for inlier in block_flow.reshape(-1, block_flow.shape[-1])])
blocks_k_means = KMeans(n_clusters=2, random_state=0).fit(blocks_inliers)

In [None]:
# frame must be the prev frame of label
def separate(label, frame, block_flow_width, block_size):
    foreground = np.ndarray(shape=frames[0].shape, dtype=np.uint8)
    foreground.fill(255)
    background = foreground.copy()

    for i in range(len(label)):
        # get the block position
        block_y = i // block_flow_width
        block_x = i % block_flow_width
        for y in range(block_y * block_size, min((block_y + 1) * block_size, frame.shape[0])):
            for x in range(block_x * block_size, min((block_x + 1) * block_size, frame.shape[1])):
                if label[i]:
                    foreground[y][x] = frame[y][x]
                else:
                    background[y][x] = frame[y][x]


    return foreground, background

foreground, background = separate(blocks_k_means.labels_, frames[-2], block_flow.shape[1], 16)

In [None]:
cv.imshow('Frame',foreground)
cv.waitKey(0)
cv.destroyAllWindows()
cv.waitKey(1)

In [None]:
cv.imshow('Frame',background)
cv.waitKey(0)
cv.destroyAllWindows()
cv.waitKey(1)

In [None]:
cv.imshow('Frame',foreground)
cv.waitKey(0)
cv.destroyAllWindows()
cv.waitKey(1)

In [None]:
cv.imshow('Frame',background)
cv.waitKey(0)
cv.destroyAllWindows()
cv.waitKey(1)