## Many thanks to [this Guthub issue](https://github.com/huggingface/datasets/issues/224) and the [resulting repository](https://github.com/lucadiliello/bleurt-pytorch)
### This notebook is an extended copy of [this notebook](https://colab.research.google.com/drive/1KsCUkFW45d5_ROSv2aHtXgeBa2Z98r03?usp=sharing)

In [None]:
# Option 1: Install bleurt repository
!pip install --upgrade pip  # ensures that pip is current
!git clone https://github.com/google-research/bleurt.git
!pip install ./bleurt

In [1]:
# Option 2: Clone repository from https://github.com/google-research/bleurt and append it to path
import sys
sys.path.append("../bleurtMaster")

In [2]:
import bleurt
from bleurt import score as bleurt_score
import sys
sys.argv = sys.argv[:1] ##thanks https://github.com/google-research/bleurt/issues/4

In [3]:
## Step 1: Convert model to torch

import tensorflow.compat.v1 as tf
import torch

checkpoint = "../neg_bleurt_500" #path to saved bleurt model
imported = tf.saved_model.load_v2(checkpoint)

import transformers
import torch.nn as nn

class BleurtModel(nn.Module):
    """
    """

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.bert = transformers.BertModel(config)
        self.dense = nn.Linear(config.hidden_size,1)

    def forward(self, input_ids, input_mask, segment_ids):
        cls_state = self.bert(input_ids, input_mask,
                            #   segment_ids)[0][:,0]#[1] doesnt work either
                              segment_ids).pooler_output # this is fix #2 - taking pooler output
        return self.dense(cls_state)

state_dict = {}
for variable in imported.variables:
    n = variable.name
    if n.startswith('global'):
        continue
    data = variable.numpy()
    # if 'dense' in n:
    if 'kernel' in n:  # this is fix #1 - considering 'kernel' layers instead of 'dense'
        data = data.T
    n = n.split(':')[0]
    n = n.replace('/','.')
    n = n.replace('_','.')
    n = n.replace('kernel','weight')
    if 'LayerNorm' in n:
        n = n.replace('beta','bias')
        n = n.replace('gamma','weight')
    elif 'embeddings' in n:
        n = n.replace('word.embeddings','word_embeddings')
        n = n.replace('position.embeddings','position_embeddings')
        n = n.replace('token.type.embeddings','token_type_embeddings')
        n = n + '.weight'
    state_dict[n] = torch.from_numpy(data)

In [4]:
config = transformers.BertConfig(hidden_size= 128, hidden_act= "gelu", initializer_range= 0.02, vocab_size= 30522, hidden_dropout_prob= 0.1, num_attention_heads= 2, type_vocab_size= 2, max_position_embeddings= 512, num_hidden_layers= 2, intermediate_size= 512, attention_probs_dropout_prob= 0.1)
bleurt_model = BleurtModel(config)
bleurt_model.load_state_dict(state_dict, strict=False)  # strict=False added otherwise crashes.
# Should be safe, according to this https://github.com/huggingface/transformers/issues/6882#issuecomment-884730078
for param in bleurt_model.parameters():
    param.requires_grad = False
bleurt_model.eval()

from transformers import BertForSequenceClassification
config = transformers.BertConfig(hidden_size= 128, hidden_act= "gelu", initializer_range= 0.02, vocab_size= 30522, hidden_dropout_prob= 0.1, num_attention_heads= 2, type_vocab_size= 2, max_position_embeddings= 512, num_hidden_layers= 2, intermediate_size= 512, attention_probs_dropout_prob= 0.1, num_labels=1)
bleurt_model = BertForSequenceClassification(config)
state_dict['classifier.weight'] = state_dict.pop('dense.weight')
state_dict['classifier.bias'] = state_dict.pop('dense.bias')
bleurt_model.load_state_dict(state_dict, strict=False)

bleurt_model.save_pretrained("negBLEURT") # Note: this saves a pytorch model but is missing the tokenizer info

In [5]:
## Step 2: Create a tokenizer

import json
from transformers import BertTokenizerFast

with open(f'{checkpoint}/bleurt_config.json','r') as f:
    bleurt_config = json.load(f)

max_seq_length = bleurt_config["max_seq_length"]
vocab_file = f'{checkpoint}/{bleurt_config["vocab_file"]}'
do_lower_case = bleurt_config["do_lower_case"]

tokenizer = bleurt.lib.tokenizers.create_tokenizer(
    vocab_file=vocab_file, do_lower_case=do_lower_case, sp_model=None)

mytok = BertTokenizerFast(vocab_file=vocab_file, do_lower_case=do_lower_case, max_seq_length=max_seq_length)
mytok.save_pretrained("negBLEURT")

INFO:tensorflow:Creating WordPiece tokenizer.
INFO:tensorflow:WordPiece tokenizer instantiated.


('negBLEURT\\tokenizer_config.json',
 'negBLEURT\\special_tokens_map.json',
 'negBLEURT\\vocab.txt',
 'negBLEURT\\added_tokens.json',
 'negBLEURT\\tokenizer.json')

In [7]:
## Step 3: Compare model outputs of Pytorch transformer model and BleurtScorer
references = ["a bird chirps by the window", "This is a test."]
candidates = ["a bird chirps by the window", "This isn't a test."]

scorer = bleurt_score.BleurtScorer(checkpoint)
scores = scorer.score(references=references, candidates=candidates)
print(scores)

INFO:tensorflow:Reading checkpoint neg_bleurt_500.
INFO:tensorflow:Config file found, reading.
INFO:tensorflow:Will load checkpoint bert_custom
INFO:tensorflow:Loads full paths and checks that files exists.
INFO:tensorflow:... name:bert_custom
INFO:tensorflow:... bert_config_file:bert_config.json
INFO:tensorflow:... max_seq_length:512
INFO:tensorflow:... vocab_file:vocab.txt
INFO:tensorflow:... do_lower_case:True
INFO:tensorflow:... sp_model:None
INFO:tensorflow:... dynamic_seq_length:False
INFO:tensorflow:Creating BLEURT scorer.
INFO:tensorflow:Creating WordPiece tokenizer.
INFO:tensorflow:WordPiece tokenizer instantiated.
INFO:tensorflow:Creating Eager Mode predictor.
INFO:tensorflow:Loading model.
INFO:tensorflow:BLEURT initialized.
[0.9107678532600403, 0.12219361960887909]


In [8]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

bleurt_tok = AutoTokenizer.from_pretrained("negBLEURT")
bleurt_model = AutoModelForSequenceClassification.from_pretrained("negBLEURT")

encoding = bleurt_tok(references, candidates, padding=True, return_tensors='pt')
bleurt_model.eval()
bleurt_model(**encoding).logits.flatten().tolist()

[0.9107482433319092, 0.12223134934902191]