## Data-Preprocessing

### Video to gif preprocessing

In [1]:
from PIL import Image
from torch.utils.data import Dataset
import glob, torch
import subprocess
from utils import *
from PIL import Image, ImageSequence
import numpy as np



In [None]:
def with_opencv(filename):
    import cv2
    video = cv2.VideoCapture(filename)

    duration = video.get(cv2.CAP_PROP_POS_MSEC)
    frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT)

    return duration, frame_count

def get_length(filename):
    result = subprocess.run(["ffprobe", "-v", "error", "-show_entries",
                             "format=duration", "-of",
                             "default=noprint_wrappers=1:nokey=1", filename],
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT)
    return float(result.stdout)

In [None]:
# Make video into gifs with fps=50, length = window = 0.5s, no overlap (overlap=0)
!rm -rf ./datasets/gifs-84/
!mkdir ./datasets/gifs-84/

target_dir = './datasets/gifs'
subjects = glob.glob('/mnt/c/Users/PCM/Dropbox/span/sub*')
for sub in subjects:
    vids = glob.glob(f'{sub}/2drt/video/*')
    # vids = glob.glob(f'/mnt/c/Users/PCM/Dropbox/span/sub006/2drt/video/*')
    window = 0.2 # step = window - overlap
    overlap = 0
    for i in range(len(vids)):
        for skip in np.arange(0, int(get_length(vids[i]))-1, window-overlap):
            command = f"ffmpeg -y -ss {skip} -t {window} -i {vids[i]} -vf \"fps=50,scale=64:-1:flags=lanczos,split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse\" -loop 0 ./datasets/gifs/{vids[i].split('/')[-1].split('.')[0]}-{int(skip*50)}.gif"
            subprocess.call(command, shell=True)

## Modeling

In [2]:
from PIL import Image
from torch.utils.data import Dataset
import glob, torch

class gif75speaker(Dataset):
    def __init__(self, image_path = './datasets/gifs', audio_path = './datasets/audios', transform=None, target_transform=None, img_per_gif = 10):
        self.gifs = glob.glob(f'{image_path}/*')  # Could be a list: ['./train/input/image_1.bmp', './train/input/image_2.bmp', ...]
        # self.audios = glob.glob(f'{audio_path}/*')  # Could be a nested list: [['./train/GT/image_1_1.bmp', './train/GT/image_1_2.bmp', ...], ['./train/GT/image_2_1.bmp', './train/GT/image_2_2.bmp', ...]]
        self.transform = transform
        self.target_transform = target_transform
        self.img_per_gif = img_per_gif

    def __getitem__(self, index):
        gifs_name = self.gifs[index].split('/')[-1].split('.')[0].split('-')

        with Image.open(self.gifs[index]) as im:
            gif = self.load_frames(im)
        # gif = Image.open(self.images[index])

        aud_embs = torch.load(f'./datasets/audios/{gifs_name[0]}.pt')
        aud_emb = aud_embs[:,int(gifs_name[-1]):int(gifs_name[-1]) + self.img_per_gif,:]
        gif = torch.transpose(torch.stack([transforms.ToTensor()(i) for i in gif[:self.img_per_gif]]), 0,1)
        return (gif, aud_emb[0], gif[:,0:2,:])

    def __len__(self):
        return len(self.gifs)
    
    def load_frames(self, image: Image, mode='RGB'):
        # ret = 
        # if self.transform:
        #     gif = self.transform(gif)
        return np.array([
            np.array(frame.convert(mode))
            for frame in ImageSequence.Iterator(image)
        ])

In [3]:
(gifs, aud_emb, preceding_frame) = next(iter(gif75speaker(transform=data_transforms['val'], img_per_gif=10)))

In [4]:
preceding_frame.shape

torch.Size([3, 2, 64, 64])

In [None]:
len(gif75speaker())

In [None]:
aud_emb.shape

In [None]:
# from PIL import Image, ImageSequence
# import numpy as np

# def load_frames(image: Image, mode='RGB'):
#     return np.array([
#         np.array(frame.convert(mode))
#         for frame in ImageSequence.Iterator(image)
#     ])
# gifs_name = '/mnt/c/Users/PCM/Documents/GitHub/SPAN-rtmri/datasets/gifs/sub006_2drt_20_topic4_video-550.gif'
# with Image.open(gifs_name) as im:
#     frames = load_frames(im)

In [None]:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(gif75speaker(), batch_size=4, shuffle=True)

In [None]:
import torch
from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer

unet1 = Unet3D(dim = 64, dim_mults = (1, 2, 4, 8)).cuda()

unet2 = Unet3D(dim = 64, dim_mults = (1, 2, 4, 8)).cuda()

# elucidated imagen, which contains the unets above (base unet and super resoluting ones)

imagen = ElucidatedImagen(
    text_embed_dim=1024,
    unets = (unet1, unet2),
    image_sizes = (64, 64),
    random_crop_sizes = (None, 16),
    temporal_downsample_factor = (2, 1),        # in this example, the first unet would receive the video temporally downsampled by 2x
    num_sample_steps = 10,
    cond_drop_prob = 0.1,
    sigma_min = 0.002,                          # min noise level
    sigma_max = (80, 160),                      # max noise level, double the max noise level for upsampler
    sigma_data = 0.5,                           # standard deviation of data distribution
    rho = 7,                                    # controls the sampling schedule
    P_mean = -1.2,                              # mean of log-normal distribution from which noise is drawn for training
    P_std = 1.2,                                # standard deviation of log-normal distribution from which noise is drawn for training
    S_churn = 80,                               # parameters for stochastic sampling - depends on dataset, Table 5 in apper
    S_tmin = 0.05,
    S_tmax = 50,
    S_noise = 1.003,
).cuda()

# mock videos (get a lot of this) and text encodings from large T5

texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles',
    'dust motes swirling in the morning sunshine on the windowsill'
]

# videos = torch.randn(4, 3, 10, 32, 32).cuda() # (batch, channels, time / video frames, height, width)

# feed images into imagen, training each unet in the cascade
# for this example, only training unet 1

trainer = ImagenTrainer(imagen,
    split_valid_from_train = True, # whether to split the validation dataset from the training
    dl_tuple_output_keywords_names = ('images', 'text_embeds', 'cond_video_frames')
).cuda()

# you can also ignore time when training on video initially, shown to improve results in video-ddpm paper. eventually will make the 3d unet trainable with either images or video. research shows it is essential (with current data regimes) to train first on text-to-image. probably won't be true in another decade. all big data becomes small data
# for i in range(1,200000):
trainer.add_train_dataset(gif75speaker(), batch_size = 8)

for i in range(1,20000):
    loss = trainer.train_step(unet_number = 1, max_batch_size = 8, ignore_time = False)
    print(f'loss: {loss}')

    if not (i % 50):
        valid_loss = trainer.valid_step(unet_number = 1, max_batch_size = 8)
        print(f'valid loss: {valid_loss}')

    if not (i % 100) and trainer.is_main: # is_main makes sure this can run in distributed
        videos = trainer.sample(text_embeds = aud_emb.unsqueeze(0), video_frames = 10, stop_at_unet_number  = 1, batch_size = 1)
        imgs = torch.transpose(videos[0], 0, 1)
        imgs = [transforms.ToPILImage()(img) for img in imgs]
        # duration is the number of milliseconds between frames; this is 40 frames per second
        # imgs[0].save(f'./gif_samples/gif-sample-{i // 100}.gif', save_all=True, append_images=imgs[1:], duration=20, loop=0)
        # torch.save(imagen.state_dict(), f'./checkpoints/imagen-video-{i}')

# losses = 0
# for (videos, aud_emb) in train_dataloader:
#     trainer(videos, text_embeds = aud_emb, unet_number = 1, ignore_time = False)
#     trainer.update(unet_number = 1)

    # if not (i % 5):
    #     valid_loss = trainer.valid_step(unet_number = 1, max_batch_size = 32)
    #     print(f'valid loss: {valid_loss}')

    # if not (i % 10) and trainer.is_main: # is_main makes sure this can run in distributed
    #     videos = trainer.sample(text_embeds = aud_emb, video_frames = 20, stop_at_unet_number  = 1, batch_size = 1) # returns List[Image]
    #     images[0].save(f'./sample_log/sample-{i // 100}.png')

# videos = trainer.sample(texts = texts, video_frames = 20) # extrapolating to 20 frames from training on 10 frames

# videos.shape # (4, 3, 20, 32, 32)


In [1]:
imagen

NameError: name 'imagen' is not defined

In [None]:
videos = trainer.sample(text_embeds = aud_emb, video_frames = 20, stop_at_unet_number  = 1, batch_size = 1) # extrapolating to 20 frames from training on 10 frames
videos.shape

In [None]:
transforms.to_pil_image()

In [None]:
from PIL import Image

imgs = torch.transpose(videos[0], 0, 1)
imgs = [transforms.ToPILImage()(img) for img in imgs]
# duration is the number of milliseconds between frames; this is 40 frames per second
imgs[0].save("sample-imagen-video.gif", save_all=True, append_images=imgs[1:], duration=20, loop=0)

In [None]:
torch.transpose(videos[0], 0, 1).shape

In [None]:
videos