<a href="https://colab.research.google.com/github/IsaacRe/Syntactically-Constrained-Sampling/blob/main/notebooks/Adding_a_New_Constraint.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Adding a New Syntax Constraint

This notebook goes through the steps of adding a new syntax constraint that forces each new sentence to be on a new line.

In [None]:
!pip install git+https://github.com/IsaacRe/transformers@syntactically-constrained-sampling
!pip install git+https://github.com/IsaacRe/Syntactically-Constrained-Sampling

In [1]:
from scs.incremental_parse import IncrementalParser, SpecialToken, ParseFailure
from typing import Union

Core logic is defined in the `IncrementalParser`. Subclasses must implement the following methods:
- `_copy_from(self, other)` - copy state from on parser to another
    - `other`: `IncrementalParser` - parser of the same subclass from which to copy parse state
- `_append(self, char)` - continue parsing, raising a `ParseFailure` when the given character deviates from the defined syntax
    - `char`: `Union[str, SpecialToken]` - character or special token to continue parsing

The parser is used to check validity of candidate tokens in the tokenizer's vocab during generation.

In [2]:
END_PUNCT = ['.', '?', '!']

class NewLineParser(IncrementalParser):
    
    def __init__(self):
        super().__init__()
        self.finished_sentence = False
        
    def _copy_from(self, other: "NewLineParser"):
        self.finished_sentence = other.finished_sentence
        
    def _append(self, char: Union[str, SpecialToken]):
        if char in END_PUNCT:
            self.finished_sentence = True
        else:
            if self.finished_sentence and char != '\n':
                raise ParseFailure('Expected newline')
            self.finished_sentence = False

In [3]:
from scs.handler import SyntaxValidityCheckFactory, SyntaxValidityCheckHandler
from scs.constraint import SyntaxConstraint

The `SyntaxValidityCheckFactory` is used to create newly initialized parsers at the beginning each new generation, wrapped with a `SyntaxConstraint`.

In [4]:
class NewLineCheckFactory(SyntaxValidityCheckFactory):
    
    def __call__(self) -> SyntaxConstraint:
        return SyntaxConstraint(NewLineParser())

In [5]:
import torch
from transformers.pipelines import pipeline
from transformers.generation.output_validity import get_token_vocab

In [6]:
pipe = pipeline(model='gpt2')
tokenizer = pipe.tokenizer
model = pipe.model

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Xformers is not installed correctly. If you want to use memorry_efficient_attention to accelerate training use the following command to install Xformers
pip install xformers.


The `SyntaxValidityCheckHandler` maintains a reference to the running parser and token ids. It is the link between model's logit output and the token input to the parser.

In [7]:
handler = SyntaxValidityCheckHandler(
    get_token_vocab(tokenizer),
    NewLineCheckFactory(),
)

Once created, the handler can be passed to directly to `generate` with the `output_validity_check` keyword.

In [10]:
token_out = tokenizer(['The king in Spain'])
out = model.generate(torch.LongTensor(token_out['input_ids']), output_validity_check=handler)
result, = tokenizer.batch_decode(out.numpy())
print(result)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


in generate

Generating with sample





STOPPING gen
The king in Spain, however, decided, and has long since forgotten, that the French have in no way ever offered their aid to Germany.

The French have made it their duty to work with them in Italy and to work with Austria while
