<a href="https://colab.research.google.com/github/alberto-paparella/nrr95p/blob/main/vdm_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Video Diffusion Model Training

This notebook is meant to experiment with the extracted dataset aiming to train a video diffusion model able to predict, given the first few frames of a gif showing the nrr95p evolution across some days, the next frames in the gif (i.e., the possible evolution of data in the next few days).

In the first experiment, a video diffusion model is trained with 2018 data; then, the model is evaluated on 2019 data.

## Initialization

Let's start by installing the [video_diffusion_pytorch](https://github.com/lucidrains/video-diffusion-pytorch) implementation of video diffusion models for pytorch developed by [Phil Wang](https://github.com/lucidrains).

In [None]:
%pip install video-diffusion-pytorch

## Setup on Google Colab

To connect Google Drive (GDrive) with Colab, execute the following two lines of code in Colab:

In [3]:
from google.colab import drive
drive.mount("/content/gdrive")

Mounted at /content/gdrive


Running the above shell will return a URL link and ask for an authorization code. Follow to the mentioned link, sign in Google account, and copy the authorization code by clicking at highlighted spot. Paste the authorization code in the shell and finally, Google Drive will be mounted at /content/gdrive. Note that, files in the drive are under the folder /content/gdrive/My Drive/. Now, we can import files in GDrive using a library like Pandas.

## Import 2018 gifs

However, the installed implementation requires a different structure; refer to below.

In [None]:
import imageio.v3 as iio
from pathlib import Path

videos = []#list()
for video in Path("nrr95p/2018_gifs_a").iterdir():
    if not video.is_file():
        continue
    # index=None means: read all images in the file and stack along first axis
    videos.append(iio.imread(video, index=None))
for video in Path("nrr95p/2018_gifs_b").iterdir():
    if not video.is_file():
        continue
    # index=None means: read all images in the file and stack along first axis
    videos.append(iio.imread(video, index=None))
for video in Path("nrr95p/2018_gifs_c").iterdir():
    if not video.is_file():
        continue
    # index=None means: read all images in the file and stack along first axis
    videos.append(iio.imread(video, index=None))

# ndarray with (num_frames, height, width, channel)
print(videos[0].shape)  # (36, 150, 200, 3)
print(len(videos))

## From gifs to tensors

Define function to convert gif to tensor; the function takes the gif path as an argument.

In [4]:
import torch
from torchvision import transforms as T
from PIL import Image, ImageSequence

CHANNELS_TO_MODE = {
    1 : 'L',
    3 : 'RGB',
    4 : 'RGBA'
}

def seek_all_images(img, channels = 3):
    assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid'
    mode = CHANNELS_TO_MODE[channels]

    i = 0
    while True:
        try:
            img.seek(i)
            yield img.convert(mode)
        except EOFError:
            break
        i += 1

def gif_to_tensor(path, channels = 3, transform = T.ToTensor()):
    img = Image.open(path)
    tensors = tuple(map(transform, seek_all_images(img, channels = channels)))
    return torch.stack(tensors, dim = 1)


In [7]:
from pathlib import Path

gifs = []
for gif in Path("/content/gdrive/MyDrive/nrr95p/2018_gifs_a").iterdir():
    if not gif.is_file():
        continue
    gifs.append(gif_to_tensor(gif))

videos = torch.stack(gifs)
print(videos.shape)

torch.Size([125, 3, 10, 480, 640])


In [None]:
# resize gif
# Get sequence iterator
frames = ImageSequence.Iterator(img)

# Wrap on-the-fly thumbnail generator
def thumbnails(frames):
    for frame in frames:
        thumbnail = frame.copy()
        thumbnail.thumbnail([640,640], Image.ANTIALIAS)
        yield thumbnail

frames = thumbnails(frames)

om = next(frames) # Handle first frame separately
om.info = img.info # Copy sequence info
om.save("out.gif", save_all=True, append_images=list(frames))

## Usage

In [None]:
#import torch
from video_diffusion_pytorch import Unet3D, GaussianDiffusion

model = Unet3D(
    dim = 64,
    dim_mults = (1, 2, 4, 8)
)

diffusion = GaussianDiffusion(
    model,
    image_size = 640,
    num_frames = 10,
    timesteps = 1000,   # number of steps
    loss_type = 'l1'    # L1 or L2
)

#videos = torch.randn(1, 3, 5, 32, 32) # video (batch, channels, frames, height, width) - normalized from -1 to +1
loss = diffusion(videos)
loss.backward()
# after a lot of training

sampled_videos = diffusion.sample(batch_size = 4)
sampled_videos.shape # (4, 3, 5, 32, 32)