Skip to content

JRC1995/BeamTreeRecursiveCells

Repository files navigation

Official Code for Beam Tree Recursive Cells (ICML 2023)

ArXiv Link

For an extended repo see: https://github.com/JRC1995/BeamRecursionFamily/tree/main

Credits:

Requirements

  • torch==1.10.0
  • tqdm==4.62.3
  • jsonlines==2.0.0
  • torchtext==0.8.1
  • ninja==1.10.2
  • typing-extensions==4.5.0
  • psutil==5.8.0
  • tensorflow-datasets==4.5.2

Data Setup

You can verify if the data is properly set up from the directory tree here.

Processing Data

  • Go to preprocess/ and run each preprocess files to preprocess the corresponding data (process_SNLI_addon.py must be run after process_SNLI.py; otherwise no order requirement)

We share some of the processed data with its exact splits here (put the processed_data folder in the outermost project directory).

How to train

Train: python trian.py --model=[insert model name] -- dataset=[insert dataset name] --times=[insert total runs] --device=[insert device name] --model_type=[classifier/sentence_pair]

  • Check argparser.py for exact options.
  • Model type sentence_pair represents sentence-matching tasks like NLI. Modely type classifier represents simple sentence classification tasks.
  • Generally we use total times as 3.

Tree Parsing

  • For tree parsing from a classifier model: python extract_trees_classifier.py --model=[insert model name] --device=[insert device name] -- dataset=[insert dataset name]
  • For tree parsing from a sentence-paur matching model: python extract_trees_nli.py --model=[insert model name] --device=[insert device name] -- dataset=[insert dataset name]

Inputs for parsing can be modified from inside a list in extract_trees_classifier.py or python extract_trees_nli.py (line 66)

Dataset Nomenclature

The dataset nomenclature in the codebase and in the paper are a bit different. We provide a mapping here of the form ([codebase dataset name] == [paper dataset name])

  • listopsc == ListOps
  • listopsd == ListOps-DG
  • listops_ndr50 == ListOps-DG1
  • listops_ndr100 == ListOps-DG2
  • proplogic == Logical Inference (Operator generalization split)
  • proplogic_C == Logical Inference (C-split for systematic generalization)
  • SST2 == SST2
  • SST5 == SST5
  • IMDB == IMDB
  • MNLIdev == MNLI

The speed-suffixed names are for stress tests.

Model Nomenclature

The model nomenclature in the codebase and in the paper are a bit different. We provide a mapping here of the form ([codebase model name] == [paper model name])

  • RCell == RecurrentGRC
  • BalancedTreeCell == BalancedTreeGRC
  • RandomTreeCell == RandomTreeGRC
  • GoldTreeCell == GoldTreeGRC
  • GumbelTreeLSTM == GumbelTreeLSTM
  • GumbelTreeCell == GumbelTreeGRC
  • MCGumbelTreeCell == MCGumbelTreeGRC
  • CYKCell == CYK-GRC
  • OrderedMemory = Ordered Memory
  • CRvNN == CRvNN
  • CRvNN_worst == CRvNN without halt (during stress test)
  • BSRPCell == BSRP-GRC (beam 5)
  • BigBSRPCell == BSRP-GRC (beam 8)
  • NDR = NDR (Neural Data Router)
  • BeamTreeLSTM == BT-LSTM (beam 5)
  • BeamTreeCell == BT-GRC (beam 5)
  • SmallerBeamTreeCell == BT-GRC (beam 2)
  • DiffBeamTreeCell == BT-GRC + OneSoft (beam 5)
  • SmallerDiffBeamTreeCell == BT-GRC + OneSoft (beam 2)
  • DiffSortBeamTreeCell == BT-GRC + SOFT (beam 5)

Citation

@InProceedings{Chowdhury2023beam,
  title = 	 {Beam Tree Recursive Cells},
  author =       {Ray Chowdhury, Jishnu and Caragea, Cornelia},
  booktitle = 	 {Proceedings of the 40th International Conference on Machine Learning},
  year = 	 {2023}
}

Contact the associated github email for any question or issue.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published