In [1]:
import os

In [2]:
import torch
from torch.utils.data import Subset
import pytorch_lightning as pl
from tqdm.notebook import trange
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from IPython.display import clear_output
from sklearn.model_selection import KFold
from sklearn.metrics import roc_auc_score

In [3]:
from system import System
from conflab.data_loaders.pose import ConflabPoseExtractor
from conflab.data_loaders.accel import ConflabAccelExtractor
from conflab.data_loaders.person import ConflabMultimodalDataset, ConflabLabelExtractor
from conflab.constants import conflab_pose_path, midge_data_path, conflab_speaking_status_path
from conflab.constants import vid2_start, vid3_start

In [4]:
from datetime import timedelta

In [5]:
vid2_seg8_start = vid2_start + timedelta(minutes=14)
vid2_len = (vid3_start - vid2_seg8_start).total_seconds()

In [6]:
vid2_len

217.0

In [7]:
def do_fold(train_ds, test_ds, model_name='resnet', deterministic=False):
    # data loaders
    data_loader_train = torch.utils.data.DataLoader(
        train_ds, batch_size=100, shuffle=True, num_workers=10,
        collate_fn=None)
    data_loader_val = torch.utils.data.DataLoader(
        test_ds, batch_size=100, shuffle=False, num_workers=10,
        collate_fn=None)

    system = System(model_name)
    trainer = pl.Trainer(
        callbacks=[EarlyStopping(monitor="val_loss", mode="min")],
        accelerator='gpu',
        log_every_n_steps=1,
        max_epochs=-1,
        deterministic=deterministic)
    trainer.fit(system, data_loader_train, data_loader_val)

    trainer.test(system, data_loader_val)
    return system.test_results

In [8]:
def get_metrics(outputs, labels, type='binary'):
    if type == 'binary':
        proba = torch.sigmoid(outputs)
        pred = (proba > 0.5)

        correct = pred.eq(outputs.bool()).sum().item()
        return {
            'auc': roc_auc_score(labels, proba),
            'correct': correct
        }
    elif type == 'regression':
        return {
            'mse': torch.nn.functional.mse_loss(outputs, labels, reduction='mean'),
            'l1': torch.nn.functional.l1_loss(outputs, labels, reduction='mean')
        }

In [9]:
def do_run(dataset, model_name, random_state, metrics_name='binary', deterministic=False):
    cv_splits = KFold(n_splits=4, random_state=random_state, shuffle=True).split(range(len(ds)))

    outputs = torch.empty((len(ds),))
    for f, (train_idx, test_idx) in enumerate(cv_splits):
        # create datasets    
        train_ds = Subset(dataset, train_idx)
        test_ds = Subset(dataset, test_idx)

        fold_outputs = do_fold(train_ds, test_ds, model_name, deterministic=deterministic)
        outputs[test_idx] = fold_outputs['proba'].cpu()
        clear_output(wait=True)

    labels = torch.Tensor(ds.get_all_labels())
    run_metrics = get_metrics(outputs, labels, metrics_name)

    return outputs, run_metrics

In [10]:
pose_extractor = ConflabPoseExtractor(conflab_pose_path)
pose_extractor.load_from_pickle('./tracks.pkl')

In [11]:
accel_extractor = ConflabAccelExtractor(midge_data_path)
label_extractor = ConflabLabelExtractor(os.path.join(conflab_speaking_status_path, 'speaking'))

In [12]:
# make windowed examples using the pose tracks.
examples = pose_extractor.make_examples()
# compose the dataset
ds = ConflabMultimodalDataset(examples, {
    'accel': accel_extractor,
    'label': label_extractor
})

100%|██████████| 8/8 [00:00<00:00, 544.04it/s]


In [13]:
seed=22
pl.utilities.seed.seed_everything(seed, workers=True)
outputs, metrics = do_run(ds, 'alexnet', random_state=seed, metrics_name='binary', deterministic=True)

  labels = torch.Tensor(ds.get_all_labels())
