#### Prerequisites

In [None]:
%%capture

!pip install transformers
!pip install nltk

#### Imports 

In [2]:
from transformers import BertModel, BertTokenizerFast
import transformers
import logging
import torch
import nltk

#### Setup logging

In [3]:
logger = logging.getLogger('sagemaker')
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

In [4]:
logger.info(f'Using Transformers: {transformers.__version__}')
logger.info(f'Using Torch: {torch.__version__}')
logger.info(f'Using NLTK: {nltk.__version__}')

Using Transformers: 4.18.0
Using Torch: 1.8.1
Using NLTK: 3.6.7


#### Load BERT and NLTK tokenizers

In [5]:
# bert_tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
# bert_tokenizer.save_pretrained('./data/bert-tokenizer/')

In [6]:
bert_tokenizer = BertTokenizerFast.from_pretrained('./data/bert-tokenizer/')

In [7]:
nltk_tokenizer = nltk.load('./data/english.pickle')
# english.pickle is derived from downloaded artifacts when you run `nltk.download('punkt)`

#### Tokenize text to sentences and tokens

In [8]:
text = "Great product. Good design. But a little pricey!"

In [9]:
sentences = nltk_tokenizer.tokenize(text)
sentences

['Great product.', 'Good design.', 'But a little pricey!']

In [10]:
tokens = bert_tokenizer.tokenize(sentences[0])
tokens

['great', 'product', '.']

In [11]:
input_ids = bert_tokenizer.convert_tokens_to_ids(tokens)
input_ids

[2307, 4031, 1012]

##### BERT has 2 constraints
* All sentences must be padded or truncated to a fixed length
* Max sentence length is 512 tokens

#### Determine the max tokens length

In [12]:
max_len = 0
for sentence in sentences:
    input_ids = bert_tokenizer.encode(sentence, add_special_tokens=True)
    max_len = max(max_len, len(input_ids))
max_len

8

#### Encode sentences

In [13]:
input_ids = []
attention_masks = []

In [14]:
for sentence in sentences:
    encoded_dict = bert_tokenizer.encode_plus(sentence, 
                                              add_special_tokens=True,
                                              max_length=max_len,
                                              padding='max_length',
                                              return_attention_mask=True,
                                              return_tensors='pt',
                                              truncation=True)
    input_ids.append(encoded_dict['input_ids'])
    attention_masks.append(encoded_dict['attention_mask'])

In [15]:
input_ids

[tensor([[ 101, 2307, 4031, 1012,  102,    0,    0,    0]]),
 tensor([[ 101, 2204, 2640, 1012,  102,    0,    0,    0]]),
 tensor([[ 101, 2021, 1037, 2210, 3976, 2100,  999,  102]])]

In [16]:
attention_masks

[tensor([[1, 1, 1, 1, 1, 0, 0, 0]]),
 tensor([[1, 1, 1, 1, 1, 0, 0, 0]]),
 tensor([[1, 1, 1, 1, 1, 1, 1, 1]])]

In [17]:
input_ids = torch.cat(input_ids, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)

In [18]:
input_ids

tensor([[ 101, 2307, 4031, 1012,  102,    0,    0,    0],
        [ 101, 2204, 2640, 1012,  102,    0,    0,    0],
        [ 101, 2021, 1037, 2210, 3976, 2100,  999,  102]])

#### Load BERT model

In [19]:
# model = BertModel.from_pretrained('bert-base-uncased')
# model.save_pretrained('./data/bert-model/')

In [20]:
model = BertModel.from_pretrained('./data/bert-model/')

#### Get last hidden state

In [21]:
with torch.no_grad():
    outputs = model(input_ids, attention_mask=attention_masks)

[2022-12-16 19:18:56.079 pytorch-1-8-gpu-py-ml-g4dn-4xlarge-ebd0c963d7cb3d49c063789f9c22:1901 INFO utils.py:27] RULE_JOB_STOP_SIGNAL_FILENAME: None
[2022-12-16 19:18:56.167 pytorch-1-8-gpu-py-ml-g4dn-4xlarge-ebd0c963d7cb3d49c063789f9c22:1901 INFO profiler_config_parser.py:102] Unable to find config at /opt/ml/input/config/profilerconfig.json. Profiler is disabled.


In [22]:
last_hidden_state = outputs.last_hidden_state[:, 0, :]
last_hidden_state

tensor([[-0.3160,  0.1158, -0.3213,  ..., -0.2726,  0.3246,  0.2845],
        [-0.2732,  0.1145, -0.2756,  ..., -0.2125,  0.0405,  0.6684],
        [-0.2529,  0.1539, -0.4362,  ..., -0.0796,  0.1113,  0.4073]])

In [23]:
last_hidden_state.shape

torch.Size([3, 768])

#### Compute output vector

In [24]:
sentence_vectors = last_hidden_state.detach().numpy()

In [25]:
sentence_vectors

array([[-0.31601778,  0.11579025, -0.32130766, ..., -0.27264526,
         0.32456014,  0.2845447 ],
       [-0.27318975,  0.11449272, -0.27562585, ..., -0.21247703,
         0.04046102,  0.66837615],
       [-0.25293225,  0.15385807, -0.43620795, ..., -0.07957672,
         0.11127631,  0.4073293 ]], dtype=float32)

In [26]:
paragraph_vector = sentence_vectors.mean(axis=0)  # mean of each column
paragraph_vector = paragraph_vector.tolist()
len(paragraph_vector)

768

## Text Encoder
* The text encoder below encapsulates the model, tokenizers and the above code logic into a callable interface.
* It transforms incoming raw text payload into a BERT feature vector [1 x 768].

In [27]:
DATA_PATH = './data'

In [28]:
from transformers import BertModel, BertTokenizerFast
import transformers
import logging
import torch
import nltk

In [29]:
class TextEncoder:
    
    bert_model = None
    bert_tokenizer = None
    nltk_tokenizer = None
    
    @classmethod
    def load_bert_model(cls):
        if cls.bert_model is None:
            cls.bert_model = BertModel.from_pretrained(f'{DATA_PATH}/bert-model/')
        return cls.bert_model
    
    @classmethod
    def load_bert_tokenizer(cls):
        if cls.bert_tokenizer is None:
            cls.bert_tokenizer = BertTokenizerFast.from_pretrained(f'{DATA_PATH}/bert-tokenizer/')
        return cls.bert_tokenizer
    
    @classmethod
    def load_nltk_tokenizer(cls):
        if cls.nltk_tokenizer is None:
            cls.nltk_tokenizer = nltk.load(f'{DATA_PATH}/english.pickle')
        return cls.nltk_tokenizer
            
    @classmethod 
    def encode(cls, text):
        bert_model = cls.load_bert_model()
        bert_tokenizer = cls.load_bert_tokenizer()
        nltk_tokenizer = cls.load_nltk_tokenizer()
        sentences = nltk_tokenizer.tokenize(text)
        
        max_len = 0
        for sentence in sentences:
            input_ids = bert_tokenizer.encode(sentence, add_special_tokens=True)
            max_len = max(max_len, len(input_ids))
        
        input_ids = []
        attention_masks = []
        
        for sentence in sentences:
            encoded_dict = bert_tokenizer.encode_plus(sentence, 
                                                      add_special_tokens=True, 
                                                      max_length=max_len, 
                                                      padding='max_length', 
                                                      return_attention_mask=True, 
                                                      return_tensors='pt', 
                                                      truncation=True)
        input_ids.append(encoded_dict['input_ids'])
        attention_masks.append(encoded_dict['attention_mask'])
        input_ids = torch.cat(input_ids, dim=0)
        attention_masks = torch.cat(attention_masks, dim=0)
        
        with torch.no_grad():
            outputs = bert_model(input_ids, attention_mask=attention_masks)
        
        last_hidden_state = outputs.last_hidden_state[:, 0, :]
        sentence_vectors = last_hidden_state.detach().numpy()
        paragraph_vector = sentence_vectors.mean(axis=0)  # mean of each column
        paragraph_vector = paragraph_vector.tolist()
        return paragraph_vector

In [30]:
def get_encoded_text(text):
    encoded_text = TextEncoder.encode(text)
    return encoded_text

#### Test text encoder

In [31]:
payload = 'I purchased this headphones on Black Friday sale. Poor quality. Not worth the money.'

In [32]:
response = get_encoded_text(payload)
response

[-0.5063672065734863,
 0.02691291831433773,
 -0.21883559226989746,
 -0.04903881996870041,
 -0.45220017433166504,
 -0.012785189785063267,
 0.2532523274421692,
 0.6240672469139099,
 0.23640933632850647,
 -0.18753862380981445,
 0.10757631063461304,
 -0.4399247467517853,
 -0.0724121555685997,
 0.2958056628704071,
 0.1202806830406189,
 -0.12032034993171692,
 0.0006345053552649915,
 0.30661094188690186,
 0.12856058776378632,
 -0.12162799388170242,
 -0.012813800014555454,
 -0.16330821812152863,
 -0.11310204863548279,
 -0.0003417262341827154,
 -0.29698044061660767,
 -0.13436968624591827,
 0.13270355761051178,
 -0.12359733879566193,
 0.14219136536121368,
 0.13176468014717102,
 -0.10959090292453766,
 0.3327617943286896,
 -0.4305773973464966,
 -0.3280176520347595,
 -0.04573799669742584,
 -0.23699122667312622,
 0.28261250257492065,
 0.2805051803588867,
 0.14983497560024261,
 -0.08238425105810165,
 -0.024005969986319542,
 -0.08355547487735748,
 0.17082443833351135,
 0.790075957775116,
 -0.006176157