In [513]:
import os
from PIL import Image, ImageSequence
import imageio
import numpy as np
import torch
import gc
from transformers import BertTokenizer, BertModel
from einops import rearrange
from einops_exts import check_shape, rearrange_many
import math
from torch import nn, einsum
import torch.nn.functional as F
from rotary_embedding_torch import RotaryEmbedding
from functools import partial

In [514]:
BERT_MODEL_DIM = 768
TOKENIZER = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-cased')
MODEL = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-cased')

Using cache found in C:\Users\david/.cache\torch\hub\huggingface_pytorch-transformers_main
Using cache found in C:\Users\david/.cache\torch\hub\huggingface_pytorch-transformers_main
Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequen

In [515]:
if torch.cuda.is_available():
    MODEL = MODEL.cuda()

In [516]:
def tokenize(texts, add_special_tokens = True):
    if not isinstance(texts, (list, tuple)):
        texts = [texts]

    encoding = TOKENIZER.batch_encode_plus(
        texts,
        add_special_tokens = add_special_tokens,
        padding = True,
        return_tensors = 'pt'
    )

    token_ids = encoding.input_ids
    return token_ids

@torch.no_grad()
def bert_embed(
    token_ids,
    return_cls_repr = False,
    eps = 1e-8,
    pad_id = 0.
):
    mask = token_ids != pad_id

    if torch.cuda.is_available():
        token_ids = token_ids.cuda()
        mask = mask.cuda()

    outputs = MODEL(input_ids = token_ids, attention_mask = mask, output_hidden_states = True)
    hidden_state = outputs.hidden_states[-1]
    cls_repr = outputs[0][:, 0, :] / (outputs[0][:, 0, :].norm(dim = -1, keepdim = True) + eps)

    if return_cls_repr:
        return hidden_state[:, 0]
    
    if mask is None:
        return hidden_state.mean(dim = 1)
    
    mask = rearrange(mask[:, 1:], 'b n -> b n 1')

    denom = mask.sum(dim = 1)
    return (hidden_state[:, 1:] * mask).sum(dim = 1) / (denom + eps)

In [517]:
def default(val, d):
    if val is not None:
        return val
    return d() if callable(d) else d

def Upsample(dim):
    return nn.ConvTranspose3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1))

def Downsample(dim):
    return nn.Conv3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1))

def prob_mask_like(shape, prob, device):
    if prob == 1:
        return torch.ones(shape, device = device, dtype = torch.bool)
    elif prob == 0:
        return torch.zeros(shape, device = device, dtype = torch.bool)
    else:
        return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
    
def normalize_img(t):
    return t * 2 - 1

def unnormalize_img(t):
    return (t + 1) * 0.5

def extract(a, t, shape):
    b, *_ = t.shape
    gathered = a.gather(-1, t)
    return gathered.reshape(b, *((1,) * (len(shape) - 1)))

In [518]:
class RelativePositionBias(torch.nn.Module):
    def __init__(self, heads = 8, buckets = 32, max_distance = 128
    ):
        super().__init__()
        self.buckets = buckets
        self.max_distance = max_distance
        self.relative_attention_bias = torch.nn.Embedding(buckets, heads)

    def forward(self, n, device):

        q_pos = torch.arange(n, dtype = torch.long, device = device)
        k_pos = torch.arange(n, dtype = torch.long, device = device)
        rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')

        rp_bucket = 0
        buckets = self.buckets // 2
        rp_bucket += (n < 0).long() * buckets
        n = torch.abs(-rel_pos)

        max_exact = buckets // 2
        is_small = n < max_exact

        val_if_large = max_exact + (
            torch.log(n.float() / max_exact) / math.log(self.max_distance / max_exact) * (buckets - max_exact)
        ).long()
        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, buckets - 1))

        rp_bucket += torch.where(is_small, n, val_if_large)
        return rearrange(self.relative_attention_bias(rp_bucket), 'i j h -> h i j')

In [519]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

In [520]:
class SinusoidalPosEmb(torch.nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        dev = x.device
        dim = self.dim // 2
        emb = torch.exp(torch.arange(dim, device=dev) * -(math.log(10000) / (dim - 1)))
        emb = x[:, None] * emb[None, :]
        return torch.cat((emb.sin(), emb.cos()), dim=-1)

In [521]:
class LayerNorm(torch.nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.gamma = torch.nn.Parameter(torch.ones(1, dim, 1, 1, 1))

    def forward(self, x):
        variance = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        return (x - torch.mean(x, dim = 1, keepdim = True)) / (variance + self.eps).sqrt() * self.gamma

In [522]:
class PreNorm(torch.nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.layer_norm = LayerNorm(dim)

    def forward(self, x, **kwargs):
        return self.fn(self.layer_norm(x), **kwargs)

In [523]:
class Block(torch.nn.Module):
    def __init__(self, dim_in, dim_out, groups = 8):
        super().__init__()
        self.proj = torch.nn.Conv3d(dim_in, dim_out, (1, 3, 3), padding = (0, 1, 1))
        self.norm = torch.nn.GroupNorm(groups, dim_out)
        self.silu = torch.nn.SiLU()

    def forward(self, x, scale_shift = None):
        x = self.proj(self.norm(x))

        if scale_shift is not None:
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        return self.silu(x)

In [524]:
class ResnetBlock(torch.nn.Module):
    def __init__(self, dim_in, dim_out, *, time = None, groups = 8):
        super().__init__()
        self.seq = torch.nn.Sequential(
            torch.nn.SiLU(),
            torch.nn.Linear(time, dim_out * 2)
        ) if (time is not None) else None

        self.block1 = Block(dim_in, dim_out, groups = groups)
        self.block2 = Block(dim_out, dim_out, groups = groups)
        self.conv = torch.nn.Conv3d(dim_in, dim_out, 1) if dim_in != dim_out else torch.nn.Identity()

    def forward(self, x, time = None):

        scale_shift = None
        if time is not None:
            assert time is not None, 'time embedding must be supplied if using time embedding mlp'
            time = self.seq(time)
            time = rearrange(time, 'b c -> b c 1 1 1')
            scale_shift = time.chunk(2, dim = 1)

        h = self.block1(x, scale_shift = scale_shift)
        return self.block2(h) + self.conv(x)

In [525]:
class SpatialLinearAttention(torch.nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.conv1 = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        self.conv2 = torch.nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, h, w = x.shape
        x = rearrange(x, 'b c f h w -> (b f) c h w')

        qkv = self.conv1(x).chunk(3, dim = 1)
        q, k, v = rearrange_many(qkv, 'b (h c) x y -> b h c (x y)', h = self.heads)

        q = q.softmax(dim = -2)
        k = k.softmax(dim = -1)

        context = torch.einsum('b h d n, b h e n -> b h d e', k, v)

        y = torch.einsum('b h d e, b h d n -> b h e n', context, q * self.scale)
        y = self.conv2(rearrange(y, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w))
        return rearrange(y, '(b f) c h w -> b c f h w', b = b)

In [526]:
class EinopsToAndFrom(nn.Module):
    def __init__(self, from_einops, to_einops, fn):
        super().__init__()
        self.from_einops = from_einops
        self.to_einops = to_einops
        self.fn = fn

    def forward(self, x, **kwargs):
        shape = x.shape
        reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(' '), shape)))
        x = self.fn(rearrange(x, f'{self.from_einops} -> {self.to_einops}'), **kwargs)
        return rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs)

In [527]:
class Attention(torch.nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 32, rotary_emb = None):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads

        self.rotary_emb = rotary_emb
        self.fc1 = torch.nn.Linear(dim, hidden_dim * 3, bias = False)
        self.fc2 = torch.nn.Linear(hidden_dim, dim, bias = False)

    def forward(self, x, pos_bias = None, focus_present_mask = None):
        n, device = x.shape[-2], x.device

        qkv = self.fc1(x).chunk(3, dim = -1)

        if (focus_present_mask is not None) and focus_present_mask.all():
            values = qkv[-1]
            return self.fc2(values)

        q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h = self.heads)

        if self.rotary_emb is not None:
            q = self.rotary_emb.rotate_queries_or_keys(q * self.scale)
            k = self.rotary_emb.rotate_queries_or_keys(k)

        sim = einsum('... h i d, ... h j d -> ... h i j', q, k)

        if pos_bias is not None:
            sim = sim + pos_bias

        if (focus_present_mask is not None) and not (~focus_present_mask).all():
            attend_all_mask = torch.ones((n, n), device = device, dtype = torch.bool)
            attend_self_mask = torch.eye(n, device = device, dtype = torch.bool)

            mask = torch.where(
                rearrange(focus_present_mask, 'b -> b 1 1 1 1'),
                rearrange(attend_self_mask, 'i j -> 1 1 1 i j'),
                rearrange(attend_all_mask, 'i j -> 1 1 1 i j'),
            )

            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

        sim = sim - sim.amax(dim = -1, keepdim = True).detach()
        attn = sim.softmax(dim = -1)
        y = einsum('... h i j, ... h j d -> ... h i d', attn, v)
        return self.fc2(rearrange(y, '... h n d -> ... n (h d)'))

In [528]:
class Unet3D(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.has_cond = True
        
        dim = 64 # the width and height of the feature map will be 64
        attn_heads = 8 # the number of heads for the spatial attention
        attn_dim_head = 32 # the dimension of each head for the spatial attention

        # temporal attention and its relative positional encoding
        rotary_emb = RotaryEmbedding(attn_dim_head) # positional encoding for the temporal attention
        temporal_attn = lambda dim: EinopsToAndFrom('b c f h w', 'b (h w) f c', Attention(dim, heads = attn_heads, dim_head = attn_dim_head, rotary_emb = rotary_emb))

        self.time_rel_pos_bias = RelativePositionBias(heads = attn_heads, max_distance = 32)

        # initial conv
        init_kernel_size = 7
        init_padding = init_kernel_size // 2
        self.init_conv = nn.Conv3d(3, dim, (1, init_kernel_size, init_kernel_size), padding = (0, init_padding, init_padding))

        self.init_temporal_attn = Residual(PreNorm(dim, temporal_attn(dim)))

        # dimensions
        dims = [dim, *map(lambda m: dim * m, [1, 2, 4, 8])]
        in_out = list(zip(dims[:-1], dims[1:]))

        # time conditioning
        time_dim = dim * 4
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(dim),
            nn.Linear(dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim)
        )

        # text conditioning
        self.null_cond_emb = nn.Parameter(torch.randn(1, BERT_MODEL_DIM))
        cond_dim = time_dim + BERT_MODEL_DIM

        # layers
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])

        num_resolutions = len(in_out)

        # block type
        block_klass = partial(ResnetBlock, groups=8)
        block_klass_cond = partial(block_klass, time=cond_dim)

        # modules for all layers
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(nn.ModuleList([
                block_klass_cond(dim_in, dim_out),
                block_klass_cond(dim_out, dim_out),
                Residual(PreNorm(dim_out, SpatialLinearAttention(dim_out, heads = attn_heads))),
                Residual(PreNorm(dim_out, temporal_attn(dim_out))),
                Downsample(dim_out) if not is_last else nn.Identity()
            ]))

        mid_dim = dims[-1]
        self.mid_block1 = block_klass_cond(mid_dim, mid_dim)

        spatial_attn = EinopsToAndFrom('b c f h w', 'b f (h w) c', Attention(mid_dim, heads = attn_heads))

        self.mid_spatial_attn = Residual(PreNorm(mid_dim, spatial_attn))
        self.mid_temporal_attn = Residual(PreNorm(mid_dim, temporal_attn(mid_dim)))

        self.mid_block2 = block_klass_cond(mid_dim, mid_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
            is_last = ind >= (num_resolutions - 1)

            self.ups.append(nn.ModuleList([
                block_klass_cond(dim_out * 2, dim_in),
                block_klass_cond(dim_in, dim_in),
                Residual(PreNorm(dim_in, SpatialLinearAttention(dim_in, heads = attn_heads))),
                Residual(PreNorm(dim_in, temporal_attn(dim_in))),
                Upsample(dim_in) if not is_last else nn.Identity()
            ]))

        out_dim = 3
        self.final_conv = nn.Sequential(
            block_klass(dim * 2, dim),
            nn.Conv3d(dim, out_dim, 1)
        )

    def forward_with_cond_scale(
        self,
        *args,
        cond_scale = 2.,
        **kwargs
    ):
        logits = self.forward(*args, null_cond_prob = 0., **kwargs)
        if cond_scale == 1:
            return logits

        null_logits = self.forward(*args, null_cond_prob = 1., **kwargs)
        return null_logits + (logits - null_logits) * cond_scale
    
    def forward(
        self,
        x,
        time,
        cond,
        null_cond_prob = 0.,
    ):
        batch, device = x.shape[0], x.device

        focus_present_mask = (lambda: prob_mask_like((batch,), 0, device = device))()

        time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device = x.device)

        x = self.init_conv(x)

        x = self.init_temporal_attn(x, pos_bias = time_rel_pos_bias)

        r = x.clone()

        t = self.time_mlp(time) if self.time_mlp is not None else None

        # classifier free guidance
        if self.has_cond:
            batch, device = x.shape[0], x.device
            mask = prob_mask_like((batch,), null_cond_prob, device = device)
            cond = torch.where(rearrange(mask, 'b -> b 1'), self.null_cond_emb, cond)
            t = torch.cat((t, cond), dim = -1)

        h = []

        for block1, block2, spatial_attn, temporal_attn, downsample in self.downs:
            x = block1(x, t)
            x = block2(x, t)
            x = spatial_attn(x)
            x = temporal_attn(x, pos_bias = time_rel_pos_bias, focus_present_mask = focus_present_mask)
            h.append(x)
            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_spatial_attn(x)
        x = self.mid_temporal_attn(x, pos_bias = time_rel_pos_bias, focus_present_mask = focus_present_mask)
        x = self.mid_block2(x, t)

        for block1, block2, spatial_attn, temporal_attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim = 1)
            x = block1(x, t)
            x = block2(x, t)
            x = spatial_attn(x)
            x = temporal_attn(x, pos_bias = time_rel_pos_bias, focus_present_mask = focus_present_mask)
            x = upsample(x)

        x = torch.cat((x, r), dim = 1)
        return self.final_conv(x)

In [529]:
class GaussianDiffusion(nn.Module):
    def __init__(self, denoise, *, text_user_bert_cls = False, channels = 3, num_frames, image_size, timesteps = 1000, loss_type = 'l1', use_dynamic_thres = False, dynamic_thres_percentile = 0.9):
        super().__init__()
        self.denoise = denoise
        self.channels = channels
        self.num_frames = num_frames
        self.image_size = image_size
        self.timesteps = timesteps
        self.loss_type = loss_type
        self.use_dynamic_thres = use_dynamic_thres
        self.dynamic_thres_percentile = dynamic_thres_percentile
        self.text_user_bert_cls = text_user_bert_cls
        self.loss_type = loss_type

        steps = timesteps + 1
        x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
        alphaprod_cum = torch.cos(((x - timesteps) + 0.008) / (1 + 0.008) * torch.pi / 2) ** 2
        alphaprod_cum = alphaprod_cum / alphaprod_cum[0]
        # Calculate betas
        betas = torch.clip(1 - (alphaprod_cum[1:] / alphaprod_cum[:-1]), 0, 0.9999)
        alphas = 1 - betas
        alphaprod_cum = torch.cumprod(alphas, axis=0)
        
        # Calculate previous alphas products
        prev_alphaprod_cum = F.pad(alphaprod_cum[:-1], (1, 0), value = 1)

        timesteps = betas.shape[0]

        self.num_steps = int(timesteps)

        # Write "register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))" in a different way
        def register_buffer(name, val):
            self.register_buffer(name, val.to(torch.float32))

        register_buffer('betas', betas)
        register_buffer('alphas_cumprod', torch.cumprod(alphas, dim = 0))
        register_buffer('alphas_cumprod_prev', prev_alphaprod_cum)

        register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphaprod_cum))
        register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1 / alphaprod_cum))
        register_buffer('sqrt_alphas_cumprod_minus_one', torch.sqrt(alphaprod_cum - 1))
        # register buffer for 1 - alpha_cumprod and log of it
        register_buffer('one_minus_alphas_cumprod', 1 - alphaprod_cum)
        register_buffer('log_one_minus_alphas_cumprod', torch.log(1 - alphaprod_cum))
        # register buffer for recip of 1 - alpha_cumprod
        register_buffer('recip_one_minus_alphas_cumprod', 1 / (1 - alphaprod_cum))

        posterior_var = betas * (1. - prev_alphaprod_cum) / (1. - alphaprod_cum)

        register_buffer('posterior_variance', posterior_var)

        # Since posterior variance is 0 for the first time step, we need to clamp it for numerical stability
        register_buffer('posterior_variance_clipped', torch.clamp(posterior_var, min = 1e-20))
        register_buffer('log_posterior_variance_clipped', torch.log(torch.clamp(posterior_var, min = 1e-20)))

        # calculate posterior mean coefficients
        mean_coef1 = betas * torch.sqrt(prev_alphaprod_cum) / (1 - alphaprod_cum)
        mean_coef2 = (1 - prev_alphaprod_cum) * torch.sqrt(alphaprod_cum) / (1 - alphaprod_cum)

        register_buffer('posterior_mean_coef1', mean_coef1)
        register_buffer('posterior_mean_coef2', mean_coef2)

        self.text_user_bert_cls = text_user_bert_cls
        self.dynamic_thres_percentile = dynamic_thres_percentile
        self.use_dynamic_thres = use_dynamic_thres
        
    def sample_q(self, x, t, noise=None, *args, **kwargs):
        if noise is not None:
            noise = noise
        else:
            noise = torch.randn_like(x)

        return (
            extract(self.sqrt_alphas_cumprod, t, x.shape) * x +
            extract(self.sqrt_alphas_cumprod_minus_one, t, x.shape) * noise
        )
        
    def p_losses(self, x, t, cond=None, noise=None, *args, **kwargs):
        if noise is not None:
            noise = noise
        else:
            noise = torch.randn_like(x)

        b, c, f, h, w, device = *x.shape, x.device

        noisy_x = self.sample_q(x=x, t=t, noise=noise)

        if isinstance(cond, (list, tuple)) and all(isinstance(s, str) for s in cond):
            cond = bert_embed(tokenize(cond), return_cls_repr = self.text_user_bert_cls)
            cond = cond.to(device)

        output_renoise = self.denoise(noisy_x, t, cond=cond, *args, **kwargs)

        if self.loss_type == 'l1':
            loss = F.l1_loss(noise, output_renoise)
        elif self.loss_type == 'l2':
            loss = F.mse_loss(noise, output_renoise)
        else:
            raise NotImplementedError()
        
        return loss
        
    
    def forward(self, x, *args, **kwargs):
        b, image_size, device = x.shape[0], self.image_size, x.device
        check_shape(x, 'b c f h w', c = self.channels, f = self.num_frames, h = image_size, w = image_size)
        return self.p_losses(x * 2 - 1, torch.randint(0, self.num_steps, (b,), device=device).long(), *args, **kwargs)

In [530]:
model = Unet3D()

In [531]:
diffusion = GaussianDiffusion(
    model,
    image_size=64,
    num_frames=10,
    timesteps = 1000,   # number of steps
    loss_type='l1',     # L1 or L2
).cuda()

In [532]:
gc.collect()
torch.cuda.empty_cache()

In [533]:
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
bert = BertModel.from_pretrained('bert-large-uncased', output_hidden_states=True)

Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [534]:
# load in the model
if os.path.exists('diffusion.pt'):
    diffusion.load_state_dict(torch.load('diffusion.pt'))

In [535]:
BATCH = 1
optim = torch.optim.Adam(diffusion.parameters(), lr=1e-4)

# train the diffusion model
for i in range(20):
    gif_names = os.listdir('gifs_64')
    
    # loop through the gifs 32 at a time
    for j in range(0, len(gif_names), BATCH):
        # the conds are the names of the gifs, remove the .gif extension
        names = gif_names[j:j+BATCH]
        conds = [name.split('.')[0] for name in names]
        
        # read the gifs and convert to tensors
        gifs = []
        for name in names:
            gif = imageio.mimread('gifs_64/' + name)
            if len(gif) < 10:
                gif = np.concatenate([gif, np.repeat(gif[-1:], 10 - len(gif), axis=0)])
            gif = gif[:10]
            gif = torch.tensor(gif).cuda()
            gif = gif.permute(3, 0, 1, 2) # (channels, frames, height, width)
            gif = gif / 127.5 - 1
            gifs.append(gif)
        gifs = torch.stack(gifs)
        
        # zero the gradients
        optim.zero_grad()
        
        # get the loss
        loss = diffusion(gifs, cond=conds)
        loss.backward()
        
        # update the weights
        optim.step()
        
        if j % 10 == 0:
            print(f'epoch {i}, batch {j}, loss {loss.item()}')

AttributeError: 'bool' object has no attribute 'long'

In [None]:
# save the model
torch.save(diffusion.state_dict(), 'diffusion.pt')

In [None]:
# get the 230'th name from gifs_64
best = os.listdir('gifs_64')[230].split('.')[0]
print(best)

a baby is playing with an orange toy


In [None]:
txt = 'green'

output_gif = diffusion.sample(cond=[txt])

sampling loop time step: 100%|██████████| 1000/1000 [1:01:20<00:00,  3.68s/it]   


In [None]:
# convert the tensor to a gif that we can see
out = output_gif[0].cpu().numpy()
out = np.transpose(out, (1, 2, 3, 0))
# the output is between -1 and 1, so we need to scale it to 0-255
out = ((out + 1) / 2 * 255).astype(np.uint8)
imageio.mimsave('out.gif', out, duration = 1000, loop = 0)