In [1]:
import os
import logging

In [2]:
import torch
import pandas as pd
from torch.utils.data import Subset
import pytorch_lightning as pl
from tqdm.notebook import trange
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from IPython.display import clear_output
from sklearn.model_selection import KFold, train_test_split
from sklearn.metrics import roc_auc_score

In [3]:
from system import System
from conflab.data_loading.pose import ConflabPoseExtractor
from conflab.data_loading.accel import ConflabAccelExtractor
from conflab.data_loading.person import ConflabDataset, ConflabSubset
from conflab.data_loading.labels import ConflabLabelExtractor
from conflab.constants import conflab_pose_path, midge_data_path, conflab_speaking_status_path

In [4]:
def do_fold(train_ds, test_ds, model_name='resnet', model_hparams={}, deterministic=False, log_prefix=None):
    # split the train set into train and val
    # use 10% of data for val
    val_ds, train_ds = train_ds.random_split(0.1)

    # data loaders
    data_loader_train = torch.utils.data.DataLoader(
        train_ds, batch_size=64, shuffle=True, num_workers=4,
        collate_fn=None)
    data_loader_val = torch.utils.data.DataLoader(
        val_ds, batch_size=128, shuffle=False, num_workers=4,
        collate_fn=None)
    data_loader_test = torch.utils.data.DataLoader(
        test_ds, batch_size=128, shuffle=False, num_workers=4,
        collate_fn=None)

    system = System(model_name, model_hparams=model_hparams)

    checkpoint_callback = ModelCheckpoint(dirpath="./checkpoints", save_top_k=1, mode="max", monitor="val_auc")
    trainer = pl.Trainer(
        callbacks=[
            EarlyStopping(monitor="val_auc", patience=6, mode="max"),
            checkpoint_callback
        ],
        accelerator='gpu',
        log_every_n_steps=1,
        max_epochs=20,
        deterministic=deterministic,
        logger=pl.loggers.TensorBoardLogger(save_dir="logs/", version=log_prefix))
    trainer.fit(system, data_loader_train, data_loader_val)

    trainer.test(system, data_loader_test, ckpt_path='best')
    return system.test_results

In [5]:
def get_metrics(outputs, labels, type='binary'):
    if type == 'binary':
        proba = torch.sigmoid(outputs)
        pred = (proba > 0.5)

        correct = pred.eq(labels).sum().item()
        return {
            'auc': roc_auc_score(labels, proba),
            'acc': correct / len(outputs),
            'correct': correct
        }
    elif type == 'regression':
        return {
            'mse': torch.nn.functional.mse_loss(outputs, labels, reduction='mean'),
            'l1': torch.nn.functional.l1_loss(outputs, labels, reduction='mean')
        }

In [6]:
def do_run(dataset, model_name, random_state, metrics_name='binary', deterministic=False, log_prefix='cv'):
    # split per pid
    pids = set(dataset.get_groups())
    pid_splits = KFold(n_splits=10, random_state=random_state, shuffle=True).split(range(len(pids)))

    outputs = torch.empty((len(dataset),))
    for f, (train_pids, test_pids) in enumerate(pid_splits):
        # create datasets   
        train_idx = [i for i, e in enumerate(dataset.examples) if e[0] in train_pids]
        test_idx = [i for i, e in enumerate(dataset.examples) if e[0] in test_pids]
        print(f'ds split into {len(train_idx)} train and {len(test_idx)} test')

        train_ds = ConflabSubset(dataset, train_idx)
        test_ds = ConflabSubset(dataset, test_idx)

        c_in = dataset.extractors['accel'].num_columns
        model_hparams = {'c_in': c_in} 

        fold_outputs = do_fold(train_ds, test_ds, 
            model_name, 
            model_hparams,
            deterministic=deterministic, 
            log_prefix=log_prefix+f'_fold{f}')
            
        outputs[test_idx] = fold_outputs['proba'].cpu()
        clear_output(wait=True)

    labels = torch.Tensor(dataset.get_all_labels())
    run_metrics = get_metrics(outputs, labels, metrics_name)

    return outputs, run_metrics

In [7]:
pose_extractor = ConflabPoseExtractor(conflab_pose_path)
pose_extractor.load_from_pickle('../tracks.pkl')

In [8]:
accel_extractor = ConflabAccelExtractor(midge_data_path, columns=['accelX', 'accelY', 'accelZ'])
sensor_extractor = ConflabAccelExtractor(midge_data_path)
label_extractor = ConflabLabelExtractor(os.path.join(conflab_speaking_status_path, 'speaking'))

In [9]:
# make windowed examples using the pose tracks.
examples = pose_extractor.make_examples()
# compose the dataset
accel_ds = ConflabDataset(examples, {
    'accel': accel_extractor,
    'label': label_extractor
})
sensor_ds = ConflabDataset(examples, {
    'accel': sensor_extractor,
    'label': label_extractor
})

100%|██████████| 8/8 [00:00<00:00, 591.13it/s]


In [10]:
# check the dataset
for i in range(len(sensor_ds)):
    assert sensor_ds[i]['accel'].shape[0] == 150, i

In [13]:
logger = logging.getLogger('paper_runs')
logger.setLevel(logging.INFO)
if not logger.handlers:
    f_handler = logging.FileHandler('paper_runs.csv', mode='w')
    logger.addHandler(f_handler)

datasets = {
    'all': sensor_ds,
    'accel': accel_ds
}

def do_paper_runs():
    results = {}
    # for model_name in ['minirocket', 'inception', 'resnet']:
    for model_name in ['inception', 'resnet']:
        model_results = {}

        for features in ['all', 'accel']:
            seed=22
            pl.utilities.seed.seed_everything(seed, workers=True)
            proba, metrics = do_run(
                datasets[features], 
                model_name, 
                random_state=seed, 
                metrics_name='binary', 
                deterministic=True, 
                log_prefix=f"{model_name}_{features}")
            model_results[features] = metrics
            pd.DataFrame(proba.numpy()).to_csv(
                os.path.join('outputs', f"{model_name}_{features}.csv"),
                header=False,
                index=False)
            logger.info(f"{model_name}, {features}, {metrics['auc']}, {metrics['acc']}")
        results[model_name] = model_results
    return results

In [14]:
do_paper_runs()

{'inception': {'all': {'auc': 0.6881288608492249,
   'acc': 0.7291870714985309,
   'correct': 31269},
  'accel': {'auc': 0.7985812018585552,
   'acc': 0.7675714752110443,
   'correct': 32915}},
 'resnet': {'all': {'auc': 0.6860748897127408,
   'acc': 0.7216081339489763,
   'correct': 30944},
  'accel': {'auc': 0.8008403875320723,
   'acc': 0.7731449092859475,
   'correct': 33154}}}