# Log line embedding using a BERT-based encoder

This notebook showcases how to use a pretrained DistilBert based model to embed log lines from text into a vector space, using Huggingface Transformers and Datasets libraries.

Note: This notebook assumes [Cookiecutter datascience](https://drivendata.github.io/cookiecutter-data-science/) directory structure of the project, and expects to be in /notebooks/ folder

In [1]:
from datasets import load_dataset
import numpy as np
from dataclasses import dataclass
from typing import List, Union, Dict, Optional
import torch
from transformers import DistilBertTokenizerFast, DistilBertPreTrainedModel, DistilBertModel
from transformers.file_utils import ModelOutput
from pathlib import Path
import re

Setup general used objects and constants.

In [2]:
project_base_dir = Path.cwd().parent
data_dir = project_base_dir / 'data'
base_pretrained_model_name = "distilbert-base-cased"

## Dataset preparation
First we load HDFS1 dataset and select first 1000 lines from it as a demonstrative subset

In [3]:
dataset_path = Path('/home/cernypro/dev/source/ml4logs/data/interim/HDFS1/train-data-HDFS1-cv1-1.log')
dataset_name = dataset_path.stem
dataset = load_dataset('text', data_files=str(dataset_path), split='train')

Using custom data configuration default-b1cdfca81a0e7c30
Reusing dataset text (/home/cernypro/.cache/huggingface/datasets/text/default-b1cdfca81a0e7c30/0.0.0/44d63bd03e7e554f16131765a251f2d8333a5fe8a73f6ea3de012dbc49443691)


Now we perform a rudimentary log-line preprocessing, removing the timestamp from each line (note, the model used in this notebook was pretrained with such preprocessing done)

In [4]:
HDFS1_TIMESTAMP_PATTERN = re.compile(r'^(\d+) (\d+) (\d+) ')
def remove_timestamp(example):
    example['text'] = HDFS1_TIMESTAMP_PATTERN.sub('', example['text'])
    return example

cleaned_dataset = dataset.map(remove_timestamp)

Loading cached processed dataset at /home/cernypro/.cache/huggingface/datasets/text/default-b1cdfca81a0e7c30/0.0.0/44d63bd03e7e554f16131765a251f2d8333a5fe8a73f6ea3de012dbc49443691/cache-c890ed1c4904eef4.arrow


## Transformer model preparation
Here we'll prepare the Transformer model classes

In [5]:
@dataclass
class EmbeddingOutput(ModelOutput):
    """
    ModelOutput class inspired per Huggingface Transformers library conventions, may be replaced by a suitable alternative class from the library if any exists.
    """
    embedding: torch.FloatTensor = None
        
class DistilBertForClsEmbedding(DistilBertPreTrainedModel):
    """
    DistilBertModel with a linear layer applied to [CLS] token.
    Initialize using .from_pretrained(path_or_model_name) method
    """
    def __init__(self, config):
        super().__init__(config)
        if config.task_specific_params is None:
            config.task_specific_params = dict()

        self.distilbert = DistilBertModel(config)
        self.cls_projector = torch.nn.Linear(config.dim, config.task_specific_params.setdefault('cls_embedding_dimension', 512))

        self.init_weights()
    
    def forward(self, input_ids, attention_mask):
        bert_output = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)
        cls_token_embedding = bert_output.last_hidden_state[:, 0]
        cls_encoding = self.cls_projector(cls_token_embedding)
        return EmbeddingOutput(embedding=cls_encoding)

Now load the model from checkpoint and prepare it's tokenizer

In [6]:
embedding_model_directory = project_base_dir / 'models' / 'LogEncoder_from_1T_Eps_1_M_basic_chunked_10_Seed-42_T-len_512_C-len_512_Tr-batch_64_Ev-b_64_O-dim_100'

encoder_model = DistilBertForClsEmbedding.from_pretrained(embedding_model_directory).to('cuda')
tokenizer = DistilBertTokenizerFast.from_pretrained(base_pretrained_model_name)  # The tokenizer must match the one used for the saved model, this model uses distilbert-base-cased tokenizer

Here we'll prepare the encode function which we will map over our dataset, which will add an embedding column to our data containing the vector embeddings for each log-line.

We will then apply this function in batches (for faster processing) as both our tokenizer and model can handle data in batched form. The batch size was chosen arbitrarily.

Our encode function takes two additional arguments which have to be passed as a dict fn_kwargs to the map function. (We could also use closures, but I find this cleaner and easier to copy into a script from a notebook environment)

See [Datasets .map documentation](https://huggingface.co/docs/datasets/processing.html#processing-data-with-map) for more info

In [10]:
def encode(examples, tokenizer, encoder):
    with torch.no_grad():
        embedding = encoder(**tokenizer(examples['text'],
                                        return_tensors='pt',
                                        truncation=True,
                                        padding=True).to('cuda')
                            ).embedding.cpu().detach().numpy().tolist()
    return {'embedding': embedding}

encoder_model.eval()
embedded_dataset = cleaned_dataset.map(encode,
                                       fn_kwargs={'tokenizer': tokenizer,
                                                  'encoder': encoder_model},
                                       batched=True,
                                       batch_size=256)

HBox(children=(FloatProgress(value=0.0, max=35356.0), HTML(value='')))




Here we can see first three embeddings alongside their original lines using the slicing notation, which Datasets supports. 
The returned object is a dict with column names as keys and lists of the column contents as values.

In [11]:
save_path = data_dir / 'processed' / dataset_name / f"embedding_from_{embedding_model_directory.stem}"
embedded_dataset.save_to_disk(save_path)

In [36]:
tokenizer.decode(tokenizer("a a", add_special_tokens=False, truncation=True, return_attention_mask=False)['input_ids'], clean_up_tokenization_spaces=True)

'a a'

In [29]:
embedded_dataset[-1]['text']

'INFO dfs.FSDataset: Deleting block blk_-7013325917247206057 file /mnt/hadoop/dfs/data/current/subdir10/blk_-7013325917247206057'