<a href="https://colab.research.google.com/github/AmaruEscalante/VideoGPT/blob/master/Using_VideoGPT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Using VideoGPT
This is a notebook demonstrating how to use VideoGPT and any pretrained models, Make sure that it is a GPU instance: **Change Runtime Type -> GPU**

## Installation
First, we install the necessary packages

In [None]:
%cd VideoGPT/

/content/VideoGPT


In [None]:
!git clone https://github.com/amaruescalante/VideoGPT.git
%cd VideoGPT

Cloning into 'VideoGPT'...
remote: Enumerating objects: 398, done.[K
remote: Counting objects: 100% (150/150), done.[K
remote: Compressing objects: 100% (88/88), done.[K
remote: Total 398 (delta 79), reused 97 (delta 62), pack-reused 248[K
Receiving objects: 100% (398/398), 4.05 MiB | 24.83 MiB/s, done.
Resolving deltas: 100% (219/219), done.
/content/VideoGPT/VideoGPT


In [None]:
! pip install git+https://github.com/amaruescalante/VideoGPT.git
! pip install scikit-video ava
! pip install --upgrade --no-cache-dir gdown

In [None]:
!sh scripts/preprocess/ucf101/create_ucf_dataset.sh datasets/ucf101

In [None]:
!sh scripts/preprocess/msrvtt/create_msrvtt_dataset.sh datasets/msrvtt

In [None]:
# Train VQ-VAE
! python scripts/train_vqvae.py --data_path datasets/msrvtt --accelerator gpu --batch_size 16 --gpus 1 --auto_select_gpus true

In [None]:
! python scripts/train_videogpt.py --data_path datasets/msrvtt --accelerator gpu --batch_size 16 --gpus 1 --auto_select_gpus true

In [None]:
%matplotlib inline

from matplotlib import pyplot as plt
from matplotlib import animation
from IPython.display import HTML

import os
import torch
from torchvision.io import read_video, read_video_timestamps

from videogpt import download, load_vqvae, load_videogpt
from videogpt.data import preprocess

VIDEOS = {
    'breakdancing': '1OZBnG235-J9LgB_qHv-waHZ4tjofiDgj',
    'bear': '16nIaqq2vbPh-WMo_7hs9feVSe0jWVXLF',
    'jaywalking': '1UxKCVrbyXhvMz_H7dI4w5hjPpRGCAApy',
    'cartoon': '1ONcTMSEuGuLYIDbX-KeFqd390vbTIH9d'
}

ROOT = 'pretrained_models'

## Downloading a Pretrained VQ-VAE
There are four pretrained models available: `bair_stride4x2x2`, `ucf101_stride4x4x4`, `kinetics_stride4x4x4`, and `kinetics_stride2x4x4`. BAIR was trained on 64 x 64 video, and the rest on 128 x 128. The `stride` component represents the THW downsampling the VQ-VAE performs on the video tensor.

In [None]:
%reload_ext autoreload
from videogpt.vqvae import VQVAE
device = torch.device('cuda')
# vqvae = load_vqvae('kinetics_stride2x4x4', device=device, root=ROOT).to(device)
# vqvae = load_vqvae('ucf101_stride4x4x4', device=device, root=ROOT).to(device)

# Download VQ-VAE
filepath = download("1FNWJtWDTX5CcVSSlINK1ZFFHuBgjBZfB", "ucf101_stride4x4x4")
vqvae = VQVAE.load_from_checkpoint(filepath).to(device)

  rank_zero_warn(
INFO:pytorch_lightning.utilities.migration.utils:Lightning automatically upgraded your loaded checkpoint from v1.1.1 to v1.9.5. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file ../../root/.cache/videogpt/ucf101_stride4x4x4`


## Video Loading and Preprocessing
The code below downloads, loads, and preprocesses a given `mp4` file.

In [None]:
%reload_ext autoreload
video_name = 'jaywalking'
# `resolution` must be divisible by the encoder image stride
# `sequence_length` must be divisible by the encoder temporal stride
resolution, sequence_length = vqvae.args.resolution, 16

video_filename = download(VIDEOS[video_name], f'{video_name}.mp4')
pts = read_video_timestamps(video_filename, pts_unit='sec')[0]
video = read_video(video_filename, pts_unit='sec', start_pts=pts[0], end_pts=pts[sequence_length - 1])[0]
video = preprocess(video, resolution, sequence_length).unsqueeze(0).to(device)

Access denied with the following error:



 	Too many users have viewed or downloaded this file recently. Please
	try accessing the file again later. If the file you are trying to
	access is particularly large or is shared with many people, it may
	take up to 24 hours to be able to view or download the file. If you
	still can't access a file after 24 hours, contact your domain
	administrator. 

You may still be able to access the file from the browser:

	 https://drive.google.com/uc?id=1UxKCVrbyXhvMz_H7dI4w5hjPpRGCAApy 



IndexError: ignored

## VQ-VAE Encoding and Decoding
Now, we can encode the video through the `encode` function. The `encode` function also has an optional input `including_embeddings` (default `False`) which will also return the embedding versions of the encodings.

In [None]:
with torch.no_grad():
    encodings = vqvae.encode(video)
    video_recon = vqvae.decode(encodings)
    video_recon = torch.clamp(video_recon, -0.5, 0.5)

NameError: ignored

## Visualizing Reconstructions

In [None]:
videos = torch.cat((video, video_recon), dim=-1)
videos = videos[0].permute(1, 2, 3, 0) # CTHW -> THWC
videos = ((videos + 0.5) * 255).cpu().numpy().astype('uint8')

fig = plt.figure()
plt.title('real (left), reconstruction (right)')
plt.axis('off')
im = plt.imshow(videos[0, :, :, :])
plt.close()

def init():
    im.set_data(videos[0, :, :, :])

def animate(i):
    im.set_data(videos[i, :, :, :])
    return im

anim = animation.FuncAnimation(fig, animate, init_func=init, frames=videos.shape[0], interval=200) # 200ms = 5 fps
HTML(anim.to_html5_video())

NameError: ignored

# Using Pretrained VideoGPT Models

The current available model to download is `ucf101`.

In [None]:
%reload_ext autoreload
from videogpt.gpt import VideoGPT
from videogpt import download, load_vqvae, load_videogpt
device = torch.device('cuda')
filepath = download("1c4CYL1joN5KDC5VYJIilFYWcDOmjWtgE", "ucf101_uncond_gpt")
gpt = VideoGPT.load_from_checkpoint(filepath).to(device)
gpt.eval()
# gpt.eval()
# gpt = load_videogpt('ucf101_uncond_gpt', device=device).to(device)
# gpt = load_videogpt('bair_gpt', device=device).to(device)

`VideoGPT.sample` method returns generated samples of shape BCTHW in the range [0, 1]

In [None]:
!sudo apt-get install llvm-9-dev

In [None]:
%cd VideoGPT

In [None]:
samples = gpt.sample(16) # unconditional model does not require batch input

100%|██████████| 4096/4096 [02:13<00:00, 30.70it/s]


In [None]:
import math
import numpy as np

b, c, t, h, w = samples.shape
samples = samples.permute(0, 2, 3, 4, 1)
samples = (samples.cpu().numpy() * 255).astype('uint8')

video = np.zeros((t, (1 + h) * 4 + 1, (1 + w) * 4 + 1, c), dtype='uint8')
for i in range(b):
  r, c = i // 4, i % 4
  start_r, start_c = (1 + h) * r, (1 + w) * c
  video[:, start_r:start_r + h, start_c:start_c + w] = samples[i]

fig = plt.figure()
plt.title('ucf101 unconditional samples')
plt.axis('off')
im = plt.imshow(video[0, :, :, :])
plt.close()

def init():
    im.set_data(video[0, :, :, :])

def animate(i):
    im.set_data(video[i, :, :, :])
    return im

anim = animation.FuncAnimation(fig, animate, init_func=init, frames=video.shape[0], interval=200) # 200ms = 5 fps
HTML(anim.to_html5_video())

# Computing FVD on UCF101

In [None]:
!git clone

In [None]:
# !pip install git+https://github.com/amaruescalante/VideoGPT.git@ff13f8b43b316086fa04d3adf468d187ceecac76
# !pip install git+https://github.com/amaruescalante/VideoGPT.git
!git clone https://github.com/amaruescalante/VideoGPT.git
%cd VideoGPT

Cloning into 'VideoGPT'...
remote: Enumerating objects: 398, done.[K
remote: Counting objects: 100% (133/133), done.[K
remote: Compressing objects: 100% (87/87), done.[K
remote: Total 398 (delta 62), reused 81 (delta 46), pack-reused 265[K
Receiving objects: 100% (398/398), 4.05 MiB | 20.32 MiB/s, done.
Resolving deltas: 100% (219/219), done.
/content/VideoGPT


In [None]:
!pip install -r requirements.txt

In [None]:
from videogpt.download import load_i3d_pretrained
from videogpt.fvd.fvd import get_fvd_logits, frechet_distance

In [None]:
device = torch.device('cuda')
i3d = load_i3d_pretrained(device=torch.device(device))

In [None]:
%reload_ext autoreload

In [None]:
hparams = gpt.hparams['args']
print("hparams", type(hparams))
hparams.batch_size = 32

hparams <class 'argparse.Namespace'>


In [None]:
import gc
torch.cuda.empty_cache()
gc.collect()

  0%|          | 0/10 [01:14<?, ?it/s]


19095

In [None]:
import os
from videogpt.download import load_i3d_pretrained
# from tqdm import tqdm # this is for script version
from tqdm.notebook import tqdm  # Use this version of tqdm for Jupyter notebooks
import numpy as np

import torch
import torch.distributed as dist

from videogpt.fvd.fvd import get_fvd_logits, frechet_distance
from videogpt import VideoData, VideoGPT, load_videogpt

MAX_BATCH = 4

def main(ckpt='bair_gpt', n_trials=1, port=23452):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.set_grad_enabled(False)

    #################### Load VideoGPT ########################################
    # if not os.path.exists(ckpt):
        # gpt = load_videogpt(ckpt, device=device)
    # else:
        # gpt = VideoGPT.load_from_checkpoint(ckpt).to(device)
    gpt.eval()
    hparams = gpt.hparams['args']
    # print("hparams", hparams)
    batch_size = 4
    hparams.batch_size = batch_size
    loader = VideoData(hparams).test_dataloader()

    #################### Load I3D ########################################
    i3d = load_i3d_pretrained(device)

    #################### Compute FVD ###############################
    fvds = []
    fvds_star = []
    pbar = tqdm(total=n_trials)
    for _ in range(n_trials):
        fvd, fvd_star = eval_fvd(i3d, gpt, loader, device)
        fvds.append(fvd)
        fvds_star.append(fvd_star)

        pbar.update(1)
        fvd_mean = np.mean(fvds)
        fvd_std = np.std(fvds)

        fvd_star_mean = np.mean(fvds_star)
        fvd_star_std = np.std(fvds_star)

        pbar.set_description(f"FVD {fvd_mean:.2f} +/- {fvd_std:.2f}, FVD* {fvd_star_mean:.2f} +/- {fvd_star_std:.2f}")
    pbar.close()
    print(f"Final FVD {fvd_mean:.2f} +/- {fvd_std:.2f}, FVD* {fvd_star_mean:.2f} +/- {fvd_star_std:.2f}")

def all_gather(tensor):
    rank, size = dist.get_rank(), dist.get_world_size()
    tensor_list = [torch.zeros_like(tensor) for _ in range(size)]
    dist.all_gather(tensor_list, tensor)
    return torch.cat(tensor_list)


def eval_fvd(i3d, videogpt, loader, device):
    # rank, size = dist.get_rank(), dist.get_world_size()  # Removed distributed parts
    # is_root = rank == 0  # Not needed in sequential execution

    batch = next(iter(loader))
    batch = {k: v.to(device) for k, v in batch.items()}

    fake_embeddings = []
    for i in range(0, batch['video'].shape[0], MAX_BATCH):
        fake = videogpt.sample(MAX_BATCH, {k: v[i:i+MAX_BATCH] for k, v in batch.items()})
        fake = torch.repeat_interleave(fake, 4, dim=2) # TODO: check correctness
        fake = fake.permute(0, 2, 3, 4, 1).cpu().numpy() # BCTHW -> BTHWC
        fake = (fake * 255).astype('uint8')
        fake_embeddings.append(get_fvd_logits(fake, i3d=i3d, device=device))
    fake_embeddings = torch.cat(fake_embeddings)

    real = batch['video'].to(device)
    real_recon_embeddings = []
    for i in range(0, batch['video'].shape[0], MAX_BATCH):
        real_recon = (videogpt.get_reconstruction(batch['video'][i:i+MAX_BATCH]) + 0.5).clamp(0, 1)
        real_recon = torch.repeat_interleave(real_recon, 4, dim=2)
        real_recon = real_recon.permute(0, 2, 3, 4, 1).cpu().numpy()
        real_recon = (real_recon * 255).astype('uint8')
        real_recon_embeddings.append(get_fvd_logits(real_recon, i3d=i3d, device=device))
    real_recon_embeddings = torch.cat(real_recon_embeddings)

    real = real + 0.5
    real = real.permute(0, 2, 3, 4, 1).cpu().numpy() # BCTHW -> BTHWC
    real = (real * 255).astype('uint8')
    real_embeddings = get_fvd_logits(real, i3d=i3d, device=device)

    # fake_embeddings = all_gather(fake_embeddings)  # Not needed in sequential execution
    # real_recon_embeddings = all_gather(real_recon_embeddings)  # Not needed in sequential execution
    # real_embeddings = all_gather(real_embeddings)  # Not needed in sequential execution

    # Ensure that fake_embeddings and real_embeddings have the same number of items
    assert fake_embeddings.shape[0] == real_recon_embeddings.shape[0] == real_embeddings.shape[0]

    fvd = frechet_distance(fake_embeddings.clone(), real_embeddings)
    fvd_star = frechet_distance(fake_embeddings.clone(), real_recon_embeddings)
    return fvd.item(), fvd_star.item()


In [None]:
main(ckpt='ufc', n_trials=10, port=12345)

  0%|          | 0/10 [00:00<?, ?it/s]