Skip to content

Implementation of the Optimal Completion Distillation for Sequence Labeling

License

Notifications You must be signed in to change notification settings

SaeedNajafi/pytorch-ocd

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

33 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CircleCI

Optimal Completion Distillation (OCD) Training

Implementation of the Optimal Completion Distillation for Sequence Labeling
source : https://arxiv.org/abs/1810.01398

Requirements

python3, pytorch 1.0.0

Install

python3 -m venv env
source env/bin/activate
pip3 install .

How to use?

look at https://github.com/SaeedNajafi/pytorch-ocd/blob/master/ocd/__init__.py#L50 and
https://github.com/SaeedNajafi/pytorch-ocd/blob/master/tests/test_ocd.py#L132

from ocd import OCD

ocd_trainer = OCD(vocab_size=10, end_symbol_id=9)
...  # model defines scores for each step and each possible output token.
ocd_loss = ocd_trainer(model_scores, gold_output_sequence)
...  # backprop with ocd_loss

About

Implementation of the Optimal Completion Distillation for Sequence Labeling

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages