In [2]:
import torch
from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer
from utils import gif75speaker, get_path_of_pretrained
import numpy as np
from torchvision import transforms
# import argparse
from IPython.display import clear_output



## Create generated dataset to compare

In [43]:
AUDIO_EMB = 'hubert-large'
POOLING = False
MODE = 'test-unseenboth' #Select {test-unseenaudio, test-unseensubject, test-unseenboth}
LEN_GEN_IMGS = 500 # Number of generated images for evaluation
PATH_2_PRETRAINED, LEN_AUDIO_EMB = get_path_of_pretrained(AUDIO_EMB, POOLING)
print(PATH_2_PRETRAINED)

/mnt/c/Users/PCM/Documents/GitHub/SPAN-rtmri/checkpoints/hubert/large/ImagenVideo-Modelhubert-large-PoolingFalse-IgnoreTimeFalse-TwoStepTrue-100


In [44]:
dataset_75speaker = gif75speaker(image_path = './datasets/preprocessed_dataset/test', 
                                img_per_gif = 10, 
                                audio_path = f'./datasets/preprocessed_dataset/{AUDIO_EMB}', 
                                audio_pooling = POOLING,
                                mode = MODE)

In [45]:
unet1 = Unet3D(dim = 64, dim_mults = (1, 2, 4, 8)).cuda()
unet2 = Unet3D(dim = 64, dim_mults = (1, 2, 4, 8)).cuda()

imagen = ElucidatedImagen(
    text_embed_dim = LEN_AUDIO_EMB,
    unets = (unet1, unet2),
    image_sizes = (64, 64),
    random_crop_sizes = (None, 16),
    temporal_downsample_factor = (1, 1),        # in this example, the first unet would receive the video temporally downsampled by 2x
    num_sample_steps = 10,
    cond_drop_prob = 0.1,
    sigma_min = 0.002,                          # min noise level
    sigma_max = (80, 160),                      # max noise level, double the max noise level for upsampler
    sigma_data = 0.5,                           # standard deviation of data distribution
    rho = 7,                                    # controls the sampling schedule
    P_mean = -1.2,                              # mean of log-normal distribution from which noise is drawn for training
    P_std = 1.2,                                # standard deviation of log-normal distribution from which noise is drawn for training
    S_churn = 80,                               # parameters for stochastic sampling - depends on dataset, Table 5 in apper
    S_tmin = 0.05,
    S_tmax = 50,
    S_noise = 1.003,
).cuda()

imagen.load_state_dict(torch.load(PATH_2_PRETRAINED))
trainer = ImagenTrainer(imagen,
    split_valid_from_train = True, # whether to split the validation dataset from the training
    dl_tuple_output_keywords_names = ('images', 'text_embeds', 'cond_video_frames')
).cuda()

In [46]:
import subprocess
# !rm -rf ./generated_images
# !mkdir ./generated_images
# subprocess.call(f"rm -rf ./generated_images/{AUDIO_EMB}", shell=True)
# subprocess.call(f"mkdir ./generated_images/{AUDIO_EMB}", shell=True)
subprocess.call(f"rm -rf ./generated_images/{AUDIO_EMB}/{MODE}", shell=True)
subprocess.call(f"mkdir ./generated_images/{AUDIO_EMB}/{MODE}", shell=True)

0

In [47]:
# real_path = []

for i in range(LEN_GEN_IMGS):
    (_, aud_emb, cond_video_frames) = dataset_75speaker[i]
    # real_path.append(dataset_75speaker.get_path(i))
    print(f'{dataset_75speaker.get_names(i)}')
    videos = trainer.sample(text_embeds = aud_emb.unsqueeze(0), video_frames = 10, stop_at_unet_number  = 1, batch_size = 1, cond_video_frames=cond_video_frames.unsqueeze(0))
    imgs = torch.transpose(videos[0], 0, 1)
    imgs = [transforms.ToPILImage()(img) for img in imgs]
    # duration is the number of milliseconds between frames; this is 40 frames per second
    # model_name = opt.audio_path.split('/')[-1]
    imgs[0].save(f'./generated_images/{AUDIO_EMB}/{MODE}/{dataset_75speaker.get_names(i)}.gif', save_all=True, append_images=imgs[1:], duration=10, loop=0)

    clear_output()

sub067_2drt_21_topic5_video-1180
unet 1 has not been trained
unet 2 has not been trained
when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets


0it [00:00, ?it/s]

sampling time step:   0%|          | 0/10 [00:00<?, ?it/s]

## Calc FVD for UNSEEN SUBJECTS

In [33]:
import glob
from PIL import Image, ImageSequence
from torchvision import transforms

def load_frames(image: Image, mode='RGB'):
    # ret = 
    # if self.transform:
    #     gif = self.transform(gif)
    return np.array([
        np.array(frame.convert(mode))
        for frame in ImageSequence.Iterator(image)
    ])

def load_frames_tensor(im: Image, mode='RGB', video_len=10):
    return torch.stack([transforms.ToTensor()(np.array(frame.convert('RGB'))) for frame in ImageSequence.Iterator(im)])[:video_len]

def get_videos_from_folder(paths):
    synthetic_batch = []
    for names in paths:
        with Image.open(names) as im:
            gif = load_frames_tensor(im)
            # gif = load_frames_tensor(im)
            synthetic_batch.append(gif)
    synthetic_batch = torch.stack(synthetic_batch)
    return synthetic_batch

In [34]:
synthetic_batch = get_videos_from_folder(glob.glob(f'./generated_images/{AUDIO_EMB}/{MODE}/*')[:LEN_GEN_IMGS])

In [35]:
real_test_paths = ['./datasets/preprocessed_dataset/test/' + name.split('/')[-1] for name in glob.glob(f'./generated_images/{AUDIO_EMB}/{MODE}/*')[:LEN_GEN_IMGS]]

In [36]:
real_batch_test = get_videos_from_folder(real_test_paths)
real_batch_train = get_videos_from_folder(glob.glob(f'./datasets/preprocessed_dataset/train/*')[:LEN_GEN_IMGS])

In [None]:
# synthetic_batch = []
# synthetic_path = glob.glob(f'./generated_images/*')[:300]
# for names in synthetic_path:
#     with Image.open(names) as im:
#         gif = load_frames_tensor(im)
#         # gif = load_frames_tensor(im)
#         synthetic_batch.append(gif)
# synthetic_batch = torch.stack(synthetic_batch)

In [None]:
# real_path = real_batch#glob.glob(f'./datasets/preprocessed_dataset/test/*')[:300]
# real_batch = []
# for names in real_path[:300]:
#     with Image.open(names) as im:
#         gif = load_frames_tensor(im)
#         # gif = load_frames_tensor(im)
#         real_batch.append(gif)
# real_batch = torch.stack(real_batch)

In [None]:
# real_batch2 = []
# real_path2 = glob.glob(f'./datasets/preprocessed_dataset/train/*')[:300]
# for names in real_path2:
#     with Image.open(names) as im:
#         gif = load_frames_tensor(im)
#         # gif = load_frames_tensor(im)
#         real_batch2.append(gif)
# real_batch2 = torch.stack(real_batch2)

In [37]:
import sys
 
# appending a path
sys.path.append('common_metrics_on_video_quality')
from calculate_fvd import calculate_fvd
from calculate_psnr import calculate_psnr
from calculate_ssim import calculate_ssim
from calculate_lpips import calculate_lpips

In [38]:
device = torch.device("cuda")
import json
result = {}
result['fvd_realvsfake'] = calculate_fvd(synthetic_batch, real_batch_train, device, method='styleganv')
result['fvd_realvsreal'] = calculate_fvd(real_batch_test, real_batch_train, device, method='styleganv')
result['ssim_realvsfake'] = calculate_ssim(synthetic_batch, real_batch_train)
result['psnr_realvsfake'] = calculate_psnr(synthetic_batch, real_batch_train)
result['ssim_realvsreal'] = calculate_ssim(real_batch_test, real_batch_train)
result['psnr_realvsreal'] = calculate_psnr(real_batch_test, real_batch_train)

calculate_fvd...
/mnt/c/Users/PCM/Documents/GitHub/SPAN-rtmri/common_metrics_on_video_quality/fvd/styleganv/i3d_torchscript.pt


100%|██████████| 1/1 [00:09<00:00,  9.32s/it]


calculate_fvd...
/mnt/c/Users/PCM/Documents/GitHub/SPAN-rtmri/common_metrics_on_video_quality/fvd/styleganv/i3d_torchscript.pt


100%|██████████| 1/1 [00:09<00:00,  9.80s/it]


calculate_ssim...


100%|██████████| 500/500 [00:04<00:00, 104.39it/s]


calculate_psnr...


100%|██████████| 500/500 [00:00<00:00, 3227.94it/s]


calculate_ssim...


100%|██████████| 500/500 [00:04<00:00, 105.79it/s]


calculate_psnr...


100%|██████████| 500/500 [00:00<00:00, 3320.14it/s]


In [39]:
np.mean(list(result['ssim_realvsfake']['value'].values()))

0.1108799424802223

In [40]:
np.mean(list(result['ssim_realvsreal']['value'].values()))

0.26404798849270844

In [41]:
result['fvd_realvsreal']

{'value': {10: 209.97105881616136},
 'video_setting': torch.Size([500, 3, 10, 64, 64]),
 'video_setting_name': 'batch_size, channel, time, heigth, width'}

In [42]:
result['fvd_realvsfake']

{'value': {10: 1673.3681691459751},
 'video_setting': torch.Size([500, 3, 10, 64, 64]),
 'video_setting_name': 'batch_size, channel, time, heigth, width'}

## Calc FVD for 2 seconds

In [None]:
import subprocess
import glob
def get_length(filename):
    result = subprocess.run(["ffprobe", "-v", "error", "-show_entries",
                             "format=duration", "-of",
                             "default=noprint_wrappers=1:nokey=1", filename],
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT)
    return float(result.stdout)

In [None]:
# MAKE TEST DATASET
subjects = glob.glob(f'/mnt/c/Users/PCM/Dropbox/span/sub*')[60:]
for sub in subjects:
    vids = glob.glob(f'{sub}/2drt/video/*')
    window = 0.4 # step = window - overlap
    overlap = 0.2
    for i in range(len(vids)):
        for skip in np.arange(0, int(get_length(vids[i]))-1, window-overlap):
            command = f"ffmpeg -y -ss {skip} -t {window} -i {vids[i]} -vf \"fps=50,scale=64:-1:flags=lanczos,split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse\" -loop 0 ./datasets/preprocessed_dataset/test-2/{vids[i].split('/')[-1].split('.')[0]}-{int(skip*50)}.gif"
            subprocess.call(command, shell=True)

subjects = glob.glob(f'/mnt/c/Users/PCM/Dropbox/span/sub*')[:60]
for sub in subjects:
    vids = glob.glob(f'{sub}/2drt/video/*')[28:]
    window = 0.4 # step = window - overlap
    overlap = 0.2
    for i in range(len(vids)):
        for skip in np.arange(0, int(get_length(vids[i]))-1, window-overlap):
            command = f"ffmpeg -y -ss {skip} -t {window} -i {vids[i]} -vf \"fps=50,scale=64:-1:flags=lanczos,split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse\" -loop 0 ./datasets/preprocessed_dataset/test-2/{vids[i].split('/')[-1].split('.')[0]}-{int(skip*50)}.gif"
            subprocess.call(command, shell=True)

In [None]:


subjects = glob.glob(f'/mnt/c/Users/PCM/Dropbox/span/sub*')[:60]
for sub in subjects:
    vids = glob.glob(f'{sub}/2drt/video/*')[:28]
    # vids = glob.glob(f'/mnt/c/Users/PCM/Dropbox/span/sub006/2drt/video/*')
    window = 1 # step = window - overlap
    overlap = 0
    for i in range(len(vids)):
        for skip in np.arange(0, int(get_length(vids[i]))-1, window-overlap):
            command = f"ffmpeg -y -ss {skip} -t {window} -i {vids[i]} -vf \"fps=50,scale=64:-1:flags=lanczos,split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse\" -loop 0 ./datasets/preprocessed_dataset/train-1/{vids[i].split('/')[-1].split('.')[0]}-{int(skip*50)}.gif"
            subprocess.call(command, shell=True)

In [None]:
!rm -rf ./generated_images_1s
!mkdir ./generated_images_1s

In [None]:
real_path = []

for i in range(100):
    (_, aud_emb, preceding) = dataset_75speaker[i]
    real_path.append(dataset_75speaker.get_path(i))
    print(f'{dataset_75speaker.get_names(i)}')

    aud_embs = dataset_75speaker.get_audio_emb(i)
    start_frame = dataset_75speaker.get_start_index(i)
    listimgs = []
    MILISECOND = 1
    for i in range(start_frame,start_frame + int(MILISECOND/0.02), 10):
        print(i)
        aud_emb = torch.mean(aud_embs[:,i:i+10,:], axis=1).unsqueeze(0) #aud_embs[:,i:i+10,:]#
        sample_img = imagen.sample(text_embeds = aud_emb.cuda(), video_frames = 10, stop_at_unet_number=1, skip_steps=0, cond_video_frames = preceding.unsqueeze(0).cuda())
        preceding = sample_img[0,:,-1:,:]
        # init_images = add_noise_video(sample_img, image_sizes = (64, 64), timesteps=1000, times = 300).cuda()
        listimgs.append(sample_img)
    # videos = trainer.sample(text_embeds = aud_emb.unsqueeze(0), video_frames = 10, stop_at_unet_number  = 1, batch_size = 1, cond_video_frames=cond_video_frames.unsqueeze(0))
    # imgs = torch.transpose(videos[0], 0, 1)
    # imgs = [transforms.ToPILImage()(img) for img in imgs]
    # duration is the number of milliseconds between frames; this is 40 frames per second
    # model_name = opt.audio_path.split('/')[-1]
    # imgs[0].save(f'./generated_images_1s/{dataset_75speaker.get_names(i)}.gif', save_all=True, append_images=imgs[1:], duration=10, loop=0)
    a = torch.transpose(torch.stack(listimgs, axis=0).squeeze(1),1,2)
    b = a.reshape(int(MILISECOND/0.02), 3, 64, 64)
    imgs = [transforms.ToPILImage()(img) for img in b]
    # duration is the number of milliseconds between frames; this is 40 frames per second
    imgs[0].save(f'./generated_images_1s/{dataset_75speaker.get_names(i)}.gif', save_all=True, append_images=imgs[1:], duration=10, loop=0)

    clear_output()

In [None]:
a = torch.transpose(torch.stack(listimgs, axis=0).squeeze(1),1,2)
b = a.reshape(200,3,64,64)
imgs = [transforms.ToPILImage()(img) for img in b]
# duration is the number of milliseconds between frames; this is 40 frames per second
imgs[0].save(f'gif-sample-video.gif', save_all=True, append_images=imgs[1:], duration=20, loop=0)