In [1]:
import os
import json
import torch
import imageio
import math
import numpy as np
import torchvision
from torch.utils.data import Dataset
from typing import Dict, Optional, Sequence, List
from einops import rearrange

In [2]:
from llava.model import *

In [3]:
from llava.train.train import load_video, load_mask

In [4]:
seed=0
import random
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

In [5]:
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8):
    videos = rearrange(videos, "b c t h w -> t b c h w")
    outputs = []
    for x in videos:
        x = torchvision.utils.make_grid(x, nrow=n_rows)
        x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
        if rescale:
            x = (x + 1.0) / 2.0  # -1,1 -> 0,1
        x = (x * 255).numpy().astype(np.uint8)
        outputs.append(x)

    # os.makedirs(os.path.dirname(path), exist_ok=True)
    imageio.mimsave(path, outputs, duration=1000 * (1 / fps), loop=0)

In [6]:
weight_dtype = torch.float16

In [7]:
model = LlavaLlamaForCausalLM.from_pretrained('/nas-hdd/shoubin/pretrained_model/mgie_ckpt/LLaVA-Lightning-7B-delta-v1-1')

The config attributes {'st_attn': False} were passed to VideoInpaintingModel, but are not expected and will be ignored. Please verify your config.json configuration file.


in_channels 9
_IncompatibleKeys(missing_keys=['down_blocks.0.attentions.0.transformer_blocks.0.attn_temp.to_q.weight', 'down_blocks.0.attentions.0.transformer_blocks.0.attn_temp.to_k.weight', 'down_blocks.0.attentions.0.transformer_blocks.0.attn_temp.to_v.weight', 'down_blocks.0.attentions.0.transformer_blocks.0.attn_temp.to_out.0.weight', 'down_blocks.0.attentions.0.transformer_blocks.0.attn_temp.to_out.0.bias', 'down_blocks.0.attentions.0.transformer_blocks.0.norm_temp.weight', 'down_blocks.0.attentions.0.transformer_blocks.0.norm_temp.bias', 'down_blocks.0.attentions.1.transformer_blocks.0.attn_temp.to_q.weight', 'down_blocks.0.attentions.1.transformer_blocks.0.attn_temp.to_k.weight', 'down_blocks.0.attentions.1.transformer_blocks.0.attn_temp.to_v.weight', 'down_blocks.0.attentions.1.transformer_blocks.0.attn_temp.to_out.0.weight', 'down_blocks.0.attentions.1.transformer_blocks.0.attn_temp.to_out.0.bias', 'down_blocks.0.attentions.1.transformer_blocks.0.norm_temp.weight', 'down_bloc

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of the model checkpoint at /nas-hdd/shoubin/pretrained_model/mgie_ckpt/LLaVA-Lightning-7B-delta-v1-1 were not used when initializing LlavaLlamaForCausalLM: ['model.layers.11.mlp.up_proj.weight', 'model.layers.5.mlp.gate_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.31.self_attn.rotary_emb.inv_freq', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.16.mlp.gate_proj.weight', 'model.layers.24.self_attn.o_proj.weight', 'model.layers.19.self_attn.k_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.8.mlp.gate_proj.weight', 'model.layers.30.self_attn.v_proj.weight', 'model.layers.4.post_attention_layernorm.weight', 'model.layers.26.self_attn.v_proj.weight', 'model.layers.26.mlp.up_proj.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.0.self_attn.rotary_emb.inv_freq', 'model.layers.16.post_attention_layernorm.weight', 'model.layer

In [None]:
from LLaVA.llava.model.video_diffusion.unet import VideoInpaintingModel
# model.unet = VideoInpaintingModel.from_pretrained('/nas-hdd/shoubin/pretrained_model/lgvi/lgvi', subfolder='unet_trainedv3')
# model.unet = VideoInpaintingModel.from_pretrained('/nas-hdd/shoubin/pretrained_model/stable-diffusion-2-inpainting/', subfolder='unet')
model.unet = VideoInpaintingModel.from_pretrained('/nas-hdd/shoubin/pretrained_model/stable-diffusion-2-inpainting/', subfolder='unet_finetuned')
model = model.to('cuda:0')

The config attributes {'st_attn': False} were passed to VideoInpaintingModel, but are not expected and will be ignored. Please verify your config.json configuration file.


in_channels 9


In [None]:
model.text_encoder.to(weight_dtype)
model.vae.to(weight_dtype)
model.unet.to(weight_dtype)

In [None]:
from PIL import Image, ImageDraw
from einops import rearrange
import random as rnd

def load_mask(video_path, indices, mask_id, convert_to_box=False):
    WIDTH = 512
    HEIGHT = 320
    
    # print(video_path)
    frame_files = list(sorted(os.listdir(video_path)))
    frame_files = [x for x in frame_files if not x.startswith('.')]  # Excludes files like .DS_Store
    selected_frames = [frame_files[i] for i in indices]
    frames = []
    
    for frame_name in selected_frames:
        image = Image.open(os.path.join(video_path, frame_name))
        all_mask = np.array(image)
        # mask = (all_mask == int(mask_id)).astype(np.uint8) * 255
        mask = all_mask.astype(np.uint8) * 255
        
        if convert_to_box:
            box_image = Image.new("L", image.size, 255)
            draw = ImageDraw.Draw(box_image)
            # Find the bounding box of the mask
            rows = np.any(mask, axis=1)
            cols = np.any(mask, axis=0)
            # box = (xmin, ymin, xmax, ymax)
            if rows.any() and cols.any():  # Only proceed if there is at least one non-zero value
                ymin, ymax = np.where(rows)[0][[0, -1]]
                xmin, xmax = np.where(cols)[0][[0, -1]]
                draw.rectangle([xmin , ymin, xmax, ymax], fill=0)
        
            box_image = box_image.resize((WIDTH, HEIGHT), resample=Image.BILINEAR)
            box_np = np.array(box_image)
            box_tensor = torch.from_numpy(box_np).float().div(255).unsqueeze(0)  # Add channel dimension
            frames.append(box_tensor)
        # Stack all tensors to create a batch
            
        else:
            image = Image.fromarray(mask)
            image = image.resize((WIDTH, HEIGHT), resample=Image.BILINEAR)
            frames.append(image)
    
    if not convert_to_box:
        # Stack images and convert to a tensor
        frames = np.stack(frames, axis=2)
        frames = torch.from_numpy(frames).permute(2, 0, 1).contiguous().unsqueeze(1)
        frames = torch.where(frames > 0, torch.tensor(0.0), torch.tensor(1.0))
    else:
        frames = torch.stack(frames, dim=0)
        frames = torch.where(frames > 0, torch.tensor(1.0), torch.tensor(0.0))
    
    return frames
    
def load_video(video_path, sample_num=16, sample_type='uniform', given_index=None):
    WIDTH = 512
    HEIGHT = 320
    
    
    frame_files = list(sorted(os.listdir(video_path)))
    # exclude .DS_Store
    frame_files = [x for x in frame_files if x[0]!='.']
    # print(frame_files)
    vlen = len(frame_files)

    n_frms = min(sample_num, vlen)
    start, end = 0, vlen

    if given_index is None:
        intervals = np.linspace(start=start, stop=end, num=n_frms + 1).astype(int)
        ranges = []
        for idx, interv in enumerate(intervals[:-1]):
            ranges.append((interv, intervals[idx + 1]))
    
        if sample_type == 'random':
            indices = []
            for x in ranges:
                if x[0] == x[1]:
                    indices.append(x[0])
                else:
                    indices.append(rnd.choice(range(x[0], x[1])))
        elif sample_type == 'uniform':
            indices = [(x[0] + x[1]) // 2 for x in ranges]
        
        selected_frames = [frame_files[i] for i in indices]
        if len(selected_frames) < sample_num:
            selected_frames += [frame_files[-1]] * (sample_num - len(selected_frames))
            indices += [indices[-1]] * (sample_num - len(indices))
    else:
        selected_frames = [frame_files[i] for i in given_index]
        indices = given_index
    
    # [:max_num_frames]
    frames = []
    # print(len(selected_frames))
    for frame_name in selected_frames:
        image = Image.open(os.path.join(video_path, frame_name)).convert("RGB")
        image = image.resize((WIDTH, HEIGHT), resample=Image.BILINEAR)
        frames.append(image)

    frames = np.stack(frames, axis=2)
    frames = torch.from_numpy(frames).permute(2, 3, 0, 1).contiguous() #.unsqueeze(0)
    frames = frames.float().div(255).clamp(0, 1).half() * 2.0 - 1.0
    return frames, indices

In [None]:
class EvalDataset(Dataset):
    def __init__(self,frame_num=16):
        super(EvalDataset, self).__init__()
        # self.data = json.load(open('/nas-hdd/shoubin/videos/rovi/data/v4_test.json'))[:100]
        # self.video_base_path = '/nas-hdd/shoubin/videos/rovi/data/JPEGImages/'
        # self.mask_base_path = '/nas-hdd/shoubin/videos/rovi/data/Annotations/'
        # self.frame_num = frame_num
        # # self.tokenizer, self.multimodal_cfg = tokenizer, multimodal_cfgs
        # print('--num data: %d--'%(len(self.data)))
        self.data = json.load(open('/nas-hdd/shoubin/videos/rovi/data/advegas_benchmark.json'))
        # self.data = json.load(open('/nas-hdd/shoubin/videos/rovi/data/v4_train.json'))
        self.video_base_path = '/nas-hdd/shoubin/videos/rovi/data/JPEGImages/'
        # self.mask_base_path = '/nas-hdd/shoubin/videos/rovi/data/Annotations/'
        self.mask_base_path = '/nas-hdd/shoubin/advegas/predicted_mask/'
        self.inpainted_base_path = '/nas-hdd/shoubin/videos/rovi/data/InpaintImages/'
        self.frame_num = frame_num
        # self.tokenizer, self.multimodal_cfg = tokenizer, multimodal_cfgs
        print('--num data: %d--'%(len(self.data)))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:

        anno = self.data[i]
        vid = anno['vid']
        task = anno['task']
        mask_id = anno['mask_id']
        
        # if not os.path.exists('/nas-hdd/shoubin/advegas/predicted_boxes/' + vid + '-' + str(mask_id)):
        #     return None
        # if task!= 'adding':
        #     return None
        # # print(vid)
        
        if task == 'removal':
            text = 'inpainted background' #anno['prompt'] #'inpainted background'
            target, index = load_video(os.path.join(self.inpainted_base_path, vid, mask_id), sample_num=self.frame_num)
            condition, _ = load_video(os.path.join(self.video_base_path, vid), sample_num=self.frame_num, given_index=index)
            # condition, _ = load_video(os.path.join(self.inpainted_base_path, vid, '3'), sample_num=self.frame_num, given_index=index)
            # mask = load_mask(os.path.join(self.mask_base_path, vid), index, mask_id, convert_to_box=False)
            mask = load_mask(os.path.join(self.mask_base_path, vid+'-'+str(mask_id)), index, mask_id, convert_to_box=False)
            
        elif task == 'adding':
            text = anno['description'] # anno['prompt']
            target, index = load_video(os.path.join(self.video_base_path, vid), sample_num=self.frame_num)
            condition, _ = load_video(os.path.join(self.inpainted_base_path, vid, mask_id), sample_num=self.frame_num, given_index=index)
            mask = load_mask(os.path.join(self.mask_base_path, vid+'-'+str(mask_id)), index, mask_id, convert_to_box=True)
        
        elif task == 'editing':
            text = anno['prompt']
            target, index = load_video(os.path.join(self.video_base_path, vid), sample_num=self.frame_num)
            condition, _ = load_video(os.path.join(self.video_base_path, vid), sample_num=self.frame_num, given_index=index)
            # condition, _ = load_video(os.path.join(self.inpainted_base_path, vid, '3'), sample_num=self.frame_num, given_index=index)
            # mask = load_mask(os.path.join(self.mask_base_path, vid), index, mask_id, convert_to_box=False)
            # mask = load_mask(os.path.join(self.mask_base_path, vid), index, mask_id, convert_to_box=True)
            mask = load_mask(os.path.join(self.mask_base_path, vid+'-'+str(mask_id)), index, mask_id, convert_to_box=False)
            
        data_dict = {}
        data_dict['task'] = task
        data_dict['target'] = target # [1, 8, 3, 320, 512]
        data_dict['condition'] = condition # [1, 8, 3, 320, 512]
        data_dict['mask'] = mask    # [1, 8, 3, 320, 512]
        data_dict['text_prompt'] = text
        return data_dict

In [None]:
eval_dataset = EvalDataset()

In [None]:
short_dict = {2: 'Orange', 5: 'Surfer in red', 8: 'Snowboarder in white', 11: 'Glass Robin',
14: 'Man in white shirt', 17: 'Yellow Magpie Bird', 20: 'Wooden Eagle',23: 'Corgi',
26: 'Rock Climber in blue', 29: 'Airplane in red', 32: 'Cat', 35: 'Man',
38: 'White Dog', 41: 'Man in black', 44: 'Belaying Man in Red', 47: 'Black Poodle',
50: 'Dog Handler as Superman', 53: 'Dalmatian', 56: 'Brown Bear', 59:'Orange Shark',
62: 'Man wearing a brown leather jacket', 65: 'Man in orange', 68: 'The baby in a white top', 71:'Iron Dog',
74: 'Black Dog', 77: 'Skier in white T-shirt', 80: 'A white vehicle', 83:'A baby in a bright red',
86: 'Green Snake', 89: 'Robot\'s Hand', 92: 'Woman in dark red dress', 95:'White Cat',
98: 'Player with Blue Hat', 101: 'Observer in white shirt', 104: 'BMX Rider in red', 107:'Panda',
110: 'blue-colored dog', 113: 'Spiderman', 116: 'Black Seagull', 119:'Person in rainbow-striped jacket',
122: 'White Cat', 125: 'Black Sheep', 128: 'Van Gogh Sheep', 131:'White Dog',
134: 'blue and white Stork', 137: 'Superman', 140: 'Runner in Red Shirt', 143:'wooden seagull',
146: 'White Camel', 149: 'Basketball Player in a red jersey', 152: 'White Shark', 155:'Black Bird of Prey',
158: 'Stone Bonsai', 161: 'woman in red t-shirt', 164: 'Basketball Player in a blue shirt and white shorts', 167:'Glass Dog',
170: 'Raccoon', 173: 'White Cat', 176: 'White Bird', 179:'White Bird',}

In [None]:
for i in range(len(eval_dataset)):
    input_dict = eval_dataset[i]
    # if input_dict is None:
    #     continue
    # print(i, input_dict['task'])
    if input_dict['task']=='adding':
        continue

    task = [input_dict['task']]
    video = input_dict['condition'].to(weight_dtype).unsqueeze(0).to('cuda:0') # [16, 3, 320, 512]
    mask = input_dict['mask'].to(weight_dtype).unsqueeze(0).to('cuda:0') # [16, 1, 320, 512]
    # video = video[:,:16,:,:,:]
    # mask = mask[:,:16,:,:,:]
    text = [input_dict['text_prompt']]
    # text = [short_dict[i]]
    inpainted = model.inpaint(
        video=video, # input video condition
        mask=mask,
        prompt=text,
        task = task)
    # print(inpainted.shape)
    # output_path = "/nas-hdd/shoubin/advegas/ablation/adding_wo_detailed/{}.gif".format(str(i))
    # output_path = "/nas-hdd/shoubin/advegas/ablation/pred_mask/{}/{}.gif".format(input_dict['task'], str(i))
    output_path = './reproduce.gif'
    save_videos_grid(inpainted, output_path)
    break

In [None]:
index = 12236
input_dict = eval_dataset[index]
print(input_dict.keys())
print(input_dict['task'])
task = [input_dict['task']]
video = input_dict['condition'].to(weight_dtype).unsqueeze(0).to('cuda:1') # [16, 3, 320, 512]
mask = input_dict['mask'].to(weight_dtype).unsqueeze(0).to('cuda:1') # [16, 1, 320, 512]
# text = ['Giant Panda: The central figure in the video, a giant panda with distinct black and white fur, showcases its climbing skills. It has a large, round body, a characteristic bear-like face with black patches around its eyes, ears, and limbs.']
text = [input_dict['text_prompt']]
print(text)
# print(text)
print(video.shape)

In [146]:
# text ='Kongfu Panda: The Panda is focused on performing dribbling exercises, showcasing control and agility.'
# text = ['Corgi: A Corgi dog with a golden coat is walking alongside a larger dog on a sandy beach. The Corgi, smaller and more agile, trots energetically, occasionally darting forward']
text = ['Woman: A casual woman in a white beach dress and with a straw hat, walking beside a man, likely accompanying him and his white dog. She moves with a relaxed stride, suggesting a leisurely outing together.']
inpainted = model.inpaint(
    video=video, # input video condition
    mask=mask,
    prompt=text,
    task = task)
print(inpainted.shape)
output_path = "teaser_adding_2.gif"
save_videos_grid(inpainted, output_path)

  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([1, 3, 16, 320, 512])


In [21]:
output_path

'/nas-hdd/shoubin/advegas/ours_sd2_50ep/editing.gif'

In [111]:
index = 28
input_dict = eval_dataset[index]
print(input_dict.keys())
# print(input_dict['task'])
# video = input_dict['condition'].to(weight_dtype).unsqueeze(0).to('cuda:0') # [16, 3, 320, 512]
mask = input_dict['mask'].to(weight_dtype).unsqueeze(0).to('cuda:0') # [16, 1, 320, 512]
# text = ['Giant Panda: The central figure in the video, a giant panda with distinct black and white fur, showcases its climbing skills. It has a large, round body, a characteristic bear-like face with black patches around its eyes, ears, and limbs.']
# text = [input_dict['text_prompt']]

print(text)
# print(text)
# print(video.shape)

3fd96c5267
dict_keys(['task', 'target', 'condition', 'mask', 'text_prompt'])
['Surfer: An athlete wearing a dark wetsuit, possibly black or navy, showcasing talent in balancing and steering on the waves. Their stance is wide and steady, knees bent, arms outstretched for balance, and their posture exuding confidence.']


In [112]:
mask_ = mask[:,8,:,:,:].unsqueeze(1)
print(mask.shape)
print(mask_.shape)
fixed_mask = torch.repeat_interleave(mask_, 16, dim=1)
print(fixed_mask.shape)

torch.Size([1, 16, 1, 320, 512])
torch.Size([1, 1, 1, 320, 512])
torch.Size([1, 16, 1, 320, 512])


In [101]:
import os
from PIL import Image
import imageio

def create_gif(image_folder, output_path, num_samples=16):
    # 获取所有图片文件的路径
    WIDTH = 512
    HEIGHT = 320
    images = sorted([os.path.join(image_folder, img) for img in os.listdir(image_folder) if img.endswith('.jpg')])
    
    # 计算间隔
    step = len(images) // num_samples
    
    # 选择均匀间隔的图片
    selected_images = images[::step][:num_samples]
    
    # 读取图片
    frames = [Image.open(img).resize((WIDTH, HEIGHT), resample=Image.BILINEAR) for img in selected_images]
    
    # 将图片保存为 GIF
    imageio.mimsave(output_path, frames, 'GIF', duration=1000 * (1 / 8), loop=0)  # duration 控制帧之间的时间间隔

# 使用函数
# image_folder = '/nas-hdd/shoubin/videos/rovi/data/JPEGImages/b24fe36b2a'
# output_path = 'orginal.gif'
# create_gif(image_folder, output_path)


In [102]:
# inversed_latents = model.DDIM(
#         video=video, # input video condition
#         mask=mask,
#         prompt='',)

In [89]:
# inv_latents_path = '/nas-hdd/shoubin/advegas/davis_ddim/test.pt'
# torch.save(inversed_latents, inv_latents_path)

In [90]:
# index = 1
# input_dict = eval_dataset[index]
# print(input_dict.keys())
# # video = input_dict['video'].to(weight_dtype).unsqueeze(0).to('cuda:1') # [16, 3, 320, 512]
# mask = input_dict['mask'].to(weight_dtype).unsqueeze(0).to('cuda:1') # [16, 1, 320, 512]
# text = [input_dict['text_prompt']]
# # print(text)
# # text = ['change the man to the background']
# # print(text)
# # print(video.shape)

In [62]:
# inversed_latents[-1].shape
# len(inversed_latents)

In [123]:
# text = ['Giant Panda: The central figure in the video, a giant panda with distinct black and white fur, showcases its climbing skills.']
text = ['White Bird: A bird is flying, its fur is pure white.']
reconstruct = model.inpaint(
        video=video, # input video condition
        mask=fixed_mask,
        prompt=text,
        task = task
        # latents=inversed_latents[-1]
        )

  0%|          | 0/50 [00:00<?, ?it/s]

In [124]:
# video = reconstruct #pipe_output.videos
output_path = "v4_4000_{}_adding_fixed_mask3.gif".format(str(index))
save_videos_grid(reconstruct, output_path)

In [27]:
# mask

In [28]:
# video = (video / 2 + 0.5).clamp(0, 1)
# # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
# video = video.cpu().float().numpy()
# video = torch.from_numpy(video)
# video=video.permute(0,2,1,3,4)
# output_path = "in.gif"
# save_videos_grid(video, output_path)