In [1]:
import random
import math
import gc
import os

from torch.utils.data import Dataset, DataLoader, IterableDataset
from webdataset.handlers import warn_and_continue
from easydict import EasyDict as edict
from matplotlib import pyplot as plt
from open_clip import tokenizer
from einops import rearrange
import webdataset as wds
import numpy as np
import torchvision
import open_clip
import torch
import cv2

from distributed import init_distributed_device, is_primary
from vivq import VIVQ, BASE_SHAPE
from utils import sample_paella
from paella import DenoiseUNet

  warn(f"Failed to load image Python extension: {e}")
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
gc.collect()
torch.cuda.empty_cache()

In [3]:
SEP       = os.path.sep
ROOT_PATH = SEP.join(os.getcwd().split(SEP)[:-4])

In [4]:
args = edict({})
args.run_name = "Paella_Test"
args.model = "maskgit"
args.dataset = "first_stage"
args.urls = {

    "videos": "/home/projects/DataSets/webvid/dataset/00000.tar",
    "images": "/home/projects/DataSets/Coyo/coyo-700m-webdataset/{00000..03506}.tar"
}

args.total_steps = 300_000
args.batch_size = 4
args.num_workers = 1
args.log_period = 2000
args.extra_ckpt = 10_000
args.accum_grad = 2

args.vq_path = 'models/vivq_8192_drop_video/model_250000.pt'  
args.du_path = 'models/Paella_Test_/model_90000.pt'
args.dim = 1224  # 1224
args.num_tokens = 8192
args.max_seq_len = 6 * 16 * 16
args.depth = 22  # 22
args.dim_context = 1024  # for clip, 512 for T5
args.heads = 22  # 22

args.clip_len = 10
args.skip_frames = 5

args.n_nodes = 1
args.dist_url = "env://"
args.dist_backend = "nccl"
args.no_set_device_rank = False

print("Launching with args: ", args)


Launching with args:  {'run_name': 'Paella_Test', 'model': 'maskgit', 'dataset': 'first_stage', 'urls': {'videos': '/home/projects/DataSets/webvid/dataset/00000.tar', 'images': '/home/projects/DataSets/Coyo/coyo-700m-webdataset/{00000..03506}.tar'}, 'total_steps': 300000, 'batch_size': 4, 'num_workers': 1, 'log_period': 2000, 'extra_ckpt': 10000, 'accum_grad': 2, 'vq_path': 'models/vivq_8192_drop_video/model_250000.pt', 'du_path': 'models/Paella_Test_/model_90000.pt', 'dim': 1224, 'num_tokens': 8192, 'max_seq_len': 1536, 'depth': 22, 'dim_context': 1024, 'heads': 22, 'clip_len': 10, 'skip_frames': 5, 'n_nodes': 1, 'dist_url': 'env://', 'dist_backend': 'nccl', 'no_set_device_rank': False}


In [5]:
class MixImageVideoDataset(IterableDataset):
    def __init__(self, args):
        super().__init__()
        self.batch_size = args.batch_size  # TODO: split this into image bs and video bs
        self.video_dataset, self.image_dataset = self.init_dataloaders(args)

    def init_dataloaders(self, args):
        video_dataset = wds.WebDataset(args.urls["videos"], resampled=True, handler=warn_and_continue).decode(wds.torch_video,
                    handler=warn_and_continue).map(ProcessVideos(clip_len=args.clip_len, skip_frames=args.skip_frames),
                    handler=warn_and_continue).to_tuple("image", "video", "txt", handler=warn_and_continue).shuffle(690, handler=warn_and_continue)
        image_dataset = wds.WebDataset(args.urls["images"], resampled=True, handler=warn_and_continue).decode("rgb").map(
            ProcessImages(), handler=warn_and_continue).to_tuple("jpg", "txt", handler=warn_and_continue).shuffle(6969, initial=10000)
        return video_dataset, image_dataset

    def __iter__(self):
        sources = [iter(self.image_dataset), iter(self.video_dataset)]
        # sources = [iter(self.video_dataset), iter(self.image_dataset)]
        # sources = [iter(self.video_dataset)]
        while True:
            for source in sources:
                for _ in range(self.batch_size):
                    try:
                        yield next(source)
                    except StopIteration:
                        return


In [6]:
def collate_first_stage(batch):
    images = torch.stack([i[0] for i in batch], dim=0)
    videos = torch.stack([i[1] for i in batch], dim=0)
    return [images, videos]


def collate_second_stage(batch):

    if len(batch[0]) == 2:
        images = torch.stack([i[0] for i in batch], dim=0)
        videos = None
        captions = [i[1] for i in batch]
    else:
        images = torch.stack([i[0] for i in batch], dim=0)
        videos = torch.stack([i[1] for i in batch], dim=0)
        captions = [i[2] for i in batch]
        
    return [images, videos, captions]


class ProcessVideos:
    def __init__(self, clip_len=10, skip_frames=4):
        self.video_transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize(128),
            torchvision.transforms.RandomCrop(128)
        ])
        self.clip_len = clip_len
        self.skip_frames = skip_frames
        print(f"Using clip length of {clip_len} and {skip_frames} skip frames.")

    def __call__(self, data):
        video = data["mp4"][0]
        max_seek = video.shape[0] - (self.clip_len * self.skip_frames)
        if max_seek < 0:
            raise Exception(f"Video too short ({video.shape[0]} frames), skipping.")
        start = math.floor(random.uniform(0., max_seek))
        video = video[start:start+(self.clip_len*self.skip_frames)+1:self.skip_frames]
        video = video.permute(0, 3, 1, 2) / 255.
        if self.video_transform:
            video = self.video_transform(video)
        image, video = video[0], video[1:]
        data["image"] = image
        data["video"] = video
        if video.shape[0] != 10:
            raise Exception("Not 10 frames. But I should find the real cause lol for this happening.")
        return data


class ProcessImages:
    def __init__(self,):
        self.transforms = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Resize(128),
            torchvision.transforms.RandomCrop(128),
        ])

    def __call__(self, data):
        data["jpg"] = self.transforms(data["jpg"])
        return data


def get_dataloader(args):
    if args.dataset == "first_stage":
        dataset = wds.WebDataset(args.dataset_path, resampled=True, handler=warn_and_continue).decode(wds.torch_video,
                    handler=warn_and_continue).map(ProcessVideos(clip_len=args.clip_len, skip_frames=args.skip_frames),
                    handler=warn_and_continue).to_tuple("image", "video", handler=warn_and_continue).shuffle(690, handler=warn_and_continue)
        
        dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=collate_first_stage)  # TODO: num_workers=args.num_workers

    elif args.dataset == "second_stage":
        dataset = MixImageVideoDataset(args)
        dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate_second_stage, num_workers=args.num_workers)  # TODO: num_workers=args.num_workers

    else:
        dataset = VideoDataset(video_transform=transforms)
        dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)  # TODO: add num_workers=args.num_workers
    return dataloader

In [7]:
device = init_distributed_device(args)

vqmodel = VIVQ(codebook_size=args.num_tokens, model = 'maskgit').to(device)
vqmodel.load_state_dict(torch.load(args.vq_path, map_location=device))
vqmodel.vqmodule.q_step_counter += int(1e9)
vqmodel.eval().requires_grad_(False)

model = DenoiseUNet(num_labels = args.num_tokens, down_levels = [4, 6, 8],
                    up_levels  = [8, 6, 4], c_clip = args.dim_context).to(device)

# model.load_state_dict(torch.load(args.du_path, map_location = device))

clip_model, _, _ = open_clip.create_model_and_transforms('ViT-H-14', pretrained = 'laion2b_s32b_b79k',
                                                         cache_dir = '/fsx/max/.cache')

del clip_model.visual
model.eval()
clip_model = clip_model.to(device).eval().requires_grad_(False)

print('모델 로딩 완.')

모델 로딩 완.


In [8]:
dataset    = wds.WebDataset(args.urls["videos"], resampled=True, handler=warn_and_continue).decode(wds.torch_video,
                    handler=warn_and_continue).map(ProcessVideos(clip_len=args.clip_len, skip_frames=args.skip_frames),
                    handler=warn_and_continue).to_tuple("image", "video", "txt", handler=warn_and_continue).shuffle(690, handler=warn_and_continue)
dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=collate_first_stage)

Using clip length of 10 and 5 skip frames.


In [9]:
sample       = next(iter(dataloader))
image, video = sample
image.size(), video.size()

(torch.Size([4, 3, 128, 128]), torch.Size([4, 10, 3, 128, 128]))

In [10]:
cool_captions_text     = open('cool_captions.txt').read().splitlines()
text_tokens            = tokenizer.tokenize(cool_captions_text).to(device)
text_tokens_embeddings = clip_model.encode_text(text_tokens).float() 

In [11]:
cool_captions_sampled  = []
for caption_embedding in text_tokens_embeddings.chunk(10):

    caption_embedding = caption_embedding[0].float().to(device)
    caption_embedding = caption_embedding.unsqueeze(0)
    
    sampled_text      = sample_paella(model, caption_embedding)
    sampled_text      = vqmodel.decode_indices(sampled_text)

    for s in sampled_text:
        cool_captions_sampled.append(s.cpu())