In [5]:
import math
import copy
import torch
from torch import nn, einsum
import torch.nn.functional as F
from functools import partial

from torch.utils import data
from pathlib import Path
from torch.optim import Adam
from torchvision import transforms as T, utils
from torch.cuda.amp import autocast, GradScaler
from PIL import Image

from tqdm import tqdm
from einops import rearrange
from einops_exts import check_shape, rearrange_many

from rotary_embedding_torch import RotaryEmbedding

from video_diffusion_pytorch.text import tokenize, bert_embed, BERT_MODEL_DIM

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
videos = torch.randn(2, 3, 5, 32, 32) # video (batch, channels, frames, height, width)
text = torch.randn(2, 64)             # assume output of BERT-large has dimension of 64

In [7]:
from video_diffusion_pytorch import Unet3D, GaussianDiffusion

loss = diffusion(videos, cond = text)

In [8]:
def normalize_img(t):
    return t * 2 - 1

In [9]:
x = videos
image_size = 32
channels = 3
num_frames = 5
num_timesteps = 1000
timesteps = 1000
loss_type = 'l1'

b, device, img_size, = x.shape[0], x.device, image_size
check_shape(x, 'b c f h w', c = channels, f = num_frames, h = img_size, w = img_size)
# 随机生成一行b列的数据，从0到1000
t = torch.randint(0, num_timesteps, (b,), device=device).long()
x = normalize_img(x)

p_losses(x, t, *args, **kwargs)

In [10]:
def exists(x):
    return x is not None

In [11]:
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

In [12]:
x_start = x
b, c, f, h, w, device = *x_start.shape, x_start.device
noise = default(None, lambda: torch.randn_like(x_start))
#x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)

In [13]:
def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

In [14]:
# 余弦schedule
def cosine_beta_schedule(timesteps, s = 0.008):
    """
    cosine schedule
    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.9999)

In [15]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

In [16]:
betas = cosine_beta_schedule(timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)

timesteps, = betas.shape
num_timesteps = int(timesteps)
loss_type = loss_type


# register_buffer = lambda name, val: register_buffer(name, val.to(torch.float32))

# register_buffer('betas', betas)
# register_buffer('alphas_cumprod', alphas_cumprod)
# register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)

# # calculations for diffusion q(x_t | x_{t-1}) and others

# register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
# register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
# register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
# register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
# register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))

# # calculations for posterior q(x_{t-1} | x_t, x_0)

# posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

# # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)

# register_buffer('posterior_variance', posterior_variance)

# # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain

# register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
# register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
# register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))

In [14]:
register_buffer = lambda name, val: register_buffer(name, val.to(torch.float32))

In [15]:
register_buffer('betas', betas)
# register_buffer('alphas_cumprod', alphas_cumprod)
# register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)

: 

: 

In [None]:
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)

In [None]:
def q_sample(x_start, t, noise = None):
    noise = default(noise, lambda: torch.randn_like(x_start))

    return (
        extract(sqrt_alphas_cumprod, t, x_start.shape) * x_start +
        extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
    )

In [None]:
def is_list_str(x):
    if not isinstance(x, (list, tuple)):
        return False
    return all([type(el) == str for el in x])

In [None]:
def p_losses(x_start, t, cond = None, noise = None, **kwargs):
        b, c, f, h, w, device = *x_start.shape, x_start.device
        noise = default(noise, lambda: torch.randn_like(x_start))

        x_noisy = q_sample(x_start=x_start, t=t, noise=noise)

        if is_list_str(cond):
            cond = bert_embed(tokenize(cond), return_cls_repr = self.text_use_bert_cls)
            cond = cond.to(device)

        x_recon = denoise_fn(x_noisy, t, cond = cond, **kwargs)

        if loss_type == 'l1':
            loss = F.l1_loss(noise, x_recon)
        elif loss_type == 'l2':
            loss = F.mse_loss(noise, x_recon)
        else:
            raise NotImplementedError()

        return loss

In [19]:
model = Unet3D(
    dim = 64,
    cond_dim = 64,
    dim_mults = (1, 2, 4, 8)
)
diffusion = GaussianDiffusion(
    model,
    image_size = 32,
    num_frames = 5,
    timesteps = 1000,   # number of steps
    loss_type = 'l1'    # L1 or L2
)

In [20]:
loss = diffusion(videos, cond = text)
x = normalize_img(videos)
t = torch.randint(0, num_timesteps, (b,), device=device).long()
noise = default(noise, lambda: torch.randn_like(x_start))
x = diffusion.q_sample(videos, t, noise=noise)

In [17]:
def prob_mask_like(shape, prob, device):
    if prob == 1:
        return torch.ones(shape, device = device, dtype = torch.bool)
    elif prob == 0:
        return torch.zeros(shape, device = device, dtype = torch.bool)
    else:
        return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob

In [22]:
focus_present_mask = None
prob_focus_present = 0
batch, device = x.shape[0], x.device
focus_present_mask = default(focus_present_mask, lambda: prob_mask_like((batch,), prob_focus_present, device = device))

In [23]:
class RelativePositionBias(nn.Module):
    def __init__(
        self,
        heads = 8,
        num_buckets = 32,
        max_distance = 128
    ):
        super().__init__()
        self.num_buckets = num_buckets
        self.max_distance = max_distance
        self.relative_attention_bias = nn.Embedding(num_buckets, heads)

    @staticmethod
    def _relative_position_bucket(relative_position, num_buckets = 32, max_distance = 128):
        ret = 0
        n = -relative_position

        num_buckets //= 2
        ret += (n < 0).long() * num_buckets
        n = torch.abs(n)

        max_exact = num_buckets // 2
        is_small = n < max_exact

        val_if_large = max_exact + (
            torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
        ).long()
        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))

        ret += torch.where(is_small, n, val_if_large)
        return ret

    def forward(self, n, device):
        q_pos = torch.arange(n, dtype = torch.long, device = device)
        k_pos = torch.arange(n, dtype = torch.long, device = device)
        rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
        rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
        values = self.relative_attention_bias(rp_bucket)
        return rearrange(values, 'i j h -> h i j')


In [None]:
time_rel_pos_bias = RelativePositionBias(heads = 8, max_distance = 32)
time_rel_pos_bias = time_rel_pos_bias(x.shape[2], device = x.device)
