This project is described in our paper titled "Relational Visual Information explains Human Social Inference: A Graph Neural Network model for Social Interaction Recognition". Link to Preprint
This repository contains code to train and test our SocialGNN models (as well as baseline VisualRNN models) on animated and natural stimuli.
P.S.: VisualRNN = CueBasedLSTM
- For macOS:
Conda Environments/condaenv_macbook_gnnEnv_Oct17_2022.yml
- For Linux:
Conda Environments/condaenv_rockfish_gnnEnv_Jul1423.yml
Please refer to our paper for the terminology used here and for changes in parameter settings.
python get_accuracy_predictions_PHASE_mainset.py --model_name=SocialGNN_E --train_datetime=20230503 --context_info=True --bootstrap_no=0 --save_predictions=False
--model_name= SocialGNN_V/ SocialGNN_E/ CueBasedLSTM/ CueBasedLSTM-Relation/ SocialGNN_V_onlyagents/ SocialGNN_E_onlyagentedges
--train_datetime= 20230503 / 20230617 (for SocialGNN_V_onlyagents/SocialGNN_E_onlyagentedges)
--bootstrap_no= 0-9
python traintest_bootstrapsplits_PHASE_mainset.py --model_name="SocialGNN_V" --context_info=True
python traintest_PHASE_genset.py --mode=test --model_name=SocialGNN_E --context_info=True
python traintest_PHASE_genset.py --mode=train --model_name=SocialGNN_E --context_info=True
python SocialGNN_get_activations.py --model_name=SocialGNN_E --context_info=True --bootstrap_no=0 --dataset=main_set --train_datetime=20230503 --activation_type=RNN
python SocialGNN_get_activations.py --model_name=SocialGNN_E --context_info=True --dataset=generalization_set --train_datetime=20230515 --activation_type=RNN
Note: motion energy files need to be downloaded into 'Activations' from OSF folder
RSA_Github.ipynb
Set <prediction_type> to 2 for social v/s non-social classification; set to 5 for classifying into the 5 gaze labels The first 10 bootstrpas correspond to "dataset=5Jun23", the next 10 to "dataset=14Jun23"
python traintest_bootstrapsplits_Gaze.py test <model_to_test> <prediction_type> <dataset>
Example:
python traintest_bootstrapsplits_Gaze.py test CueBasedLSTM-Relation 5 5Jun23
python traintest_bootstrapsplits_Gaze.py train <model_to_train> <prediction_type> <dataset>
python VGG19full_traintest_gaze.py --mode=test --dataset=5Jun23 --output_type=2
Note 1: may need to run this outside gnnEnv conda environment Note 2: need to download original .pik files from the PHASE dataset to rerun all
SocialGNN_Generating_Plots_Github.ipynb