In [None]:
from catalyst import dl

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader

from mmcv.utils import Config

from liedet.data import VideoReader
from liedet.datasets import build_dataset
from liedet.models.e2e import LieDetectorRunner
from liedet.models.registry import build

### Choose a config file

In [None]:
# video --> landmarks + angles --> transformer --> linear --> probs
# cfg = "configs/landmarks_transformer.py"

# video --> landmarks+angles -- 
#                              --> concat --> transformer --> linear --> probs
# audio --> features ----------
cfg = "configs/landmarks_audio_transformer.py"

# video --> TinaFace --(faces) --> ResNet3D --> linear --> probs
# cfg = "confgis/tinaface_r3d.py"

# video --> TinaFace --(faces)--> TimeSformer --> linear --> probs
# cfg = "confgis/tinaface_timesformer.py"

# video --> TinaFace --(faces)--> ResNet50 --(face features)--> Transformer --> linear --> probs
# cfg = "confgis/tinaface_r50_transformer.py"


cfg = Config.fromfile(cfg)
cfg["model"].pop("init_cfg")

### Build dataset and dataloaders

In [None]:
dataset = build_dataset(cfg.dataset)
train_set, valid_set = dataset.split(**cfg.dataset.split)
loaders = dict(
    train_loader=DataLoader(train_set, batch_size=cfg.batch_size, num_workers=0, drop_last=True),
    valid_loader=DataLoader(valid_set, batch_size=cfg.batch_size, num_workers=0),
)

### Build model

In [None]:
model = build(cfg.model)

### Build optimizer and critetion

In [None]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

### Build runner

In [None]:
runner = LieDetectorRunner()

### Train model

In [None]:
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    loaders=loaders,
    logdir="./logs",
    num_epochs=200,
    valid_loader="valid_loader",
    valid_metric="accuracy01",
    minimize_valid_metric=False,
    callbacks=[
        dl.CriterionCallback(input_key="logits", target_key="labels", metric_key="loss"),
        dl.BackwardCallback(metric_key="loss"),
        dl.OptimizerCallback(metric_key="loss"),
        dl.AccuracyCallback(input_key="logits", target_key="labels", num_classes=2),
        dl.EarlyStoppingCallback(patience=15, loader_key="valid_loader", metric_key="loss", minimize=True),
        dl.CheckpointCallback(
            logdir="./logs",
            loader_key="valid_loader",
            metric_key="loss",
            minimize=True,
            topk=1,
        ),
    ],
    load_best_on_end=True,
    verbose=True,
)

### Evaluate model on valid loader

In [None]:
runner.evaluate_loader(
    loader=loaders["valid_loader"],
    callbacks=[
        dl.BatchTransformCallback(
            input_key="logits", output_key="scores", scope="on_batch_end", transform=torch.sigmoid
        ),
        dl.AccuracyCallback(input_key="scores", target_key="labels", num_classes=2),
    ],
    verbose=True,
)

### Infer model on custom video

#### Load video from file

In [None]:
video_path = "assets/example.mp4"

In [None]:
vr = VideoReader(uri=video_path, **cfg.dataset)
length = len(vr)

#### Generate predictions

In [None]:
for start in range(0, length, cfg.window):
    sample = vr[start : start + cfg.window]

    print(runner.predict_sample(sample))