In [3]:
import math
import copy
import torch
from torch import nn, einsum
import torch.nn.functional as F
from functools import partial

from torch.utils import data
from pathlib import Path
from torch.optim import Adam
from torchvision import transforms as T, utils
from torch.cuda.amp import autocast, GradScaler
from PIL import Image

from tqdm import tqdm
from einops import rearrange
from einops_exts import check_shape, rearrange_many

from rotary_embedding_torch import RotaryEmbedding

from video_diffusion_pytorch.text import tokenize, bert_embed, BERT_MODEL_DIM

  from .autonotebook import tqdm as notebook_tqdm


In [9]:
class Block(nn.Module):
    def __init__(self, dim, dim_out, groups = 8):
        super().__init__()
        self.proj = nn.Conv3d(dim, dim_out, (1, 3, 3), padding = (0, 1, 1))
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift = None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        return self.act(x)

class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, dim_out * 2)
        ) if exists(time_emb_dim) else None

        self.block1 = Block(dim, dim_out, groups = groups)
        self.block2 = Block(dim_out, dim_out, groups = groups)
        self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb = None):

        scale_shift = None
        if exists(self.mlp):
            assert exists(time_emb), 'time emb must be passed in'
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, 'b c -> b c 1 1 1')
            scale_shift = time_emb.chunk(2, dim = 1)

        h = self.block1(x, scale_shift = scale_shift)

        h = self.block2(h)
        return h + self.res_conv(x)

In [10]:
A = torch.randn(3, 4)
B = torch.randn(4, 5)
A

tensor([[ 0.8416, -1.2116, -0.8289, -0.7209],
        [-0.1592, -1.2871, -1.9590, -1.4572],
        [ 1.1402,  0.3616,  0.3398,  1.2067]])

In [11]:
B

tensor([[-2.0876,  0.3313, -0.5288,  0.6752,  0.4685],
        [ 0.3515,  1.9028,  0.4115,  0.8687, -0.7149],
        [-0.3824, -0.5538,  0.3706, -0.0309,  1.0972],
        [ 0.2812,  1.0520, -0.7916,  0.3152,  0.8573]])

In [12]:
torch.einsum('ik,kj->ij', A, B)

tensor([[-2.0685, -2.3259, -0.6802, -0.6859, -0.2670],
        [ 0.2193, -2.9499, -0.0179, -1.6243, -2.5531],
        [-2.0438,  2.1471, -1.2834,  1.4538,  1.6830]])

In [14]:
import torch
from video_diffusion_pytorch import Unet3D, GaussianDiffusion

model = Unet3D(
    dim = 64,
    use_bert_text_cond = True,  # this must be set to True to auto-use the bert model dimensions
    dim_mults = (1, 2, 4, 8),
)

diffusion = GaussianDiffusion(
    model,
    image_size = 32,    # height and width of frames
    num_frames = 5,     # number of video frames
    timesteps = 1000,   # number of steps
    loss_type = 'l1'    # L1 or L2
)

videos = torch.randn(3, 3, 5, 32, 32) # video (batch, channels, frames, height, width)

text = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles'
]

In [16]:
loss = diffusion(videos, cond = text)

HTTPError: HTTP Error 403: rate limit exceeded

In [1]:
def check_shape(tensor, pattern, **kwargs):
    return rearrange(tensor, f"{pattern} -> {pattern}", **kwargs)

In [7]:
x= torch.randn(2, 3, 5, 32, 32)
x

tensor([[[[[-8.7177e-01,  1.4804e+00,  4.6724e-01,  ..., -3.6518e-01,
             8.6641e-01,  6.4620e-01],
           [ 4.7141e-01,  1.2190e+00, -1.4006e+00,  ..., -8.5927e-01,
             1.1078e+00,  1.0962e+00],
           [-3.9610e-01,  2.9317e-01, -1.6449e-01,  ..., -1.0622e+00,
             3.9326e-01, -1.7272e+00],
           ...,
           [ 1.2459e+00,  2.5615e-03, -1.3402e+00,  ...,  1.0643e+00,
             1.0806e+00,  5.2228e-01],
           [ 2.2934e+00, -5.0693e-01, -1.5094e-01,  ..., -1.7517e-01,
             9.5116e-01,  4.2418e-01],
           [-6.2055e-01,  6.4942e-01, -1.7191e+00,  ..., -9.8483e-03,
             1.1283e+00, -1.2933e+00]],

          [[ 5.9575e-01, -1.4136e+00,  2.3844e-01,  ..., -1.7149e+00,
             1.4036e+00,  1.1823e+00],
           [-1.1108e+00,  3.0106e-01,  5.5318e-01,  ...,  2.8564e-01,
             5.1822e-01, -1.0516e+00],
           [-1.3084e+00,  6.4336e-01,  1.1504e+00,  ..., -1.2653e+00,
             1.1751e+00, -7.8095e-01],
 

In [8]:
b, device, img_size, = x.shape[0], x.device, 32 # b=2, device=cuda:0, img_size=32 
check_shape(x, 'b c f h w', c = 3, f = 5, h = img_size, w = img_size)

tensor([[[[[-8.7177e-01,  1.4804e+00,  4.6724e-01,  ..., -3.6518e-01,
             8.6641e-01,  6.4620e-01],
           [ 4.7141e-01,  1.2190e+00, -1.4006e+00,  ..., -8.5927e-01,
             1.1078e+00,  1.0962e+00],
           [-3.9610e-01,  2.9317e-01, -1.6449e-01,  ..., -1.0622e+00,
             3.9326e-01, -1.7272e+00],
           ...,
           [ 1.2459e+00,  2.5615e-03, -1.3402e+00,  ...,  1.0643e+00,
             1.0806e+00,  5.2228e-01],
           [ 2.2934e+00, -5.0693e-01, -1.5094e-01,  ..., -1.7517e-01,
             9.5116e-01,  4.2418e-01],
           [-6.2055e-01,  6.4942e-01, -1.7191e+00,  ..., -9.8483e-03,
             1.1283e+00, -1.2933e+00]],

          [[ 5.9575e-01, -1.4136e+00,  2.3844e-01,  ..., -1.7149e+00,
             1.4036e+00,  1.1823e+00],
           [-1.1108e+00,  3.0106e-01,  5.5318e-01,  ...,  2.8564e-01,
             5.1822e-01, -1.0516e+00],
           [-1.3084e+00,  6.4336e-01,  1.1504e+00,  ..., -1.2653e+00,
             1.1751e+00, -7.8095e-01],
 

In [6]:
f"{'b c f h w'} -> {'b c f h w'}"

'b c f h w -> b c f h w'

In [9]:
torch.randint(0, 1000, (b,), device=device).long()

tensor([389, 494])

In [10]:
x*2-1

tensor([[[[[-2.7435,  1.9609, -0.0655,  ..., -1.7304,  0.7328,  0.2924],
           [-0.0572,  1.4380, -3.8012,  ..., -2.7185,  1.2157,  1.1923],
           [-1.7922, -0.4137, -1.3290,  ..., -3.1245, -0.2135, -4.4543],
           ...,
           [ 1.4918, -0.9949, -3.6804,  ...,  1.1285,  1.1611,  0.0446],
           [ 3.5869, -2.0139, -1.3019,  ..., -1.3503,  0.9023, -0.1516],
           [-2.2411,  0.2988, -4.4382,  ..., -1.0197,  1.2566, -3.5866]],

          [[ 0.1915, -3.8271, -0.5231,  ..., -4.4299,  1.8073,  1.3646],
           [-3.2216, -0.3979,  0.1064,  ..., -0.4287,  0.0364, -3.1033],
           [-3.6169,  0.2867,  1.3009,  ..., -3.5306,  1.3502, -2.5619],
           ...,
           [-1.4969,  1.7879,  0.6920,  ..., -0.0137, -2.9727,  1.1013],
           [-2.1190,  1.6372,  0.7390,  ...,  0.6157, -2.1099, -1.6527],
           [ 1.0462, -2.7141, -1.1130,  ..., -0.3001,  0.0547, -2.3800]],

          [[ 1.6507, -2.0926, -0.1980,  ..., -1.3553,  2.1718, -1.7503],
           [-4.

In [3]:
import torch
torch.cuda.is_available()

AttributeError: module 'torch' has no attribute 'cuda'