In [1]:
from types import SimpleNamespace
from math import floor
from dataclasses import dataclass

import torch
import torchvision

import torchvision.transforms as T
import torch.nn.functional as F

device = torch.device("cuda")
rng = torch.manual_seed(42)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print("Torch version: ", torch.__version__)
print("CUDA available: ", torch.cuda.is_available())

Torch version:  1.13.1+cu117
CUDA available:  True


- See [./code/utils/arguments.py](./code/utils/arguments.py) for arguments used in upstream.
  - Upstream author was lazy and passed around the args object, so it is hard to trace the parameters of stuff.
- Upstream author modified torchvision loader from 2020: <https://github.com/ajabri/videowalk/blob/master/code/data/kinetics.py>
  - Current torchvision loader is now: <https://pytorch.org/vision/stable/generated/torchvision.datasets.Kinetics.html>

#### Preprocessing Transforms

In [3]:
class MapVideoTransform:
    """Map transform over THWC tensor."""

    def __init__(self, transform):
        self.transform = transform

    def __call__(self, video):
        return torch.stack([self.transform(frame) for frame in video])

In [4]:
# TODO: check if its RGB or BGR order.
# TODO: check if you missed preprocessing where THWC -> TCHW since upstream authors use THWC.
# TODO: steal the augmentation transforms & figure out why its repeated twice.

to_tensor = T.ToTensor()
to_PIL = T.ToPILImage()
norm_size = T.Resize((640, 640)) # works only on PIL/numpy # upstream uses 256x256
norm_color = T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
transform = MapVideoTransform(T.Compose([to_PIL, norm_size, to_tensor, norm_color]))

#### Dataloading

In [5]:
# Can re-use metadata from previous pickled dataset object if need to reconfigure.
try:
    kinetics400 = torch.load("datasets/kinetics400.pt")
except:
    kinetics400 = None



In [6]:
kinetics400 = torchvision.datasets.Kinetics(
  root="/home/jovyan/downloads/kinetics400",
  frames_per_clip=4,
  num_classes="400",
  split="val",
  frame_rate=8,
  step_between_clips=8,
  # download=True,
  transform=transform,
  num_workers=16,
  num_download_workers=16,
  output_format="TCHW",
  _precomputed_metadata=dict(
    video_paths=kinetics400.video_clips.video_paths,
    video_fps=kinetics400.video_clips.video_fps,
    video_pts=kinetics400.video_clips.video_pts
  )
)
torch.save(kinetics400, "datasets/kinetics400.pt")

In [7]:
from torchvision.datasets.samplers import RandomClipSampler
from torch.utils.data import DataLoader, default_collate

def collate(batch):
  """torchvision.datasets.video_utils.VideoClips returns metadata along with video tensor. Select video tensor & stack into batch."""
  # See https://github.com/pytorch/vision/blob/707457050620e1f70ab1b187dad81cc36a7f9180/torchvision/datasets/video_utils.py#L289
  # list of (T, C, H, W) videos
  batch = [c[0] for c in batch]
  # let torchvision handle conversion to tensor
  return default_collate(batch)

subset_idx = torch.randperm(kinetics400.video_clips.num_videos(), generator=rng)[:5000].tolist()
sampler = RandomClipSampler(kinetics400.video_clips.subset(subset_idx), 1)
dataloader = DataLoader(
    kinetics400,
    batch_size=6,
    sampler=sampler,
    num_workers=16,
    collate_fn=collate,
    pin_memory=True,
    generator=rng
)

print("Total videos: ", kinetics400.video_clips.num_videos())
print("Total clips: ", len(kinetics400))
print("Filtered clips: ", len(dataloader)*dataloader.batch_size)

Total videos:  19881
Total clips:  190357
Filtered clips:  5004


In [8]:
sample = next(iter(dataloader))
print("Dataloader tensor shape: ", sample.shape)



Dataloader tensor shape:  torch.Size([6, 4, 3, 640, 640])


#### Visualizer

In [9]:
# from types import SimpleNamespace
# from videowalk.utils.visualize import Visualize
# 
# # author was lazy
# args = SimpleNamespace(
#     name="videowalk-test",
#     port=80,
#     server="localhost"
# )
# 
# # doesnt work if no ports are available...
# 
# viz = Visualize(args)
# 

#### Build Model

In [3]:
from videowalk.model import CRW

args = SimpleNamespace(
    dropout=0,
    featdrop=0.0,
    temperature=0.07,
    head_depth=0,
    device=device,
    flip=False,
    sk_targets=False,
    model_type="imagenet18",
    remove_layers=[],
)

model = CRW(args).to(device)

stride Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
stride Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
stride Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
stride Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
stride Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
stride Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
stride Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
stride Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
stride Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
stride Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)


In [11]:
# TODO: lr_milestones should be per total epochs, not per epoch.
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
lr_milestones = [0.2, 0.8] # 20% and 80% of epoch
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer, milestones=[floor(len(dataloader) * m) for m in lr_milestones], gamma=0.1
)

#### Checkpointer

In [12]:
@dataclass
class Checkpointer:
    model: torch.nn.Module
    optimizer: torch.optim.Optimizer
    scheduler: torch.optim.lr_scheduler._LRScheduler
    epoch: int
    args: SimpleNamespace

    def save(self, path):
        torch.save({
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'lr_scheduler': self.scheduler.state_dict(),
            'epoch': self.epoch,
            'args': self.args
        }, path)

    def load(self, path):
        checkpoint = torch.load(path)
        self.model.load_state_dict(checkpoint['model'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.scheduler.load_state_dict(checkpoint['lr_scheduler'])
        self.epoch = checkpoint['epoch']
        self.args = checkpoint['args']

#### Model Training

In [13]:
from tqdm import tqdm
checkpointer = Checkpointer(model, optimizer, lr_scheduler, 0, args)

def train_one_epoch():
    model.train()
    print("Epoch: ", checkpointer.epoch)
    for video in tqdm(dataloader):
        video = video.to(device)

        _, loss, _ = model(video)
        loss = loss.mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

    checkpointer.epoch += 1
    checkpointer.save(f"epoch{checkpointer.epoch}.pt")

In [14]:
train_one_epoch()



Epoch:  0


100%|██████████| 834/834 [1:26:54<00:00,  6.25s/it]


#### Check Encoder Architecture

In [4]:
from videowalk.utils import make_encoder

encoder = make_encoder(args)
print(encoder)

stride Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
stride Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
stride Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
stride Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
stride Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
stride Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
stride Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
stride Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
stride Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
stride Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
From3D(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentu

In [5]:
from torchinfo import summary

# Dataloader output tensor: NCTHW
summary(encoder, input_size=(1, 3, 8, 64, 64))

Layer (type:depth-idx)                        Output Shape              Param #
From3D                                        [1, 512, 8, 8, 8]         --
├─ResNet: 1-1                                 [8, 512, 8, 8]            --
│    └─Conv2d: 2-1                            [8, 64, 32, 32]           9,408
│    └─BatchNorm2d: 2-2                       [8, 64, 32, 32]           128
│    └─ReLU: 2-3                              [8, 64, 32, 32]           --
│    └─MaxPool2d: 2-4                         [8, 64, 16, 16]           --
│    └─Sequential: 2-5                        [8, 64, 16, 16]           --
│    │    └─BasicBlock: 3-1                   [8, 64, 16, 16]           73,984
│    │    └─BasicBlock: 3-2                   [8, 64, 16, 16]           73,984
│    └─Sequential: 2-6                        [8, 128, 8, 8]            --
│    │    └─BasicBlock: 3-3                   [8, 128, 8, 8]            230,144
│    │    └─BasicBlock: 3-4                   [8, 128, 8, 8]            295,42