In [1]:
from transformers import pipeline

class QualityControlPipeline:
    
    def __init__(self, type):
        assert type in ['captions', 'questions', 'sentences']
        self.pipe = pipeline('text2text-generation', model=f'ibm/qcpg-{type}')
        self.ranges = {
            'captions': {'lex': [0, 90], 'syn': [0, 80], 'sem': [0, 95]},
            'sentences': {'lex': [0, 100], 'syn': [0, 80], 'sem': [0, 95]},
            'questions': {'lex': [0, 90], 'syn': [0, 75], 'sem': [0, 95]}
        }[type]

    def __call__(self, text, lexical, syntactic, semantic, **kwargs):
        assert all([0 <= val <= 1 for val in [lexical, syntactic, semantic]]), \
                 f' control values must be between 0 and 1, got {lexical}, {syntactic}, {semantic}'
        names = ['semantic_sim', 'lexical_div', 'syntactic_div']
        control = [int(5 * round(val * 100 / 5)) for val in [semantic, lexical, syntactic]]
        control ={name: max(min(val , self.ranges[name[:3]][1]), self.ranges[name[:3]][0]) for name, val in zip(names, control)}
        control = [f'COND_{name.upper()}_{control[name]}' for name in names]
        assert all(cond in self.pipe.tokenizer.additional_special_tokens for cond in control)
        text = ' '.join(control) + text if isinstance(text, str) else [' '.join(control) for t in text]
        return self.pipe(text, **kwargs)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = QualityControlPipeline('sentences')

In [3]:
model('Is this going to work or what are we doing here?', lexical=0.3, syntactic=0.5, semantic=0.8)

[{'generated_text': "Is that going to work or what is it we're doing?"}]

In [8]:
model('How are the things going to work if your progress is so weak?', lexical=0.3, syntactic=0.5, semantic=0.8)

[{'generated_text': "How will things work if you're slow on your own progress?"}]

In [9]:
model = QualityControlPipeline('questions')

Downloading (…)lve/main/config.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 1.50k/1.50k [00:00<00:00, 131kB/s]
Downloading (…)"pytorch_model.bin";: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 892M/892M [03:03<00:00, 4.85MB/s]
Downloading (…)okenizer_config.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 1.62k/1.62k [00:00<00:00, 317kB/s]
Downloading (…)"spiece.model";: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 792k/792k [00:01<00:00, 633kB/s]
Downloading (…)/main/tokenizer.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 1.39M/1.39M [00:02<00:00, 635kB/s]
Downloading (…)in/added_tokens.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 1.7

In [11]:
model('how to install windows?', lexical=0.3, syntactic=0.5, semantic=0.8)

[{'generated_text': 'How do you install windows?'}]