In [1]:
import time
from datetime import timedelta
import os
import argparse
import torch
import torch.nn.functional as F
from transformers import (
    AutoTokenizer,
    AutoModel,
    WhisperProcessor
)
from sentence_transformers import SentenceTransformer
from collections import OrderedDict
import pandas as pd
import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt
from pprint import pprint

from config import WhiSBERTConfig, CACHE_DIR, CHECKPOINT_DIR
from model import WhiSBERTModel
from data import AudioDataset, collate_train
from train import load_models
from utils import (
    mean_pooling,
    cos_sim_loss,
    sim_clr_loss,
    norm_temp_ce_loss
)

os.environ["TOKENIZERS_PARALLELISM"] = "false"



In [116]:
pd.read_csv('aad_hitop.csv')

Unnamed: 0,message_id,ang,anx,dep,ang_norm,anx_norm,dep_norm
0,PP572_0_1,2,2,2,2.346886,3.314925,3.479106
1,PP572_0_8,1,1,1,2.603998,3.166203,3.160666
2,PP572_0_9,5,5,5,-0.740902,3.049605,5.081950
3,PP572_0_10,12,12,12,2.209402,2.016459,-0.138630
4,PP572_0_11,4,4,4,1.375462,1.966641,1.700260
...,...,...,...,...,...,...,...
366071,P692_0_1183,3,3,3,6.851456,2.012588,-0.135535
366072,P692_0_1187,3,3,3,-2.776020,0.317252,3.993117
366073,P692_0_1188,3,3,3,-1.467750,-0.131700,1.204343
366074,P692_0_1191,1,1,1,4.560347,4.227186,4.138852


In [2]:
config = WhiSBERTConfig(
    whisper_model_id='openai/whisper-base',
    pooling_mode='mean',
    # use_sbert_encoder=True,
    n_new_dims=7,
    batch_size=8,
    shuffle=False,
    device='cpu'
)
processor, whisbert, tokenizer, sbert = load_models(config, '')



In [3]:
print('Preprocessing AudioDataset...')
dataset = AudioDataset(processor)
mini_size = int(0.1 * len(dataset))
drop_size = len(dataset) - mini_size
mini_dataset, _ = torch.utils.data.random_split(dataset, [mini_size, drop_size])

# Calculate lengths for the train/val split (80:20)
total_size = len(mini_dataset)
train_size = int(0.8 * total_size)  # 80% for training
val_size = total_size - train_size  # 20% for validation
# Perform the split
train_dataset, val_dataset = torch.utils.data.random_split(mini_dataset, [train_size, val_size])
print(f'\tTotal dataset size (N): {total_size}')
print(f'\tTraining dataset size (N): {train_size}')
print(f'\tValidation dataset size (N): {val_size}')

Preprocessing AudioDataset...
	Total dataset size (N): 50352
	Training dataset size (N): 40281
	Validation dataset size (N): 10071


In [4]:
data_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    num_workers=config.num_workers,
    shuffle=config.shuffle,
    collate_fn=collate_train
)

In [6]:
batch = next(iter(data_loader))
print(batch['audio_inputs'].shape)
print(len(batch['message']))
print(batch['outcomes'].shape)

torch.Size([8, 80, 3000])
8
torch.Size([8, 7])


In [7]:
encoded_input = tokenizer(batch['message'], padding=True, truncation=True, return_tensors='pt').to(config.device)
encoded_input['input_ids'].shape

torch.Size([8, 26])

In [8]:
with torch.no_grad():
    sbert_output = sbert(**encoded_input)
print(sbert_output.last_hidden_state.shape)
sentence_embeddings = mean_pooling(sbert_output.last_hidden_state, encoded_input['attention_mask'])
print(sentence_embeddings.shape)

torch.Size([8, 26, 384])
torch.Size([8, 384])


In [9]:
# Whisper-based tokenization
with torch.no_grad():
    outputs = processor.tokenizer(
        batch['message'],
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors='pt'
    ).to(config.device)
print(outputs['input_ids'].shape)
print(outputs['attention_mask'].shape)

torch.Size([8, 23])
torch.Size([8, 23])


In [10]:
# Get WhiSBERT's MEAN/LAST token
whis_embs = whisbert(
    batch['audio_inputs'].to(config.device),
    outputs['input_ids'],
    outputs['attention_mask']
)
whis_embs.shape

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


torch.Size([8, 391])

In [19]:
import json
from pprint import pprint

json_string = """
{'agr_mean': {(): {1: {'N': 1941,
                       'R': 0.6781856326039291,
                       'R2': 0.4599357522703914,
                       'R2_folds': 0.453639004616472,
                       'mae': 0.43537655507589523,
                       'mae_folds': 0.43537220698634294,
                       'mse': 0.3532908578549623,
                       'mse_folds': 0.35325036550347877,
                       'num_features': 384,
                       'r': 0.6787795547361397,
                       'r_folds': 0.6892055515849448,
                       'r_p': 2.5301844440603368e-262,
                       'r_p_folds': 9.948183028442736e-24,
                       'rho': 0.6955800032617787,
                       'rho_p': 9.264737934159614e-281,
                       'se_R2_folds': 0.019282732542315308,
                       'se_mae_folds': 0.00658957020554258,
                       'se_mse_folds': 0.01582590901910876,
                       'se_r_folds': 0.012218876227776765,
                       'se_r_p_folds': 6.5535193745646996e-24,
                       'se_train_mean_mae_folds': 0.011128111445277766,
                       'test_size': 195,
                       'train_mean_mae': 0.4333476080420102,
                       'train_mean_mae_folds': 0.6137855903151233,
                       'train_size': 1746,
                       '{modelFS_desc}': 'None'}}},
 'arousal_mean': {(): {1: {'N': 1941,
                           'R': 0.617598901571267,
                           'R2': 0.38142840322203553,
                           'R2_folds': 0.3490178392132953,
                           'mae': 0.054574852440665414,
                           'mae_folds': 0.05457506819011182,
                           'mse': 0.00789516409945485,
                           'mse_folds': 0.007895874935219847,
                           'num_features': 384,
                           'r': 0.6251156790734829,
                           'r_folds': 0.6417872788821157,
                           'r_p': 6.419578196009967e-211,
                           'r_p_folds': 4.1661066060463354e-19,
                           'rho': 0.6196590449478352,
                           'rho_p': 3.0170211158753247e-206,
                           'se_R2_folds': 0.021121273153522663,
                           'se_mae_folds': 0.0009046017261080018,
                           'se_mse_folds': 0.0017486554986763873,
                           'se_r_folds': 0.018111836549284596,
                           'se_r_p_folds': 3.6802273399741352e-19,
                           'se_train_mean_mae_folds': 0.0018380968091404,
                           'test_size': 195,
                           'train_mean_mae': 0.05761304522024313,
                           'train_mean_mae_folds': 0.06855474153595784,
                           'train_size': 1746,
                           '{modelFS_desc}': 'None'}}},
 'con_mean': {(): {1: {'N': 1941,
                       'R': 0.7493123684266637,
                       'R2': 0.5614690254771763,
                       'R2_folds': 0.5660593191636039,
                       'mae': 0.4458016341599674,
                       'mae_folds': 0.44577665831957436,
                       'mse': 0.40493397044206375,
                       'mse_folds': 0.4047287449266862,
                       'num_features': 384,
                       'r': 0.7501943618096237,
                       'r_folds': 0.765424274695027,
                       'r_p': 0.0,
                       'r_p_folds': 4.03134921378977e-28,
                       'rho': 0.7838881550659071,
                       'rho_p': 0.0,
                       'se_R2_folds': 0.03375709496680231,
                       'se_mae_folds': 0.009145447688730183,
                       'se_mse_folds': 0.04594218143826922,
                       'se_r_folds': 0.015399687846529345,
                       'se_r_p_folds': 3.809976898638083e-28,
                       'se_train_mean_mae_folds': 0.012997766058122941,
                       'test_size': 195,
                       'train_mean_mae': 0.5589895738388265,
                       'train_mean_mae_folds': 0.7325708337507906,
                       'train_size': 1746,
                       '{modelFS_desc}': 'None'}}},
 'ext_mean': {(): {1: {'N': 1941,
                       'R': 0.7272885481795476,
                       'R2': 0.5289486323131141,
                       'R2_folds': 0.5150440308294684,
                       'mae': 0.4419681643148561,
                       'mae_folds': 0.4419519709944864,
                       'mse': 0.3832523699673135,
                       'mse_folds': 0.38310053595384536,
                       'num_features': 384,
                       'r': 0.7315056441406073,
                       'r_folds': 0.7361770570901929,
                       'r_p': 0.0,
                       'r_p_folds': 9.893712344371962e-22,
                       'rho': 0.695890958000543,
                       'rho_p': 4.1071674040184995e-281,
                       'se_R2_folds': 0.05134779704767456,
                       'se_mae_folds': 0.010537451350460398,
                       'se_mse_folds': 0.035510759790296255,
                       'se_r_folds': 0.024589363796900878,
                       'se_r_p_folds': 9.365373296442015e-22,
                       'se_train_mean_mae_folds': 0.00956655106201722,
                       'test_size': 195,
                       'train_mean_mae': 0.5244290069724853,
                       'train_mean_mae_folds': 0.64755677469714,
                       'train_size': 1746,
                       '{modelFS_desc}': 'None'}}},
 'neu_mean': {(): {1: {'N': 1941,
                       'R': 0.6966799799391793,
                       'R2': 0.4853629944480553,
                       'R2_folds': 0.4713593460022267,
                       'mae': 0.4776359984885277,
                       'mae_folds': 0.4776401782030811,
                       'mse': 0.42887439186964965,
                       'mse_folds': 0.42881173879799783,
                       'num_features': 384,
                       'r': 0.6977851091499924,
                       'r_folds': 0.709841039617225,
                       'r_p': 2.82918116816316e-283,
                       'r_p_folds': 1.0839497129899245e-22,
                       'rho': 0.7047032133666933,
                       'rho_p': 2.5655737240011162e-291,
                       'se_R2_folds': 0.036404206874198,
                       'se_mae_folds': 0.011543887215630388,
                       'se_mse_folds': 0.02443126820996445,
                       'se_r_folds': 0.018361671045662352,
                       'se_r_p_folds': 1.0101732389789841e-22,
                       'se_train_mean_mae_folds': 0.01718468826074042,
                       'test_size': 195,
                       'train_mean_mae': 0.508786044043443,
                       'train_mean_mae_folds': 0.6880450788033405,
                       'train_size': 1746,
                       '{modelFS_desc}': 'None'}}},
 'ope_mean': {(): {1: {'N': 1941,
                       'R': 0.687173454147489,
                       'R2': 0.4722073560849911,
                       'R2_folds': 0.4673015827091218,
                       'mae': 0.4772704839058447,
                       'mae_folds': 0.47727723957952134,
                       'mse': 0.41368918346585476,
                       'mse_folds': 0.41370495149217523,
                       'num_features': 384,
                       'r': 0.6902178930977013,
                       'r_folds': 0.6941406897230195,
                       'r_p': 9.736823034412e-275,
                       'r_p_folds': 1.7709972322454935e-21,
                       'rho': 0.6457935165027847,
                       'rho_p': 1.6895507852241627e-229,
                       'se_R2_folds': 0.022460942187813547,
                       'se_mae_folds': 0.011904136663584926,
                       'se_mse_folds': 0.02532620779800574,
                       'se_r_folds': 0.015713329261864323,
                       'se_r_p_folds': 1.6794999944648988e-21,
                       'se_train_mean_mae_folds': 0.010627934853089976,
                       'test_size': 195,
                       'train_mean_mae': 0.49347467540502726,
                       'train_mean_mae_folds': 0.6282224203465161,
                       'train_size': 1746,
                       '{modelFS_desc}': 'None'}}},
 'pcl_score': {(): {1: {'N': 1407,
                        'R': 0.3432733230392371,
                        'R2': 0.1178365743104004,
                        'R2_folds': 0.106078036191816,
                        'mae': 8.219562694190266,
                        'mae_folds': 8.226491548005512,
                        'mse': 118.95726581099895,
                        'mse_folds': 119.13272828915247,
                        'num_features': 384,
                        'r': 0.3451403353677704,
                        'r_folds': 0.34873488175121886,
                        'r_p': 1.2335338379227026e-40,
                        'r_p_folds': 0.0008440256643428273,
                        'rho': 0.35432841464121084,
                        'rho_p': 7.011310494289468e-43,
                        'se_R2_folds': 0.01957191336658385,
                        'se_mae_folds': 0.1524234309824447,
                        'se_mse_folds': 5.220218239017663,
                        'se_r_folds': 0.023686210413116266,
                        'se_r_p_folds': 0.0005774695994057005,
                        'se_train_mean_mae_folds': 0.1334656746805633,
                        'test_size': 139,
                        'train_mean_mae': 3.415169929989724,
                        'train_mean_mae_folds': 8.9465991318074,
                        'train_size': 1268,
                        '{modelFS_desc}': 'None'}}},
 'valence_mean': {(): {1: {'N': 1941,
                           'R': 0.7302928590327442,
                           'R2': 0.5333276599542196,
                           'R2_folds': 0.5158404208074534,
                           'mae': 0.06420540013049159,
                           'mae_folds': 0.0642013774872571,
                           'mse': 0.008094708544819768,
                           'mse_folds': 0.008092333743545396,
                           'num_features': 384,
                           'r': 0.7307417301811368,
                           'r_folds': 0.7248172388175128,
                           'r_p': 1e-323,
                           'r_p_folds': 1.2995976901915584e-27,
                           'rho': 0.7278030592756253,
                           'rho_p': 5.832e-320,
                           'se_R2_folds': 0.020893744746816673,
                           'se_mae_folds': 0.0013748976450061688,
                           'se_mse_folds': 0.0006252243906719673,
                           'se_r_folds': 0.015197630855852347,
                           'se_r_p_folds': 1.1542791194444408e-27,
                           'se_train_mean_mae_folds': 0.0021706548075603686,
                           'test_size': 195,
                           'train_mean_mae': 0.07123230177174238,
                           'train_mean_mae_folds': 0.09731299072758595,
                           'train_size': 1746,
                           '{modelFS_desc}': 'None'}}}}
""".replace('()', '\'1\'').replace('{1:', '{\'1\':').replace('\'', '\"').strip()

dlatk_dict = json.loads(json_string)
pprint(dlatk_dict)

{'agr_mean': {'1': {'1': {'N': 1941,
                          'R': 0.6781856326039291,
                          'R2': 0.4599357522703914,
                          'R2_folds': 0.453639004616472,
                          'mae': 0.43537655507589523,
                          'mae_folds': 0.43537220698634294,
                          'mse': 0.3532908578549623,
                          'mse_folds': 0.35325036550347877,
                          'num_features': 384,
                          'r': 0.6787795547361397,
                          'r_folds': 0.6892055515849448,
                          'r_p': 2.5301844440603368e-262,
                          'r_p_folds': 9.948183028442736e-24,
                          'rho': 0.6955800032617787,
                          'rho_p': 9.264737934159614e-281,
                          'se_R2_folds': 0.019282732542315308,
                          'se_mae_folds': 0.00658957020554258,
                          'se_mse_folds': 0.01582590901910876,


In [25]:
for outcome in dlatk_dict:
    stats = dlatk_dict[outcome]['1']['1']
    print(f'Outcome: {outcome}')
    print(f'r: {stats["r"]}')
    print(f'mse: {stats["mse"]}')
    print()

Outcome: agr_mean
r: 0.6787795547361397
mse: 0.3532908578549623

Outcome: arousal_mean
r: 0.6251156790734829
mse: 0.00789516409945485

Outcome: con_mean
r: 0.7501943618096237
mse: 0.40493397044206375

Outcome: ext_mean
r: 0.7315056441406073
mse: 0.3832523699673135

Outcome: neu_mean
r: 0.6977851091499924
mse: 0.42887439186964965

Outcome: ope_mean
r: 0.6902178930977013
mse: 0.41368918346585476

Outcome: pcl_score
r: 0.3451403353677704
mse: 118.95726581099895

Outcome: valence_mean
r: 0.7307417301811368
mse: 0.008094708544819768

