Skip to content

ae-foster/cresp

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

On Contrastive Representations of Stochastic Processes License: MIT Python 3.8+

This project is based on PyTorch, Hydra and PyTorch Lightning.

Project organisation

config/ -> Hydra project configurations

src/ -> Everything that relates to models

utils/ -> Helper functions

experiments/ -> Where experiments are saved (logs, checkpoints, configs, etc)

Install

virtualenv -p python3.8 venv
source venv/bin/activate
pip install -r requirements.txt

Experiments

Sinusoids

  • Run the different models with:
python main.py -m +experiment=sine model=ssl,cnp self_attn=id,on  # Untargeted CReSP, FCLR, ANCP, CNP
  • Figure 2b results can be obtained by varying the distance between the modes:
python main.py -m +experiment=sine model=cnp,ssl self_attn=on dataset.eps=0.0,0.5,1.,2.,5.,8.,10
  • Figure 2c results can be obtained by varying the number of training views:
python main.py -m +experiment=sine model=cnp,ssl self_attn=on n_views_train=2,5,10,20,50

ShapeNet

  • Download and decompress the r2n2 dataset:
wget http://cvgl.stanford.edu/data2/ShapeNetRendering.tgz -P data
tar zxvf data/ShapeNetRendering.tgz -C data
  • Run the different models with:
python main.py +experiment=shapenet model=ssl targeted=True  # Targeted CReSP
python main.py -m +experiment=shapenet model=ssl,cnp self_attn=id,on  # Untargeted CReSP, FCLR, ANCP, CNP
python main.py +experiment=shapenet model=sup fix_clf_train=True  # Supervised
  • Figure 4a and 4b results can be obtained by varying the colour distortion strength:
python main.py -m +experiment=nocolour self_attn=on,id targeted=False,True model=ssl,cnp
python main.py -m +experiment=shapenet self_attn=on,id targeted=False,True model=ssl,cnp dataset.distortion_s=0.5,1.0,1.5
  • Figure 4c results can be obtained by varying the number of training views:
python main.py -m +experiment=shapenet self_attn=on targeted=False self_attn=on,id model=ssl n_views_train=6,12,24
  • Figure 5a can be obtained by varying the fraction of labels available:
python main.py -m +experiment=shapenet targeted=False model=ssl,cnp clf.prop=0.01,0.02,0.04,0.1,0.2,0.4,1.0
  • Figure 5b can be obtained by varying the number of test views:
python main.py -m +experiment=shapenet targeted=False model=ssl,cnp n_views_test=1,2,4,10,20

Snooker

  • Table 3 results can be obtained with:
python main.py +experiment=snooker model=cnp  # CNP
python main.py +experiment=snooker targeted=false  # Untargeted CReSP
python main.py +experiment=snooker  # Targeted CReSP
python main.py +experiment=snooker agg=kernel enc=simple  # MetaCDE

Miscellaneous

Logging

Can select the logger with logger=tensorboard, by default it's using logger=csv.

  • Tensorboard: execute the following and forward the port to access live logs:
tensorboard --logdir experiments/
  • CSV: The metrics additionally saved under .../NAME_OF_EXPERIMENT/RUN_ID/logs/metrics.csv

About

Code for 'On Contrastive Representations of Stochastic Processes' https://arxiv.org/abs/2106.10052

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages