Skip to content

vzantedeschi/LatentTrees

Repository files navigation

Learning Binary Trees by Argmin Differentiation

Code source of ICML 2021 paper Learning Binary Trees by Argmin Differentiation.

Dependencies

Install PyTorch, following the guidelines.

On Ubuntu16.04+, make sure you have GLIBCXX_3.4.22 support via libstdc++.so.6:

sudo add-apt-repository ppa:ubuntu-toolchain-r/test
sudo apt-get update
sudo apt-get install gcc-4.9
sudo apt-get upgrade libstdc++6
sudo apt-get dist-upgrade

Setting up the cpp extensions requires gcc-9 or above:

sudo apt install gcc-9
sudo apt install g++-9

Plotting with Networkx requires the following libraries:

sudo apt-get install python3-dev graphviz libgraphviz-dev pkg-config

Setup

pip3 install -r requirements.txt
CXX=gcc python3 setup.py build_ext --inplace

Train on toy datasets

python3 fit_toyset.py

Default configuration is stored in 'config/default-xor.yaml'. You can edit directly the config file or change values from the command line, e.g. as follows:

python3 fit_toyset.py dataset.N=1000 model.SPLIT=linear

See Hydra for a tutorial.

Citation

  @article{zantedeschi2021learning,
    title={Learning Binary Trees by Argmin Differentiation},
    author={Zantedeschi, Valentina and Kusner, Matt J and Niculae, Vlad},
    journal={ICML},
    year={2021}
  }

About

Code source of Learning Binary Trees by Argmin Differentiation.

Resources

Stars

Watchers

Forks