In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
import os
import torch 
import pytorch_lightning as pl

from pytorchvideo.data import LabeledVideoDataset
from pytorchvideo.data.clip_sampling import UniformClipSampler
from pytorchvideo.transforms import ApplyTransformToKey, UniformTemporalSubsample
from torchvision.transforms import Compose, Resize

from utils.config import TRAIN_LABELS_FILE, TEST_FOLDER, TRAIN_FOLDER

In [5]:
fname_label_pairs = []

with open(TRAIN_LABELS_FILE, 'r') as f:
    for line in f:
        try:
            fname, label = line.strip().split(',')
            fname = os.path.join(TRAIN_FOLDER, fname)
            # label = {"label":int(label)}
            label = {"label": torch.tensor([int(label)], dtype=torch.int32)}
            fname_label_pairs.append(tuple((fname, label)))
        except:
            pass

In [6]:
transform = Compose([
    ApplyTransformToKey(
        key="video",
        transform=Compose([
            UniformTemporalSubsample(16),
            Resize(128)
        ])
    )
])

In [7]:
train_dataset = LabeledVideoDataset(
    fname_label_pairs,
    UniformClipSampler(5.),
    transform=transform,
    decode_audio=False,
)

In [8]:
from torch.utils.data import  DataLoader

train_loader = DataLoader(train_dataset, batch_size=10)

In [9]:
loader = iter(train_loader)

In [10]:
batch = next(loader)

In [12]:
batch['label']

tensor([[0],
        [1],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0]], dtype=torch.int32)

In [35]:
batch['video'].dtype

torch.float32

In [186]:
from sklearn.model_selection import train_test_split

In [189]:
length = len(fname_label_pairs)

In [196]:
train, val = train_test_split(fname_label_pairs, test_size=0.1, train_size=0.9)

# Training

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

from models.model import FakeVideoDetector, create_resnet
from utils.datasets import FakeVideoDataModule

In [2]:
classifier = FakeVideoDetector()
datamodule = FakeVideoDataModule(num_workers=1, batch_size=4)

In [None]:
callbacks = [
            EarlyStopping(monitor="AUROC", mode="max", patience=20),
            ModelCheckpoint(
                dirpath=f'checkpoints',
                filename='{epoch}--{AUROC:.3f}', monitor="AUROC", mode="max",
            ),
        ]

trainer = Trainer(
    gpus=1,
    callbacks=callbacks,
    log_every_n_steps=5,
    precision=16,
    deterministic=True,
)


In [None]:
trainer.fit(classifier, datamodule)