In [None]:
!pip install -q wandb pytorch_lightning av

In [2]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from torchvision.datasets.video_utils import VideoClips
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.transforms import Compose, Lambda, Resize, ToTensor, CenterCrop, Grayscale

In [None]:
from google.colab import drive
drive.mount('/content/drive')
steamboat_willie_gdrive_path = '/content/drive/My Drive/SteamboatWillie.mp4'

In [None]:
class SteamboatWillieDataset(Dataset):
    def __init__(self, paths, clip_length=16, dest_dir='./clips/'):
        super().__init__()
        video_clips = VideoClips(
            paths,
            clip_length_in_frames=clip_length,
            frames_between_clips=clip_length
        )

        transforms = Compose([
            Lambda(lambda x: x.permute(0, 3, 1, 2)),     # (T, H, W, C) to (T, C, H, W) for Greyscale
            Grayscale(num_output_channels=1),            # Convert to grayscale
            Lambda(lambda x: x.permute(1, 0, 2, 3)),     # (T, C, H, W) to (C, T, H, W) for Conv3d
            Lambda(lambda x: CenterCrop((480, 575))(x)), # Center crop to remove virtical bars
            Lambda(lambda x: Resize((256, 256))(x)),     # Resize frames
            Lambda(lambda x: x / 255.),                  # Scale pixel values to [0, 1]
        ])

        self.clips = self.build_clip_refs(self.build_clip_paths(video_clips, transforms, dest_dir))
        
    def build_clip_paths(self, video_clips, transforms, dest_dir):
        """
        Build set of binary files to store processed video clips
        """
        print('here')
        clip_paths = {}
        
        if not os.path.exists(dest_dir):
            os.makedirs(dest_dir)
        
        for idx in range(video_clips.num_clips()):
            # transform clips and write to mmap file
            clip, _, _, _ = video_clips.get_clip(idx)
            clip = transforms(clip)
            clip_np = clip.numpy()
            mmapped_file_path = os.path.join(dest_dir, f'clip_{idx}.bin')
            fp = np.memmap(mmapped_file_path, dtype='float32', mode='w+', shape=clip_np.shape)
            fp[:] = clip_np[:]
            fp.flush()
            del fp
            clip_paths[idx] = mmapped_file_path

        return clip_paths

    def build_clip_refs(self, clip_paths):
        """
        Build reference to mmap files
        """
        clips = {}
        for idx, path in clip_paths.items():
            clips[idx] = np.memmap(path, dtype='float32', mode='r')
        
        return clips

    def __len__(self):
        return len(self.clips)

    def __getitem__(self, idx):
        return torch.tensor(self.clips[idx], dtype=torch.float32)


class SteamboatWillieDataModule(pl.LightningDataModule):
    def __init__(self, paths, batch_size=8, train_split=0.8):
        super().__init__()
        self.paths = paths
        self.batch_size = batch_size
        self.train_split = train_split

    def prepare_data(self):
        self.full_dataset = SteamboatWillieDataset(self.paths)

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            train_len = int(len(self.full_dataset) * self.train_split)
            val_len = len(self.full_dataset) - train_len
            self.train_dataset, self.val_dataset = random_split(self.full_dataset, [train_len, val_len])

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=0)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=0)

In [None]:
data_module = SteamboatWillieDataModule([steamboat_willie_gdrive_path])
data_module.prepare_data()
data_module.setup()

train_loader = data_module.train_dataloader()
val_loader = data_module.val_dataloader()

In [None]:
train_batch = next(iter(train_loader))

In [None]:
!cp -r clips /content/drive/My\ Drive/