## Implement Wavelet Transform in Convolution

In [65]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn.functional as F
import decord
from einops import rearrange
from diffusers.utils import export_to_video

In [76]:
video = decord.VideoReader('0003.mp4').get_batch(range(16)).asnumpy()
video = torch.tensor(video).unsqueeze_(0).permute(0,4,1,2,3).sub(127.5).div(127.5).to(torch.float16)
video.shape, video.min(), video.max()

(torch.Size([1, 3, 16, 480, 720]),
 tensor(-1., dtype=torch.float16),
 tensor(1., dtype=torch.float16))

In [77]:
# B, C, T, H, W = 2, 16, 8, 32, 32
# tensor = torch.randn(B, C, T, H, W, dtype=torch.float16)
B, C, T, H, W = video.shape
tensor = video

In [78]:
class WaveletTransform(torch.nn.Module):
    def __init__(self):
        super(WaveletTransform, self).__init__()

        h0 = (1 + np.sqrt(3)) / 4
        h1 = (3 + np.sqrt(3)) / 4
        self.low_pass_filter = torch.tensor([h0, h1], dtype=torch.float16).unsqueeze(0).unsqueeze(0) # [out_channels, in_channels, kernel_height, kernel_width]
        g0 = (3 - np.sqrt(3)) / 4
        g1 = (1 - np.sqrt(3)) / 4
        self.high_pass_filter = torch.tensor([g0, g1], dtype=torch.float16).unsqueeze(0).unsqueeze(0) # [out_channels, in_channels, kernel_height, kernel_width]
    def forward(self, x):
        low = F.conv1d(x, self.low_pass_filter, stride=2, padding=0)
        high = F.conv1d(x, self.high_pass_filter, stride=2, padding=0)
        return low, high
    def inverse(self, low, high):
        x = F.conv_transpose1d(low, self.low_pass_filter, stride=2, padding=0)
        x += F.conv_transpose1d(high, self.high_pass_filter, stride=2, padding=0)
        return x

In [79]:
transform = WaveletTransform()

In [80]:
# on h-dim
tensor = video
tensor = rearrange(tensor, 'b c t h w -> (b c t w) 1 h')
h_a, h_d = transform(tensor)
h_a = rearrange(h_a, '(b c t w) 1 h -> b c t h w', b=B, c=C, t=T, h=H//2, w=W)
h_d = rearrange(h_d, '(b c t w) 1 h -> b c t h w', b=B, c=C, t=T, h=H//2, w=W)
# on w-dim
h_a = rearrange(h_a, 'b c t h w -> (b c t h) 1 w')
h_a_w_a, h_a_w_d = transform(h_a)
h_a_w_a = rearrange(h_a_w_a, '(b c t h) 1 w -> b c t h w', b=B, c=C, t=T, h=H//2, w=W//2)
h_a_w_d = rearrange(h_a_w_d, '(b c t h) 1 w -> b c t h w', b=B, c=C, t=T, h=H//2, w=W//2)
# on t-dim
h_a_w_a = rearrange(h_a_w_a, 'b c t h w -> (b c h w) 1 t')
h_a_w_a_t_a, h_a_w_a_t_d = transform(h_a_w_a)
h_a_w_a_t_a = rearrange(h_a_w_a_t_a, '(b c h w) 1 t -> b c t h w', b=B, c=C, t=T//2, h=H//2, w=W//2)
h_a_w_a_t_d = rearrange(h_a_w_a_t_d, '(b c h w) 1 t -> b c t h w', b=B, c=C, t=T//2, h=H//2, w=W//2)

In [81]:
h_a_w_a_t_d.shape, h_a_w_a_t_d.min(), h_a_w_a_t_d.max()

(torch.Size([1, 3, 8, 240, 360]),
 tensor(-1.5859, dtype=torch.float16),
 tensor(1.4922, dtype=torch.float16))

In [83]:
def view_cthw(tensor):
    tensor = rearrange(tensor, 'c t h w -> t h w c')
    tensor = ((tensor - tensor.min()) / (tensor.max() - tensor.min())).mul(255).to(torch.uint8).numpy()
    video = [Image.fromarray(frame) for frame in tensor]
    return video

In [None]:
export_to_video(view_cthw(h_a_w_a_t_d[0]), './h_a_w_a_t_d.mp4', fps=4)  # 时间上高频，空间上低频 --> motion
export_to_video(view_cthw(h_a_w_a_t_a[0]), './h_a_w_a_t_a.mp4', fps=4)  
export_to_video(view_cthw(h_a_w_d[0]), './h_a_w_d.mp4', fps=4)



'./h_a_w_d.mp4'

## VidGen-1M

In [1]:
import json

In [2]:
meta_file = '/backup/data/qiguojunLab/VidGen-1M-meta/VidGen_1M_video_caption.json'
with open(meta_file, 'r') as f:
    meta = json.load(f)

In [4]:
meta[0]

{'vid': 'Eep9uvenxAo-Scene-0030',
 'caption': "The video shows a person's hand touching and moving flowers on a plant. The flowers are red in color and the plant has green leaves. The person's hand is visible in the foreground, and the background shows a house and a driveway. The video is shot during the daytime, and the lighting is natural."}