In [1]:
import torch
from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer
from utils import gif75speaker, get_path_of_pretrained
import numpy as np
from torchvision import transforms
# import argparse
from IPython.display import clear_output



## Create generated dataset to compare

In [2]:
AUDIO_EMB = 'wav2vec2-l60-pho'
POOLING = False
MODE = 'test-unseenboth' #Select {test-unseenaudio, test-unseensubject, test-unseenboth}
LEN_GEN_IMGS = 300 # Number of generated images for evaluation
PATH_2_PRETRAINED, LEN_AUDIO_EMB = get_path_of_pretrained(AUDIO_EMB, POOLING)

In [3]:
dataset_75speaker = gif75speaker(image_path = './datasets/preprocessed_dataset/test', 
                                img_per_gif = 10, 
                                audio_path = f'./datasets/preprocessed_dataset/{AUDIO_EMB}', 
                                audio_pooling = POOLING,
                                mode = MODE)

In [4]:
unet1 = Unet3D(dim = 64, dim_mults = (1, 2, 4, 8)).cuda()
unet2 = Unet3D(dim = 64, dim_mults = (1, 2, 4, 8)).cuda()

imagen = ElucidatedImagen(
    text_embed_dim = LEN_AUDIO_EMB,
    unets = (unet1, unet2),
    image_sizes = (64, 64),
    random_crop_sizes = (None, 16),
    temporal_downsample_factor = (1, 1),        # in this example, the first unet would receive the video temporally downsampled by 2x
    num_sample_steps = 10,
    cond_drop_prob = 0.1,
    sigma_min = 0.002,                          # min noise level
    sigma_max = (80, 160),                      # max noise level, double the max noise level for upsampler
    sigma_data = 0.5,                           # standard deviation of data distribution
    rho = 7,                                    # controls the sampling schedule
    P_mean = -1.2,                              # mean of log-normal distribution from which noise is drawn for training
    P_std = 1.2,                                # standard deviation of log-normal distribution from which noise is drawn for training
    S_churn = 80,                               # parameters for stochastic sampling - depends on dataset, Table 5 in apper
    S_tmin = 0.05,
    S_tmax = 50,
    S_noise = 1.003,
).cuda()

imagen.load_state_dict(torch.load(PATH_2_PRETRAINED))
trainer = ImagenTrainer(imagen,
    split_valid_from_train = True, # whether to split the validation dataset from the training
    dl_tuple_output_keywords_names = ('images', 'text_embeds', 'cond_video_frames')
).cuda()

The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/


dataloader_config = DataLoaderConfiguration(split_batches=True)


In [None]:
!rm -rf ./generated_images
!mkdir ./generated_images

In [None]:
real_path = []

for i in range(301):
    (_, aud_emb, cond_video_frames) = dataset_75speaker[i]
    real_path.append(dataset_75speaker.get_path(i))
    print(f'{dataset_75speaker.get_names(i)}')
    videos = trainer.sample(text_embeds = aud_emb.unsqueeze(0), video_frames = 10, stop_at_unet_number  = 1, batch_size = 1, cond_video_frames=cond_video_frames.unsqueeze(0))
    imgs = torch.transpose(videos[0], 0, 1)
    imgs = [transforms.ToPILImage()(img) for img in imgs]
    # duration is the number of milliseconds between frames; this is 40 frames per second
    # model_name = opt.audio_path.split('/')[-1]
    imgs[0].save(f'./generated_images/{dataset_75speaker.get_names(i)}.gif', save_all=True, append_images=imgs[1:], duration=10, loop=0)

    clear_output()

## Calc FVD for UNSEEN SUBJECTS

In [32]:
import glob
from PIL import Image, ImageSequence
from torchvision import transforms

def load_frames(image: Image, mode='RGB'):
    # ret = 
    # if self.transform:
    #     gif = self.transform(gif)
    return np.array([
        np.array(frame.convert(mode))
        for frame in ImageSequence.Iterator(image)
    ])

def load_frames_tensor(image: Image, mode='RGB', video_len=10):
    return torch.stack([transforms.ToTensor()(np.array(frame.convert('RGB'))) for frame in ImageSequence.Iterator(im)])[:video_len]

# def get_videos_from_folder(path, size_batch):
#     synthetic_batch = []
#     print(f'{path}/*')
#     synthetic_path = glob.glob(path + '/*')[:10]
#     for names in synthetic_path:
#         with Image.open(names) as im:
#             gif = load_frames_tensor(im)
#             synthetic_batch.append(gif)
#     synthetic_batch = torch.stack(synthetic_batch)
#     return synthetic_batch

In [33]:
synthetic_batch = []
synthetic_path = glob.glob(f'./generated_images/*')[:300]
for names in synthetic_path:
    with Image.open(names) as im:
        gif = load_frames_tensor(im)
        # gif = load_frames_tensor(im)
        synthetic_batch.append(gif)
synthetic_batch = torch.stack(synthetic_batch)

In [34]:
# real_path = real_batch#glob.glob(f'./datasets/preprocessed_dataset/test/*')[:300]
real_batch = []
for names in real_path[:300]:
    with Image.open(names) as im:
        gif = load_frames_tensor(im)
        # gif = load_frames_tensor(im)
        real_batch.append(gif)
real_batch = torch.stack(real_batch)

In [35]:
real_batch2 = []
real_path2 = glob.glob(f'./datasets/preprocessed_dataset/train/*')[:300]
for names in real_path2:
    with Image.open(names) as im:
        gif = load_frames_tensor(im)
        # gif = load_frames_tensor(im)
        real_batch2.append(gif)
real_batch2 = torch.stack(real_batch2)

In [20]:
import sys
 
# appending a path
sys.path.append('common_metrics_on_video_quality')
from calculate_fvd import calculate_fvd
from calculate_psnr import calculate_psnr
from calculate_ssim import calculate_ssim
from calculate_lpips import calculate_lpips

In [36]:
device = torch.device("cuda")
import json
result = {}
result['fvd_realvsfake'] = calculate_fvd(synthetic_batch, real_batch2, device, method='styleganv')
result['fvd_realvsreal'] = calculate_fvd(real_batch, real_batch2, device, method='styleganv')
result['ssim_realvsfake'] = calculate_ssim(synthetic_batch, real_batch2)
result['psnr_realvsfake'] = calculate_psnr(synthetic_batch, real_batch2)
result['ssim_realvsreal'] = calculate_ssim(real_batch, real_batch2)
result['psnr_realvsreal'] = calculate_psnr(real_batch, real_batch2)

calculate_fvd...
/mnt/c/Users/PCM/Documents/GitHub/SPAN-rtmri/common_metrics_on_video_quality/fvd/styleganv/i3d_torchscript.pt


100%|██████████| 1/1 [00:30<00:00, 30.97s/it]


calculate_fvd...
/mnt/c/Users/PCM/Documents/GitHub/SPAN-rtmri/common_metrics_on_video_quality/fvd/styleganv/i3d_torchscript.pt


100%|██████████| 1/1 [00:24<00:00, 24.46s/it]


calculate_ssim...


100%|██████████| 300/300 [00:08<00:00, 34.09it/s]


calculate_psnr...


100%|██████████| 300/300 [00:01<00:00, 274.62it/s]


calculate_ssim...


100%|██████████| 300/300 [00:09<00:00, 30.08it/s]


calculate_psnr...


100%|██████████| 300/300 [00:00<00:00, 1003.12it/s]


In [14]:
np.mean(list(result['ssim_realvsfake']['value'].values()))

0.14528324416866542

In [15]:
np.mean(list(result['ssim_realvsreal']['value'].values()))

0.3241815948654033

In [37]:
result['fvd_realvsreal']

{'value': {10: 295.53774701782646},
 'video_setting': torch.Size([300, 3, 10, 64, 64]),
 'video_setting_name': 'batch_size, channel, time, heigth, width'}

In [38]:
result['fvd_realvsfake']

{'value': {10: 1701.0833649052056},
 'video_setting': torch.Size([300, 3, 10, 64, 64]),
 'video_setting_name': 'batch_size, channel, time, heigth, width'}

In [18]:
gif.shape

torch.Size([10, 3, 64, 64])

## Calc FVD for 2 seconds