In [1]:
import torch
from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer
from utils import gif75speaker
import numpy as np
from torchvision import transforms
import argparse



## Create generated dataset to compare

In [2]:
dataset_75speaker = gif75speaker(image_path = './datasets/preprocessed_dataset/test', 
                                img_per_gif=10, 
                                audio_path = './datasets/preprocessed_dataset/wav2vec2-l60', 
                                audio_pooling = None)

In [3]:
unet1 = Unet3D(dim = 64, dim_mults = (1, 2, 4, 8)).cuda()
unet2 = Unet3D(dim = 64, dim_mults = (1, 2, 4, 8)).cuda()

imagen = ElucidatedImagen(
    text_embed_dim=1024,
    unets = (unet1, unet2),
    image_sizes = (64, 64),
    random_crop_sizes = (None, 16),
    temporal_downsample_factor = (1, 1),        # in this example, the first unet would receive the video temporally downsampled by 2x
    num_sample_steps = 10,
    cond_drop_prob = 0.1,
    sigma_min = 0.002,                          # min noise level
    sigma_max = (80, 160),                      # max noise level, double the max noise level for upsampler
    sigma_data = 0.5,                           # standard deviation of data distribution
    rho = 7,                                    # controls the sampling schedule
    P_mean = -1.2,                              # mean of log-normal distribution from which noise is drawn for training
    P_std = 1.2,                                # standard deviation of log-normal distribution from which noise is drawn for training
    S_churn = 80,                               # parameters for stochastic sampling - depends on dataset, Table 5 in apper
    S_tmin = 0.05,
    S_tmax = 50,
    S_noise = 1.003,
).cuda()

imagen.load_state_dict(torch.load('./checkpoints/ImagenVideo-Modelwav2vec2-l60-pho-PoolingFalse-IgnoreTimeFalse-TwoStepFalse-100'))
trainer = ImagenTrainer(imagen,
    split_valid_from_train = True, # whether to split the validation dataset from the training
    dl_tuple_output_keywords_names = ('images', 'text_embeds', 'cond_video_frames')
).cuda()

The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/


dataloader_config = DataLoaderConfiguration(split_batches=True)


In [4]:
!rm -rf ./generated_images
!mkdir ./generated_images

In [6]:
for i in range(10):
    (_, aud_emb, cond_video_frames) = dataset_75speaker[i]
    print(f'{dataset_75speaker.get_names(i)}')
    videos = trainer.sample(text_embeds = aud_emb.unsqueeze(0), video_frames = 10, stop_at_unet_number  = 1, batch_size = 1, cond_video_frames=cond_video_frames.unsqueeze(0))
    imgs = torch.transpose(videos[0], 0, 1)
    imgs = [transforms.ToPILImage()(img) for img in imgs]
    # duration is the number of milliseconds between frames; this is 40 frames per second
    # model_name = opt.audio_path.split('/')[-1]
    imgs[0].save(f'./generated_images/{dataset_75speaker.get_names(i)}.gif', save_all=True, append_images=imgs[1:], duration=10, loop=0)

sub001_2drt_18_topic2_video-0
unet 1 has not been trained
unet 2 has not been trained
when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets


0it [00:00, ?it/s]

sampling time step:   0%|          | 0/10 [00:00<?, ?it/s]

sub001_2drt_18_topic2_video-10
unet 1 has not been trained
unet 2 has not been trained
when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets


0it [00:00, ?it/s]

sampling time step:   0%|          | 0/10 [00:00<?, ?it/s]

sub001_2drt_18_topic2_video-100
unet 1 has not been trained
unet 2 has not been trained
when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets


0it [00:00, ?it/s]

sampling time step:   0%|          | 0/10 [00:00<?, ?it/s]

sub001_2drt_18_topic2_video-1000
unet 1 has not been trained
unet 2 has not been trained
when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets


0it [00:00, ?it/s]

sampling time step:   0%|          | 0/10 [00:00<?, ?it/s]

sub001_2drt_18_topic2_video-1010
unet 1 has not been trained
unet 2 has not been trained
when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets


0it [00:00, ?it/s]

sampling time step:   0%|          | 0/10 [00:00<?, ?it/s]

sub001_2drt_18_topic2_video-1020
unet 1 has not been trained
unet 2 has not been trained
when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets


0it [00:00, ?it/s]

sampling time step:   0%|          | 0/10 [00:00<?, ?it/s]

sub001_2drt_18_topic2_video-1030
unet 1 has not been trained
unet 2 has not been trained
when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets


0it [00:00, ?it/s]

sampling time step:   0%|          | 0/10 [00:00<?, ?it/s]

sub001_2drt_18_topic2_video-1040
unet 1 has not been trained
unet 2 has not been trained
when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets


0it [00:00, ?it/s]

sampling time step:   0%|          | 0/10 [00:00<?, ?it/s]

sub001_2drt_18_topic2_video-1050
unet 1 has not been trained
unet 2 has not been trained
when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets


0it [00:00, ?it/s]

sampling time step:   0%|          | 0/10 [00:00<?, ?it/s]

sub001_2drt_18_topic2_video-1060
unet 1 has not been trained
unet 2 has not been trained
when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets


0it [00:00, ?it/s]

sampling time step:   0%|          | 0/10 [00:00<?, ?it/s]

In [9]:
dataset_75speaker[10]

AttributeError: 'list' object has no attribute 'split'

In [8]:
(gifs, aud_emb, cond_video_frames) = dataset_75speaker[2]

## Calc FVD

In [95]:
import glob
from PIL import Image, ImageSequence
from torchvision import transforms

def load_frames(image: Image, mode='RGB'):
    # ret = 
    # if self.transform:
    #     gif = self.transform(gif)
    return np.array([
        np.array(frame.convert(mode))
        for frame in ImageSequence.Iterator(image)
    ])

def load_frames_tensor(image: Image, mode='RGB', video_len=10):
    return torch.stack([transforms.ToTensor()(np.array(frame.convert('RGB'))) for frame in ImageSequence.Iterator(im)])[:video_len]

# def get_videos_from_folder(path, size_batch):
#     synthetic_batch = []
#     print(f'{path}/*')
#     synthetic_path = glob.glob(path + '/*')[:10]
#     for names in synthetic_path:
#         with Image.open(names) as im:
#             gif = load_frames_tensor(im)
#             synthetic_batch.append(gif)
#     synthetic_batch = torch.stack(synthetic_batch)
#     return synthetic_batch

In [96]:
synthetic_batch = []
synthetic_path = glob.glob(f'./generated_images/*')
for names in synthetic_path:
    with Image.open(names) as im:
        gif = load_frames_tensor(im)
        # gif = load_frames_tensor(im)
        synthetic_batch.append(gif)
synthetic_batch = torch.stack(synthetic_batch)

In [97]:
real_batch = []
real_path = glob.glob(f'./datasets/test/gifs/*')[:10]
for names in real_path:
    with Image.open(names) as im:
        gif = load_frames_tensor(im)
        # gif = load_frames_tensor(im)
        real_batch.append(gif)
real_batch = torch.stack(real_batch)

In [108]:
real_batch2 = []
real_path2 = glob.glob(f'./datasets/test/gifs/*')[100:110]
for names in real_path2:
    with Image.open(names) as im:
        gif = load_frames_tensor(im)
        # gif = load_frames_tensor(im)
        real_batch2.append(gif)
real_batch2 = torch.stack(real_batch2)

In [102]:
import sys
 
# appending a path
sys.path.append('common_metrics_on_video_quality')
from calculate_fvd import calculate_fvd
from calculate_psnr import calculate_psnr
from calculate_ssim import calculate_ssim
from calculate_lpips import calculate_lpips

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [on]


Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /home/hongn/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth
100%|██████████| 233M/233M [00:08<00:00, 30.1MB/s] 


Loading model from: /home/hongn/miniconda3/envs/genai/lib/python3.8/site-packages/lpips/weights/v0.1/alex.pth


In [109]:
device = torch.device("cuda")
import json
result = {}
result['fvd_realvsfake'] = calculate_fvd(synthetic_batch, real_batch2, device, method='styleganv')
result['fvd_realvsreal'] = calculate_fvd(real_batch, real_batch2, device, method='styleganv')

calculate_fvd...
/mnt/c/Users/PCM/Documents/GitHub/SPAN-rtmri/common_metrics_on_video_quality/fvd/styleganv/i3d_torchscript.pt


100%|██████████| 1/1 [00:01<00:00,  1.95s/it]


calculate_fvd...
/mnt/c/Users/PCM/Documents/GitHub/SPAN-rtmri/common_metrics_on_video_quality/fvd/styleganv/i3d_torchscript.pt


100%|██████████| 1/1 [00:01<00:00,  1.79s/it]


In [110]:
result['fvd_realvsfake']

{'value': {10: 1360.9338710416423},
 'video_setting': torch.Size([10, 3, 10, 64, 64]),
 'video_setting_name': 'batch_size, channel, time, heigth, width'}

In [111]:
result['fvd_realvsreal']

{'value': {10: 429.49936907293204},
 'video_setting': torch.Size([10, 3, 10, 64, 64]),
 'video_setting_name': 'batch_size, channel, time, heigth, width'}

In [80]:
real_batch.shape

AttributeError: 'list' object has no attribute 'shape'

In [21]:
gif.shape

(10, 64, 64, 3)