Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How can I use only one mask files? #72

Open
HamiguaLu opened this issue Jun 14, 2023 · 1 comment
Open

How can I use only one mask files? #72

HamiguaLu opened this issue Jun 14, 2023 · 1 comment

Comments

@HamiguaLu
Copy link

I want to remove wartermark from the video, the watermark is at the fixed postion, so can I use only one mask file for all the frames? if so how

Thanks in advance

@HamiguaLu
Copy link
Author

Hi, in case someone else want it too, you can try the following code
By the way, it's based on https://github.com/gaomingqi/Track-Anything/blob/master/inpainter/base_inpainter.py
`import os
import glob
from PIL import Image

import torch
import yaml
import cv2
import importlib
import numpy as np
from tqdm import tqdm
import torchvision
from util.tensor_util import resize_frames, resize_masks

class BaseInpainter:
def init(self, E2FGVI_checkpoint, device) -> None:
"""
E2FGVI_checkpoint: checkpoint of inpainter (version hq, with multi-resolution support)
"""
net = importlib.import_module('model.e2fgvi_hq')
self.model = net.InpaintGenerator().to(device)
self.model.load_state_dict(torch.load(E2FGVI_checkpoint, map_location=device))
self.model.eval()
self.device = device
# load configurations
with open("config/config.yaml", 'r') as stream:
config = yaml.safe_load(stream)
self.neighbor_stride = config['neighbor_stride']
self.num_ref = config['num_ref']
self.step = config['step']
# config for E2FGVI with splits
self.num_subset_frames = config['num_subset_frames']
self.num_external_ref = config['num_external_ref']

# sample reference frames from the whole video
def get_ref_index(self, f, neighbor_ids, length):
    ref_index = []
    if self.num_ref == -1:
        for i in range(0, length, self.step):
            if i not in neighbor_ids:
                ref_index.append(i)
    else:
        start_idx = max(0, f - self.step * (self.num_ref // 2))
        end_idx = min(length, f + self.step * (self.num_ref // 2))
        for i in range(start_idx, end_idx + 1, self.step):
            if i not in neighbor_ids:
                if len(ref_index) > self.num_ref:
                    break
                ref_index.append(i)
    return ref_index

def inpaint_efficient(self, frames, masks, num_tcb, num_tca, dilate_radius=15, ratio=1):
    """
    Perform Inpainting for video subsets
    frames: numpy array, T, H, W, 3
    masks: numpy array, T, H, W
    num_tcb: constant, number of temporal context before, frames
    num_tca: constant, number of temporal context after, frames
    dilate_radius: radius when applying dilation on masks
    ratio: down-sample ratio

    Output:
    inpainted_frames: numpy array, T, H, W, 3
    """
    #assert frames.shape[:3] == masks.shape, 'different size between frames and masks'
    #assert ratio > 0 and ratio <= 1, 'ratio must in (0, 1]'
    

    
    # --------------------
    # pre-processing
    # --------------------
    masks = masks.copy()
    masks = np.clip(masks, 0, 1)
    kernel = cv2.getStructuringElement(2, (dilate_radius, dilate_radius))
    masks = np.stack([cv2.dilate(mask, kernel) for mask in masks], 0)
    T = frames.shape[0]
    masks = np.expand_dims(masks, axis=3)    # expand to T, H, W, 1
    # size: (w, h)
    if ratio == 1:
        size = None
        binary_masks = masks
    else:
        print("Doesn't support ratio !!!!")
        return null
        
        
    # frames and binary_masks are numpy arrays
    h, w = frames.shape[1:3]
    video_length = T - (num_tca + num_tcb)  # real video length
    # convert to tensor
    imgs = (torch.from_numpy(frames).permute(0, 3, 1, 2).contiguous().unsqueeze(0).float().div(255)) * 2 - 1
    masks = torch.from_numpy(binary_masks).permute(0, 3, 1, 2).contiguous().unsqueeze(0)
    imgs, masks = imgs.to(self.device), masks.to(self.device)
    comp_frames = [None] * video_length
    tcb_imgs = None
    tca_imgs = None
    tcb_masks = None
    tca_masks = None
    # --------------------
    # end of pre-processing
    # --------------------

    # separate tc frames/masks from imgs and masks
    if num_tcb > 0:
        tcb_imgs = imgs[:, :num_tcb]
        tcb_masks = masks
    if num_tca > 0:
        tca_imgs = imgs[:, -num_tca:]
        tca_masks = masks
        end_idx = -num_tca
    else:
        end_idx = T

    imgs = imgs[:, num_tcb:end_idx]
    #masks = masks[:, num_tcb:end_idx]
    #binary_masks = binary_masks[num_tcb:end_idx]    # only neighbor area are involved
    frames = frames[num_tcb:end_idx]                # only neighbor area are involved

    for f in tqdm(range(0, video_length, self.neighbor_stride), desc='Inpainting image'):
        neighbor_ids = [
            i for i in range(max(0, f - self.neighbor_stride),
                            min(video_length, f + self.neighbor_stride + 1))
        ]
        ref_ids = self.get_ref_index(f, neighbor_ids, video_length)

        # selected_imgs = imgs[:1, neighbor_ids + ref_ids, :, :, :]
        # selected_masks = masks[:1, neighbor_ids + ref_ids, :, :, :]
        
        selected_imgs = imgs[:, neighbor_ids]
        selected_masks = masks#[:, neighbor_ids]
        # pad before
        if tcb_imgs is not None:
            selected_imgs = torch.concat([selected_imgs, tcb_imgs], dim=1)
            selected_masks = torch.concat([selected_masks, tcb_masks], dim=1)
        # integrate ref frames
        selected_imgs = torch.concat([selected_imgs, imgs[:, ref_ids]], dim=1)
        #selected_masks = torch.concat([selected_masks, masks[:, ref_ids]], dim=1)
        # pad after
        if tca_imgs is not None:
            selected_imgs = torch.concat([selected_imgs, tca_imgs], dim=1)
            #selected_masks = torch.concat([selected_masks, tca_masks], dim=1)

        with torch.no_grad():
            masked_imgs = selected_imgs * (1 - selected_masks)
            mod_size_h = 60
            mod_size_w = 108
            h_pad = (mod_size_h - h % mod_size_h) % mod_size_h
            w_pad = (mod_size_w - w % mod_size_w) % mod_size_w
            masked_imgs = torch.cat(
                [masked_imgs, torch.flip(masked_imgs, [3])],
                3)[:, :, :, :h + h_pad, :]
            masked_imgs = torch.cat(
                [masked_imgs, torch.flip(masked_imgs, [4])],
                4)[:, :, :, :, :w + w_pad]
            pred_imgs, _ = self.model(masked_imgs, len(neighbor_ids))
            pred_imgs = pred_imgs[:, :, :h, :w]
            pred_imgs = (pred_imgs + 1) / 2
            pred_imgs = pred_imgs.cpu().permute(0, 2, 3, 1).numpy() * 255
            for i in range(len(neighbor_ids)):
                idx = neighbor_ids[i]
                img = pred_imgs[i].astype(np.uint8) * binary_masks[0] + frames[idx] * (
                        1 - binary_masks[0])
                if comp_frames[idx] is None:
                    comp_frames[idx] = img
                else:
                    comp_frames[idx] = comp_frames[idx].astype(
                        np.float32) * 0.5 + img.astype(np.float32) * 0.5
        torch.cuda.empty_cache()
    inpainted_frames = np.stack(comp_frames, 0)
    return inpainted_frames.astype(np.uint8)

def inpaint(self, frames, masks, dilate_radius=15, ratio=1):
    """
    Perform Inpainting for video subsets
    frames: numpy array, T, H, W, 3
    masks: numpy array, T, H, W
    dilate_radius: radius when applying dilation on masks
    ratio: down-sample ratio

    Output:
    inpainted_frames: numpy array, T, H, W, 3
    """
    assert masks.shape[0] == 1, 'different size between frames and masks'
    assert ratio > 0 and ratio <= 1, 'ratio must in (0, 1]'
    
    # set num_subset_frames
    num_subset_frames = self.num_subset_frames
    # split frames into subsets
    video_length = len(frames)
    num_splits = video_length // num_subset_frames
    id_splits = [[i*num_subset_frames, (i+1)*num_subset_frames] for i in range(num_splits)]  # id splits
    
    if num_splits ==  0:
        id_splits = [[0, video_length]]
    
    # if remaining split > num_subset_frames/2, add a new split, else, append to the last split
    if video_length - id_splits[-1][-1] > num_subset_frames / 3:
        id_splits.append([num_splits*num_subset_frames, video_length])
    else:
        diff = video_length - id_splits[-1][-1]
        id_splits = [[ids[0]+diff, ids[1]+diff] for ids in id_splits]
        id_splits[0][0] = 0     # if OOM, let it happen at the begining :D

    # if appending, convert the appended split to the FIRST one, avoiding OOM at last

    # perform inpainting for each split
    inpainted_splits = []
    for id_split in id_splits:
        video_split = frames[id_split[0]:id_split[1]]
        #mask_split = masks[id_split[0]:id_split[1]]

        # | id_before | ----- | id_split[0] | ----- | id_split[1] | ----- | id_after |
        # for each split, consider its temporal context [-context_range] frames and [context_range] frames
        id_before = max(0, id_split[0] - self.step * self.num_external_ref)
        try:
            tcb_frames = np.stack([frames[idb] for idb in range(id_before, (id_split[0]-self.step) + 1, self.step)], 0)
            tcb_masks = np.stack([masks[idb] for idb in range(id_before, (id_split[0]-self.step) + 1, self.step)], 0)
            num_tcb = len(tcb_frames)
        except:
            num_tcb = 0
        id_after = min(video_length, id_split[1] + self.step * self.num_external_ref + 1)
        try:
            tca_frames = np.stack([frames[ida] for ida in range(id_split[1]+self.step, id_after, self.step)], 0)
            tca_masks = np.stack([masks[ida] for ida in range(id_split[1]+self.step, id_after, self.step)], 0)
            num_tca = len(tca_frames)
        except:
            num_tca = 0

        # concatenate temporal context frames/masks with input frames/masks (for parallel pre-processing)
        if num_tcb > 0:
            video_split = np.concatenate([tcb_frames, video_split], 0)
            #mask_split = np.concatenate([tcb_masks, mask_split], 0)
        if num_tca > 0:
            video_split = np.concatenate([video_split, tca_frames], 0)
            #mask_split = np.concatenate([mask_split, tca_masks], 0)
        
        torch.cuda.empty_cache()
        # inpaint each split
        inpainted_splits.append(self.inpaint_efficient(video_split, masks, num_tcb, num_tca, dilate_radius, ratio))
        torch.cuda.empty_cache()
    inpainted_frames = np.concatenate(inpainted_splits, 0)

    return inpainted_frames.astype(np.uint8)

def inpaint_ori(self, frames, masks, dilate_radius=15, ratio=1):
    """
    frames: numpy array, T, H, W, 3
    masks: numpy array, T, H, W
    dilate_radius: radius when applying dilation on masks
    ratio: down-sample ratio

    Output:
    inpainted_frames: numpy array, T, H, W, 3
    """
    assert frames.shape[:3] == masks.shape, 'different size between frames and masks'
    assert ratio > 0 and ratio <= 1, 'ratio must in (0, 1]'
    masks = masks.copy()
    masks = np.clip(masks, 0, 1)
    kernel = cv2.getStructuringElement(2, (dilate_radius, dilate_radius))
    masks = np.stack([cv2.dilate(mask, kernel) for mask in masks], 0)

    T, H, W = masks.shape
    masks = np.expand_dims(masks, axis=3)    # expand to T, H, W, 1
    # size: (w, h)
    if ratio == 1:
        size = None
        binary_masks = masks
    else:
        size = [int(W*ratio), int(H*ratio)]
        size = [si+1 if si%2>0 else si for si in size]  # only consider even values
        # shortest side should be larger than 50
        if min(size) < 50:
            ratio = 50. / min(H, W)
            size = [int(W*ratio), int(H*ratio)]

        size = [160, 120]
        binary_masks = resize_masks(masks, tuple(size))
        frames = resize_frames(frames, tuple(size))          # T, H, W, 3
    # frames and binary_masks are numpy arrays
    h, w = frames.shape[1:3]
    video_length = T

    # convert to tensor
    imgs = (torch.from_numpy(frames).permute(0, 3, 1, 2).contiguous().unsqueeze(0).float().div(255)) * 2 - 1
    masks = torch.from_numpy(binary_masks).permute(0, 3, 1, 2).contiguous().unsqueeze(0)

    imgs, masks = imgs.to(self.device), masks.to(self.device)
    comp_frames = [None] * video_length

    for f in tqdm(range(0, video_length, self.neighbor_stride), desc='Inpainting image'):
        neighbor_ids = [
            i for i in range(max(0, f - self.neighbor_stride),
                            min(video_length, f + self.neighbor_stride + 1))
        ]
        ref_ids = self.get_ref_index(f, neighbor_ids, video_length)
        selected_imgs = imgs[:1, neighbor_ids + ref_ids, :, :, :]
        selected_masks = masks[:1, neighbor_ids + ref_ids, :, :, :]
        with torch.no_grad():
            masked_imgs = selected_imgs * (1 - selected_masks)
            mod_size_h = 60
            mod_size_w = 108
            h_pad = (mod_size_h - h % mod_size_h) % mod_size_h
            w_pad = (mod_size_w - w % mod_size_w) % mod_size_w
            masked_imgs = torch.cat(
                [masked_imgs, torch.flip(masked_imgs, [3])],
                3)[:, :, :, :h + h_pad, :]
            masked_imgs = torch.cat(
                [masked_imgs, torch.flip(masked_imgs, [4])],
                4)[:, :, :, :, :w + w_pad]
            pred_imgs, _ = self.model(masked_imgs, len(neighbor_ids))
            pred_imgs = pred_imgs[:, :, :h, :w]
            pred_imgs = (pred_imgs + 1) / 2
            pred_imgs = pred_imgs.cpu().permute(0, 2, 3, 1).numpy() * 255
            for i in range(len(neighbor_ids)):
                idx = neighbor_ids[i]
                img = pred_imgs[i].astype(np.uint8) * binary_masks[idx] + frames[idx] * (
                        1 - binary_masks[idx])
                if comp_frames[idx] is None:
                    comp_frames[idx] = img
                else:
                    comp_frames[idx] = comp_frames[idx].astype(
                        np.float32) * 0.5 + img.astype(np.float32) * 0.5
        torch.cuda.empty_cache()
    inpainted_frames = np.stack(comp_frames, 0)
    return inpainted_frames.astype(np.uint8)

def read_mask(mpath):
masks = []
mnames = os.listdir(mpath)
mnames.sort()
for mp in mnames:
m = Image.open(os.path.join(mpath, mp))
#m = m.resize(size, Image.NEAREST)
m = np.array(m.convert('L'))
m = np.array(m > 0).astype(np.uint8)
m = cv2.dilate(m,
cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)),
iterations=4)

    masks.append(Image.fromarray(m * 255))
    
masks = np.stack(masks)
return masks

read frames from video

def read_frame_from_videos(vname, use_mp4 = False):

frames = []
if use_mp4:
    vidcap = cv2.VideoCapture(vname)
    success, image = vidcap.read()
    count = 0
    while success:
        image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        frames.append(image)
        success, image = vidcap.read()
        count += 1
else:
    lst = os.listdir(vname)
    lst.sort()
    fr_lst = [vname + '/' + name for name in lst]
    for fr in fr_lst:
        image = cv2.imread(fr)
        image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        frames.append(image)
return frames

generate video after vos inference

def generate_video_from_frames(frames, output_path, fps=30):
"""
Generates a video from a list of frames.

Args:
    frames (list of numpy arrays): The frames to include in the video.
    output_path (str): The path to save the generated video.
    fps (int, optional): The frame rate of the output video. Defaults to 30.
"""
# height, width, layers = frames[0].shape
# fourcc = cv2.VideoWriter_fourcc(*"mp4v")
# video = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
# print(output_path)
# for frame in frames:
#     video.write(frame)

# video.release()
frames = torch.from_numpy(np.asarray(frames))
if not os.path.exists(os.path.dirname(output_path)):
    os.makedirs(os.path.dirname(output_path))
torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
return output_path

if name == 'main':

# # davis-2017
# frame_path = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/parkour', '*.jpg'))
# frame_path.sort()
# mask_path = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/Annotations/480p/parkour', "*.png"))
# mask_path.sort()

# long and large video


frames = read_frame_from_videos("examples/schoolgirls.mp4",True)

video_length = len(frames)
frames = [np.array(f).astype(np.uint8) for f in frames]

masks = read_mask("examples/schoolgirls_mask")


frames = np.array(frames)  # Convert list to numpy array
masks = np.array(masks)  # Convert list to numpy array

# ----------------------------------------------
# how to use
# ----------------------------------------------
# 1/3: set checkpoint and device
checkpoint = 'release_model\\E2FGVI-HQ-CVPR22.pth'
device = 'cuda:0'
# 2/3: initialise inpainter
base_inpainter = BaseInpainter(checkpoint, device)
# 3/3: inpainting (frames: numpy array, T, H, W, 3; masks: numpy array, T, H, W)
# ratio: (0, 1], ratio for down sample, default value is 1
inpainted_frames = base_inpainter.inpaint(frames, masks)   # numpy array, T, H, W, 3


generate_video_from_frames(inpainted_frames,"result/test.mp4")
# save


torch.cuda.empty_cache()
print('switch to ori')

# inpainted_frames = base_inpainter.inpaint_ori(frames[:50], masks[:50], ratio=0.1)
# save_path = '/ssd1/gaomingqi/results/inpainting/avengers'
# # ----------------------------------------------
# # end
# # ----------------------------------------------
# # save
# for ti, inpainted_frame in enumerate(inpainted_frames):
#   frame = Image.fromarray(inpainted_frame).convert('RGB')
#   frame.save(os.path.join(save_path, f'{ti:05d}.jpg'))

`

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant