In [1]:
from video_diffusion_pytorch import Unet3D, GaussianDiffusion, Trainer
import os
from PIL import Image, ImageSequence
import imageio
import numpy as np
import torch
import gc
from transformers import BertTokenizer, BertModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = Unet3D(
    dim = 64,
    dim_mults=(1, 2, 4, 8),
)

In [3]:
diffusion = GaussianDiffusion(
    model,
    image_size=64,
    num_frames=10,
    timesteps = 1000,   # number of steps
    loss_type='l1',     # L1 or L2
).cuda()

In [4]:
# create gifs_64 folder if it doesn't exist
if not os.path.exists('gifs_64'):
    os.makedirs('gifs_64')

def resize(frames):
    for frame in frames:
        cpy = frame.copy()
        cpy = cpy.resize((64, 64))
        yield cpy

# create a new folder gifs_64 that contains 64x64 resized versions of the gifs in the folder gifs
for file in os.listdir('gifs'):
    if file.endswith('.gif'):
        # read each frame of the gif
        gif = Image.open('gifs/' + file)
        
        # get the frames
        frames = ImageSequence.Iterator(gif)
        
        # resize
        frames = resize(frames)
        
        # save the resized frames
        om = next(frames)
        om.info = gif.info
        om.save('gifs_64/' + file, save_all=True, append_images=list(frames), loop=0)

KeyboardInterrupt: 

In [4]:
gc.collect()
torch.cuda.empty_cache()

In [5]:
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
bert = BertModel.from_pretrained('bert-large-uncased', output_hidden_states=True)

Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
BATCH = 1
optim = torch.optim.Adam(diffusion.parameters(), lr=1e-4)

# train the diffusion model
for i in range(10):
    gif_names = os.listdir('gifs_64')
    
    # loop through the gifs 32 at a time
    for j in range(0, len(gif_names), BATCH):
        # the conds are the names of the gifs, remove the .gif extension
        names = gif_names[j:j+BATCH]
        conds = [name.split('.')[0] for name in names]
        
        # read the gifs and convert to tensors
        gifs = []
        for name in names:
            gif = imageio.mimread('gifs_64/' + name)
            if len(gif) < 10:
                gif = np.concatenate([gif, np.repeat(gif[-1:], 10 - len(gif), axis=0)])
            gif = gif[:10]
            gif = torch.tensor(gif).cuda()
            gif = gif.permute(3, 0, 1, 2) # (channels, frames, height, width)
            gif = gif / 127.5 - 1
            gifs.append(gif)
        gifs = torch.stack(gifs)
        
        # zero the gradients
        optim.zero_grad()
        
        # get the loss
        loss = diffusion(gifs, cond=conds)
        loss.backward()
        
        # update the weights
        optim.step()
        
        if j % 10 == 0:
            print(f'epoch {i}, batch {j}, loss {loss.item()}')

  gif = torch.tensor(gif).cuda()
Using cache found in C:\Users\david/.cache\torch\hub\huggingface_pytorch-transformers_main
Using cache found in C:\Users\david/.cache\torch\hub\huggingface_pytorch-transformers_main
Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassifi

epoch 0, batch 0, loss 0.8698789477348328
epoch 0, batch 10, loss 0.7602686882019043
epoch 0, batch 20, loss 0.6494357585906982
epoch 0, batch 30, loss 0.6202204823493958
epoch 0, batch 40, loss 0.5340890288352966
epoch 0, batch 50, loss 0.383981317281723
epoch 0, batch 60, loss 0.4594486355781555
epoch 0, batch 70, loss 0.34664252400398254
epoch 0, batch 80, loss 0.5275373458862305
epoch 0, batch 90, loss 0.2943350076675415
epoch 0, batch 100, loss 0.39578238129615784
epoch 0, batch 110, loss 0.25666141510009766
epoch 0, batch 120, loss 0.2454884946346283
epoch 0, batch 130, loss 0.21062105894088745
epoch 0, batch 140, loss 0.19643940031528473
epoch 0, batch 150, loss 0.45845767855644226
epoch 0, batch 160, loss 0.19678707420825958
epoch 0, batch 170, loss 0.537451446056366
epoch 0, batch 180, loss 0.4582083821296692
epoch 0, batch 190, loss 0.634088933467865
epoch 0, batch 200, loss 0.2907816171646118
epoch 0, batch 210, loss 0.5245013236999512
epoch 0, batch 220, loss 0.181067973375

KeyboardInterrupt: 

In [7]:
txt = 'man jumps up and down'

output_gif = diffusion.sample(cond=[txt])

sampling loop time step: 100%|██████████| 1000/1000 [06:54<00:00,  2.41it/s]


In [9]:
# convert the tensor to a gif that we can see
out = output_gif[0].cpu().numpy()
out = np.transpose(out, (1, 2, 3, 0))
# the output is between -1 and 1, so we need to scale it to 0-255
out = ((out + 1) / 2 * 255).astype(np.uint8)
imageio.mimsave('out.gif', out, duration = 1000, loop = 0)