In [None]:
!pip install -q wandb pytorch_lightning av

In [2]:
import torch
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from torchvision.datasets.video_utils import VideoClips
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.transforms import Compose, Lambda, Resize, ToTensor, CenterCrop, Grayscale

In [None]:
from google.colab import drive
drive.mount('/content/drive')
steamboat_willie_path = '/content/drive/My Drive/SteamboatWillie.mp4'

In [None]:
class SteamboatWillieDataset(Dataset):
    def __init__(self, paths, clip_length=16):
        super().__init__()
        self.video_clips = VideoClips(
            paths,
            clip_length_in_frames=clip_length,
            frames_between_clips=clip_length,
            num_workers=2
        )

        self.transforms = Compose([
            Lambda(lambda x: x.permute(0, 3, 1, 2)),     # (T, H, W, C) to (T, C, H, W) for Greyscale
            Grayscale(num_output_channels=1),            # Convert to grayscale
            Lambda(lambda x: x.permute(1, 0, 2, 3)),     # (T, C, H, W) to (C, T, H, W) for Conv3d
            Lambda(lambda x: CenterCrop((480, 575))(x)), # Center crop to remove virtical bars
            Lambda(lambda x: Resize((256, 256))(x)),     # Resize frames
            Lambda(lambda x: x / 255.),                  # Scale pixel values to [0, 1]
        ])

    def __len__(self):
        print(f'len: {self.video_clips.num_clips()}')
        return self.video_clips.num_clips()

    def __getitem__(self, idx):
        print(f'getting idx: {idx}')
        try:
            clip, _, _, _ = self.video_clips.get_clip(idx)
            return self.transforms(clip)
        except:
            print(f'error getting idx: {idx}')


class SteamboatWillieDataModule(pl.LightningDataModule):
    def __init__(self, paths, batch_size=32, train_split=0.8):
        super().__init__()
        self.paths = paths
        self.batch_size = batch_size
        self.train_split = train_split

    def prepare_data(self):
        self.full_dataset = SteamboatWillieDataset(self.paths)

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            train_len = int(len(self.full_dataset) * self.train_split)
            val_len = len(self.full_dataset) - train_len
            self.train_dataset, self.val_dataset = random_split(self.full_dataset, [train_len, val_len])
            print(f'Train: {len(self.train_dataset)}, Val: {len(self.val_dataset)}')

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=2)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2)

In [None]:
data_module = SteamboatWillieDataModule([steamboat_willie_path])
data_module.prepare_data()
data_module.setup()

train_loader = data_module.train_dataloader()
val_loader = data_module.val_dataloader()

In [None]:
train_batch = next(iter(train_loader)) # segfaulting with more than 0 workers