In [None]:
!pip install -q wandb pytorch_lightning av

In [2]:
import os
import torch
import numpy as np
import pytorch_lightning as pl
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import display, HTML
from torch import nn
from torchvision.datasets.video_utils import VideoClips
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.transforms import Compose, Lambda, Resize, ToTensor, CenterCrop, Grayscale

In [None]:
from google.colab import drive
drive.mount('/content/drive')
steamboat_willie_gdrive_path = '/content/drive/My Drive/SteamboatWillie/SteamboatWillie.mp4'
!cp -r /content/drive/My\ Drive/SteamboatWillie/clips .

In [None]:
class SteamboatWillieDataset(Dataset):
    def __init__(self, paths, clip_length=16, dest_dir='./clips/'):
        super().__init__()
        if os.path.exists(dest_dir):
            clip_paths = self.build_existing_clip_paths(dest_dir)
            self.clips = self.build_clip_refs(clip_paths)
        else:
            video_clips = VideoClips(
                paths,
                clip_length_in_frames=clip_length,
                frames_between_clips=clip_length
            )

            transforms = Compose([
                Lambda(lambda x: x.permute(0, 3, 1, 2)),     # (T, H, W, C) to (T, C, H, W) for Greyscale
                Grayscale(num_output_channels=1),            # Convert to grayscale
                Lambda(lambda x: x.permute(1, 0, 2, 3)),     # (T, C, H, W) to (C, T, H, W) for Conv3d
                Lambda(lambda x: CenterCrop((480, 575))(x)), # Center crop to remove virtical bars
                Lambda(lambda x: Resize((256, 256))(x)),     # Resize frames
                Lambda(lambda x: x / 255.),                  # Scale pixel values to [0, 1]
            ])

            self.clips = self.build_clip_refs(self.build_clip_paths(video_clips, transforms, dest_dir))
        
    def build_clip_paths(self, video_clips, transforms, dest_dir):
        """
        Build set of binary files to store processed video clips
        returns dict of clip_idx -> mmapped file path
        """
        clip_paths = {}
        
        if not os.path.exists(dest_dir):
            os.makedirs(dest_dir)
        
        for idx in tqdm(range(video_clips.num_clips()), desc='Creating clip .bin files'):
            # transform clips and write to mmap file
            clip, _, _, _ = video_clips.get_clip(idx)
            clip = transforms(clip)
            clip_np = clip.numpy()
            
            mmapped_file_path = os.path.join(dest_dir, f'clip_{idx}.bin')
            fp = np.memmap(mmapped_file_path, dtype='float32', mode='w+', shape=clip_np.shape)
            fp[:] = clip_np[:]
            fp.flush()
            del fp
            clip_paths[idx] = mmapped_file_path

        return clip_paths
    
    def build_existing_clip_paths(self, dest_dir):
        """"
        returns dict of clip_idx -> mmapped file path
        from existing .bin files
        """
        clips_paths = {}
        for filename in os.listdir(dest_dir):
            if filename.startswith('clip_') and filename.endswith('.bin'):
                idx = int(filename.split('_')[1].split('.')[0])
                file_path = os.path.join(dest_dir, filename)
                clips_paths[idx] = file_path
        
        return clips_paths

    def build_clip_refs(self, clip_paths):
        """
        Build reference to mmap files
        returns dict of clip_idx -> np array connected to respective mmap file
        """
        clips = {}
        for idx, path in tqdm(clip_paths.items(), desc='Building clip refs'):
            clips[idx] = np.memmap(path, dtype='float32', mode='r')
        
        return clips

    def __len__(self):
        return len(self.clips)

    def __getitem__(self, idx):
        # TODO: change to config values (in_channels, -1, height, width)
        return torch.tensor(self.clips[idx], dtype=torch.float32).view(1, -1, 256, 256)


class SteamboatWillieDataModule(pl.LightningDataModule):
    def __init__(self, paths, batch_size=8, train_split=0.8):
        super().__init__()
        self.paths = paths
        self.batch_size = batch_size
        self.train_split = train_split

    def prepare_data(self):
        self.full_dataset = SteamboatWillieDataset(self.paths)

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            train_len = int(len(self.full_dataset) * self.train_split)
            val_len = len(self.full_dataset) - train_len
            self.train_dataset, self.val_dataset = random_split(self.full_dataset, [train_len, val_len])

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=0)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=0)

In [None]:
data_module = SteamboatWillieDataModule([steamboat_willie_gdrive_path])
data_module.prepare_data()
data_module.setup()

train_loader = data_module.train_dataloader()
val_loader = data_module.val_dataloader()

In [None]:
train_batch = next(iter(train_loader))

# Model

In [None]:
class ResBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels
        ):
        super().__init__()
        if in_channels != out_channels:
            self.identity = nn.Conv3d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1
            )
        else:
            self.identity = nn.Identity()

        self.block = nn.Sequential(
            nn.Conv3d(
                in_channels=in_channels,
                out_channels=in_channels,
                kernel_size=3,
                padding=1
            ),
            nn.BatchNorm3d(in_channels),
            nn.ReLU(),
            nn.Conv3d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1
            ),
            nn.BatchNorm3d(out_channels)
        )
        self.res_act = nn.ReLU()

    def forward(self, x):
        out = self.block(x) + self.identity(x)
        return self.res_act(out)

In [None]:
class Encoder(nn.Module):
    def __init__(self, in_channels=1, hidden_channels=32, out_channels=16, nlayers=4, nblocks=1):
        super().__init__()
        self.downsample_blocks = nn.Sequential(*[
            nn.Sequential(
                nn.Conv3d(
                    in_channels=in_channels if i==0 else hidden_channels,
                    out_channels=hidden_channels,
                    kernel_size=(4, 4, 4),
                    stride=(2, 2, 2),
                    padding=(1, 1, 1)
                ),
                nn.BatchNorm3d(hidden_channels),
                nn.ReLU()
            ) for i in range(nlayers)
        ])

        self.res_blocks = nn.Sequential(*[
            ResBlock(
                in_channels=hidden_channels,
                out_channels=out_channels if i==nblocks-1 else hidden_channels
            ) for i in range(nblocks)
        ])

    def forward(self, x):
        x = self.downsample_blocks(x)
        x = self.res_blocks(x)
        return x

In [None]:
class Decoder(torch.nn.Module):
    def __init__(self, in_channels=16, hidden_channels=32, out_channels=1, nlayers=4, nblocks=1):
        super().__init__()
        self.res_blocks = nn.Sequential(*[
            ResBlock(
                in_channels=in_channels if i==0 else hidden_channels,
                out_channels=hidden_channels,
            )
            for i in range(nblocks)
        ])
        
        self.upsample_blocks = nn.Sequential(*[
            nn.Sequential(
                nn.ConvTranspose3d(
                    in_channels=hidden_channels,
                    out_channels=hidden_channels,
                    kernel_size=(4, 4, 4),
                    stride=(2, 2, 2),
                    padding=(1, 1, 1)),
                nn.BatchNorm3d(hidden_channels),
                nn.ReLU()
            ) for i in range(nlayers-1)
        ])

        self.out_layer = nn.ConvTranspose3d(
            in_channels=hidden_channels,
            out_channels=out_channels,
            kernel_size=(4, 4, 4),
            stride=(2, 2, 2),
            padding=(1, 1, 1)
        )

    def forward(self, x):
        x = self.res_blocks(x)
        x = self.upsample_blocks(x)
        x = self.out_layer(x)
        return x

In [None]:
encoder = Encoder()

In [None]:
sample_enc = encoder(train_batch)
sample_enc.shape

In [None]:
decoder = Decoder()
recon_clip = decoder(sample_enc)
recon_clip.shape

# Display Clip

In [None]:
def display_clip(clip):
    def update(frame_idx):
        ax.clear()
        ax.imshow(video_clip_np[frame_idx], cmap='gray')
        ax.axis('off')

    video_clip_np = clip.permute(1, 2, 3, 0).numpy()
    fig, ax = plt.subplots()
    ani = FuncAnimation(fig, update, frames=range(video_clip_np.shape[0]), interval=50)
    plt.close()
    display(HTML(ani.to_html5_video()))

In [None]:
clip = train_batch[7].view(1, -1, 256, 256)
display_clip(clip)