In [1]:
import math
import copy
import torch
from torch import nn, einsum
import torch.nn.functional as F
from functools import partial

from torch.utils import data
from pathlib import Path
from torch.optim import Adam
from torchvision import transforms as T, utils
from torch.cuda.amp import autocast, GradScaler
from PIL import Image

from tqdm import tqdm
from einops import rearrange
from einops_exts import check_shape, rearrange_many

from rotary_embedding_torch import RotaryEmbedding

from video_diffusion_pytorch.text import tokenize, bert_embed, BERT_MODEL_DIM

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
videos = torch.randn(2, 3, 5, 32, 32) # video (batch, channels, frames, height, width)
text = torch.randn(2, 64)             # assume output of BERT-large has dimension of 64

In [3]:
videos

tensor([[[[[-4.3891e-01, -4.0580e-01, -6.7191e-01,  ...,  1.5507e-02,
            -7.4809e-01, -6.1074e-03],
           [-3.8195e-01, -1.5135e+00, -8.3344e-02,  ..., -1.9664e-01,
            -1.0374e+00,  9.0908e-01],
           [-3.0262e-01,  1.2833e+00,  9.3202e-01,  ..., -6.8047e-01,
             3.7489e-01,  3.5777e-01],
           ...,
           [ 1.2676e+00,  1.0160e+00,  1.8983e+00,  ..., -3.9257e-01,
             3.4796e-01, -1.8097e+00],
           [ 1.2884e-01, -3.5021e-01,  6.6888e-01,  ...,  1.3110e+00,
             6.8381e-01,  1.6851e+00],
           [ 1.1797e+00, -6.6863e-01, -7.8625e-01,  ..., -7.5229e-01,
            -2.0344e+00,  2.4676e+00]],

          [[ 9.2108e-02, -1.1087e+00, -1.0631e+00,  ..., -6.0763e-01,
             5.9526e-01,  1.0224e+00],
           [-1.6439e+00, -8.4051e-01,  3.8818e-01,  ...,  6.5746e-02,
            -1.5066e+00, -8.2118e-01],
           [-9.5423e-01, -2.6438e-01,  9.3915e-01,  ...,  1.1333e+00,
            -7.4235e-01, -9.9136e-01],
 

In [4]:
from video_diffusion_pytorch import Unet3D, GaussianDiffusion

loss = diffusion(videos, cond = text)

In [4]:
def normalize_img(t):
    return t * 2 - 1

In [5]:
x = videos
image_size = 32
channels = 3
num_frames = 5
num_timesteps = 1000
timesteps = 1000
loss_type = 'l1'

b, device, img_size, = x.shape[0], x.device, image_size
check_shape(x, 'b c f h w', c = channels, f = num_frames, h = img_size, w = img_size)
# 随机生成一行b列的数据，从0到1000
t = torch.randint(0, num_timesteps, (b,), device=device).long()
x = normalize_img(x)

p_losses(x, t, *args, **kwargs)

In [9]:
def exists(x):
    return x is not None

In [10]:
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

In [11]:
x_start = x
b, c, f, h, w, device = *x_start.shape, x_start.device
noise = default(None, lambda: torch.randn_like(x_start))
#x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)

In [None]:
def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

In [None]:
# 余弦schedule
def cosine_beta_schedule(timesteps, s = 0.008):
    """
    cosine schedule
    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.9999)

In [None]:
betas = cosine_beta_schedule(timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)

timesteps, = betas.shape
num_timesteps = int(timesteps)
loss_type = loss_type


register_buffer = lambda name, val: register_buffer(name, val.to(torch.float32))

register_buffer('betas', betas)
register_buffer('alphas_cumprod', alphas_cumprod)
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)

# calculations for diffusion q(x_t | x_{t-1}) and others

register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))

# calculations for posterior q(x_{t-1} | x_t, x_0)

posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)

register_buffer('posterior_variance', posterior_variance)

# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain

register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))

In [None]:
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)

In [None]:
def q_sample(x_start, t, noise = None):
    noise = default(noise, lambda: torch.randn_like(x_start))

    return (
        extract(sqrt_alphas_cumprod, t, x_start.shape) * x_start +
        extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
    )