In [3]:
from video_diffusion_pytorch import Unet3D, GaussianDiffusion, Trainer
import os
from PIL import Image, ImageSequence
import imageio
import numpy as np
import torch
import gc
from transformers import BertTokenizer, BertModel

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
model = Unet3D(
    dim = 64,
    dim_mults=(1, 2, 4, 8),
)

In [5]:
diffusion = GaussianDiffusion(
    model,
    image_size=64,
    num_frames=10,
    timesteps = 1000,   # number of steps
    loss_type='l1',     # L1 or L2
).cuda()

In [6]:
gc.collect()
torch.cuda.empty_cache()

In [7]:
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
bert = BertModel.from_pretrained('bert-large-uncased', output_hidden_states=True)

Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [15]:
# load in the model
diffusion.load_state_dict(torch.load('diffusion.pt'))

<All keys matched successfully>

In [16]:
BATCH = 1
optim = torch.optim.Adam(diffusion.parameters(), lr=1e-4)

# train the diffusion model
for i in range(20):
    gif_names = os.listdir('gifs_64')
    
    # loop through the gifs 32 at a time
    for j in range(0, len(gif_names), BATCH):
        # the conds are the names of the gifs, remove the .gif extension
        names = gif_names[j:j+BATCH]
        conds = [name.split('.')[0] for name in names]
        
        # read the gifs and convert to tensors
        gifs = []
        for name in names:
            gif = imageio.mimread('gifs_64/' + name)
            if len(gif) < 10:
                gif = np.concatenate([gif, np.repeat(gif[-1:], 10 - len(gif), axis=0)])
            gif = gif[:10]
            gif = torch.tensor(gif).cuda()
            gif = gif.permute(3, 0, 1, 2) # (channels, frames, height, width)
            gif = gif / 127.5 - 1
            gifs.append(gif)
        gifs = torch.stack(gifs)
        
        # zero the gradients
        optim.zero_grad()
        
        # get the loss
        loss = diffusion(gifs, cond=conds)
        loss.backward()
        
        # update the weights
        optim.step()
        
        if j % 10 == 0:
            print(f'epoch {i}, batch {j}, loss {loss.item()}')

epoch 0, batch 0, loss 0.07453851401805878
epoch 0, batch 10, loss 0.4417616128921509
epoch 0, batch 20, loss 0.04349784925580025
epoch 0, batch 30, loss 0.09936734288930893
epoch 0, batch 40, loss 0.09539846330881119
epoch 0, batch 50, loss 0.1495378315448761
epoch 0, batch 60, loss 0.1764092594385147
epoch 0, batch 70, loss 0.27043285965919495
epoch 0, batch 80, loss 0.07030817121267319
epoch 0, batch 90, loss 0.08462493121623993
epoch 0, batch 100, loss 0.22654083371162415
epoch 0, batch 110, loss 0.20555029809474945
epoch 0, batch 120, loss 0.051872946321964264
epoch 0, batch 130, loss 0.24013517796993256
epoch 0, batch 140, loss 0.5099602937698364
epoch 0, batch 150, loss 0.0448327362537384
epoch 0, batch 160, loss 0.04682619497179985
epoch 0, batch 170, loss 0.08727879822254181
epoch 0, batch 180, loss 0.2072146236896515
epoch 0, batch 190, loss 0.11641300469636917
epoch 0, batch 200, loss 0.15204967558383942
epoch 0, batch 210, loss 0.40269312262535095
epoch 0, batch 220, loss 0

In [17]:
# save the model
torch.save(diffusion.state_dict(), 'diffusion.pt')

In [18]:
# get the 230'th name from gifs_64
best = os.listdir('gifs_64')[230].split('.')[0]
print(best)

a baby is playing with an orange toy


In [23]:
txt = 'green'

output_gif = diffusion.sample(cond=[txt])

sampling loop time step: 100%|██████████| 1000/1000 [1:01:20<00:00,  3.68s/it]   


In [24]:
# convert the tensor to a gif that we can see
out = output_gif[0].cpu().numpy()
out = np.transpose(out, (1, 2, 3, 0))
# the output is between -1 and 1, so we need to scale it to 0-255
out = ((out + 1) / 2 * 255).astype(np.uint8)
imageio.mimsave('out.gif', out, duration = 1000, loop = 0)