In [29]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_video
from torch import nn
import json
from os import path
from typing import Self, TypedDict, Union, Optional, Literal, List, Tuple
import time

DATA_PATH = 'D:\dfdc'
NUM_PARTS = 50
Split = Union[Literal['train'], Literal['validation']]
Label = Union[Literal[0], Literal[1]]

In [30]:

class FileMetadata(TypedDict):
  path: str
  label: Union[Literal['REAL'], Literal['FAKE']]
  original: Optional[str]

class VideoDataset(Dataset):
  def __get_part_directory(self: Self, index: int) -> str:
    return path.join(self.root_path, f'dfdc_train_part_{index}')

  def __init__(self: Self, root_path: str, split: Split):
    self.root_path = root_path

    # Init metadata
    self.metadata: List[FileMetadata] = []
    for i in range(NUM_PARTS):
      print(f'reading part {i} metadata...')
      start = time.time()
      dir_path = self.__get_part_directory(i)
      metadata_path = path.join(dir_path, 'metadata.json')
      metadata = json.load(open(metadata_path))
      post_load = time.time()
      print(f'loading json took {1_000*(post_load - start):.2f}ms')
      for k, data in metadata.items():
        if data['split'] != split:
          continue

        video_path = path.join(dir_path, k)
        data['path'] = video_path
        del data['split']

        self.metadata.append(data)
      
      post_parse = time.time()
      print(f'parsing json took {1_000*(post_parse - post_load):.2f}ms')


  def __getitem__(self: Self, index: int) -> Tuple[torch.Tensor, Label]:
    metadata = self.metadata[index]
    # How do we handle the audio as well?
    video, audio, _ = read_video(metadata['path'], pts_unit='sec')
    label = 1 if metadata['label'] == 'FAKE' else 0

    return video, label
  
  def __len__(self: Self):
    return len(self.metadata)

In [31]:
# Test the VideoDataset
loader = DataLoader(VideoDataset(DATA_PATH, 'train'))

reading part 0 metadata...
loading json took 3.00ms
parsing json took 4.00ms
reading part 1 metadata...
loading json took 2.00ms
parsing json took 4.00ms
reading part 2 metadata...
loading json took 10.00ms
parsing json took 4.00ms
reading part 3 metadata...
loading json took 3.00ms
parsing json took 3.00ms
reading part 4 metadata...
loading json took 2.00ms
parsing json took 4.00ms
reading part 5 metadata...
loading json took 2.00ms
parsing json took 6.00ms
reading part 6 metadata...
loading json took 3.00ms
parsing json took 7.00ms
reading part 7 metadata...
loading json took 3.00ms
parsing json took 6.00ms
reading part 8 metadata...
loading json took 2.08ms
parsing json took 3.92ms
reading part 9 metadata...
loading json took 2.00ms
parsing json took 4.00ms
reading part 10 metadata...
loading json took 3.00ms
parsing json took 7.00ms
reading part 11 metadata...
loading json took 2.00ms
parsing json took 5.00ms
reading part 12 metadata...
loading json took 3.08ms
parsing json took 4.

In [34]:
from matplotlib import pyplot as plt

# Get some images & labels from the data loader

data, label = next(iter(loader))

# data dimensions are BxTxCxHxW

fig = plt.figure(figsize=(12, 6))
fig.add_subplot(1, 3, 1)
plt.imshow(data[0][0])
fig.add_subplot(1, 3, 2)
plt.imshow(data[0][100])
fig.add_subplot(1, 3, 3)
plt.imshow(data[0][-1])

print(label)


tensor([1])


In [25]:
device = ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

Using cuda device


In [26]:
model = nn.Sequential(
    nn.Conv3d(3, 16, kernel_size=3, stride=1, padding=1),
    nn.ReLU(),
    nn.MaxPool3d(2),
    nn.Conv3d(16, 32, kernel_size=3, stride=1, padding=1),
    nn.ReLU(),
    nn.MaxPool3d(2),
    nn.Conv3d(32, 64, kernel_size=3, stride=1, padding=1),
    nn.ReLU(),
    nn.MaxPool3d(2),
    nn.Conv3d(64, 128, kernel_size=3, stride=1, padding=1),
    nn.ReLU(),
    nn.MaxPool3d(2),
    nn.Flatten(),
    nn.Linear(128, 2),
    nn.Softmax(dim=1)
)

In [27]:
train_dl = DataLoader(VideoDataset(DATA_PATH, 'train'))
validation_dl = DataLoader(VideoDataset(DATA_PATH, 'validation'))

reading part 0 metadata...
loading json took 1.00ms
parsing json took 3.00ms
reading part 1 metadata...
loading json took 1.00ms
parsing json took 4.00ms
reading part 2 metadata...
loading json took 2.00ms
parsing json took 3.00ms
reading part 3 metadata...
loading json took 3.00ms
parsing json took 2.00ms
reading part 4 metadata...
loading json took 2.00ms
parsing json took 3.00ms
reading part 5 metadata...
loading json took 3.00ms
parsing json took 6.00ms
reading part 6 metadata...
loading json took 3.00ms
parsing json took 12.00ms
reading part 7 metadata...
loading json took 4.00ms
parsing json took 6.00ms
reading part 8 metadata...
loading json took 2.00ms
parsing json took 4.00ms
reading part 9 metadata...
loading json took 2.00ms
parsing json took 4.00ms
reading part 10 metadata...
loading json took 3.00ms
parsing json took 7.00ms
reading part 11 metadata...
loading json took 2.00ms
parsing json took 4.00ms
reading part 12 metadata...
loading json took 2.00ms
parsing json took 5.

In [28]:
from train_model import TrainModel

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()
score_fn = nn.Softmax(dim=1)
num_epochs = 10

train_model = TrainModel(model, train_dl, validation_dl, optimizer, num_epochs, loss_fn, score_fn)


RuntimeError: Input type (unsigned char) and bias type (float) should be the same