# ConBERT inference
This notebook illustrates the inference of Conditional BERT model only.

Important note: The notebook use pseudo-absolute path and should be launched only once. So If you want to launch it second time, restart the kernel.

In [1]:
import os

# Upcast the path to the src folder
os.chdir('..')
print(os.getcwd())

/home/leon/Projects/Programming/Study/Python/ML_Inno/PMLDL/PML_ASS_1


In [None]:
import torch
import numpy as np

def manual_seed(seed):
    """
    Function to set the seed value for reproducibility
    :param seed: seed value
    :return: None
    """
    # PyTorch manual seed
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    # NumPy manual seed
    np.random.seed(seed)

# Set the seed value
seed = 42

# Call the manual seeding function
manual_seed(seed)

In [2]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [3]:
from importlib import reload

In [4]:
import models.Conbert.conbert
reload(models.Conbert.conbert)
from models.Conbert.conbert import CondBertRewriter

In [5]:
from transformers import BertTokenizer, BertForMaskedLM
import pickle
from tqdm.auto import tqdm, trange

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Load the model

In [7]:
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)

In [8]:
model = BertForMaskedLM.from_pretrained(model_name)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [9]:
model.to(device);

#### Load vocabularies for spans detection

In [10]:
vocab_root = 'models/Conbert/vocab/'

In [11]:
with open(vocab_root + "negative-words.txt", "r") as f:
    s = f.readlines()
negative_words = list(map(lambda x: x[:-1], s))
with open(vocab_root + "toxic_words.txt", "r") as f:
    ss = f.readlines()
negative_words += list(map(lambda x: x[:-1], ss))

with open(vocab_root + "positive-words.txt", "r") as f:
    s = f.readlines()
positive_words = list(map(lambda x: x[:-1], s))

In [12]:
import pickle
with open(vocab_root + 'word2coef.pkl', 'rb') as f:
    word2coef = pickle.load(f)

In [13]:
token_toxicities = []
with open(vocab_root + 'token_toxicities.txt', 'r') as f:
    for line in f.readlines():
        token_toxicities.append(float(line))
token_toxicities = np.array(token_toxicities)
token_toxicities = np.maximum(0, np.log(1/(1/token_toxicities-1)))   # log odds ratio

# discourage meaningless tokens
for tok in ['.', ',', '-']:
    token_toxicities[tokenizer.encode(tok)][1] = 3

for tok in ['you']:
    token_toxicities[tokenizer.encode(tok)][1] = 0

### Applying the model

In [14]:
reload(models.Conbert.conbert)
from models.Conbert.conbert import CondBertRewriter

editor = CondBertRewriter(
    model=model,
    tokenizer=tokenizer,
    device=device,
    neg_words=negative_words,
    pos_words=positive_words,
    word2coef=word2coef,
    token_toxicities=token_toxicities,
)

In [15]:
print(editor.translate('You are an idiot!', prnt=False))

you are an the !


### Multiunit

In [16]:
editor = CondBertRewriter(
    model=model,
    tokenizer=tokenizer,
    device=device,
    neg_words=negative_words,
    pos_words=positive_words,
    word2coef=word2coef,
    token_toxicities=token_toxicities,
    predictor=None,
)

In [17]:
from models.Conbert.multiword import masked_token_predictor_bert
reload(masked_token_predictor_bert)
from models.Conbert.multiword.masked_token_predictor_bert import MaskedTokenPredictorBert

In [18]:
predictor = MaskedTokenPredictorBert(model, tokenizer, max_len=250, device=device, label=0, contrast_penalty=0.0)
editor.predictor = predictor

def adjust_logits(logits, label):
    """
    Function to adjust logits for the Conbert model
    :param logits: the logits from the model
    :param label: the label
    :return: adjusted logits
    """
    return logits - editor.token_toxicities * 3

predictor.logits_postprocessor = adjust_logits

print(editor.replacement_loop('You are an idiot!', verbose=False))

you are an old man !


In [19]:
%%time
print(editor.replacement_loop('You are an idiot!', verbose=False, n_units=1))

you are an old man !
CPU times: user 283 ms, sys: 68.8 ms, total: 352 ms
Wall time: 350 ms


In [20]:
%%time
print(editor.replacement_loop('You are an idiot!', verbose=False, n_units=3))

you are an old man !
CPU times: user 825 ms, sys: 264 ms, total: 1.09 s
Wall time: 1.1 s


In [21]:
%%time
print(editor.replacement_loop('You are an idiot!', verbose=False, n_units=10))

you are an old man !
CPU times: user 962 ms, sys: 181 ms, total: 1.14 s
Wall time: 1.14 s


# Simplified inference
As a simple way of use of this model you could use the custom wrapper over it.

In [22]:
from models.Conbert.conbert_wrapper import Conbert

conbert_dir = 'models/Conbert'
model = Conbert(device, conbert_dir)

Loading BERT tokenizer...


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Loading BERT vocabularies...


In [23]:
model.detoxicate('You are an idiot!')

'you are an the !'