In [19]:
from datasets import load_dataset
import numpy as np
from dataclasses import dataclass
from typing import List, Union, Dict
import torch
from transformers import DistilBertTokenizerFast, AutoModel
from pathlib import Path

In [2]:
hdfs1_dataset = load_dataset('text', data_files='../data/raw/HDFS1/HDFS.log', split='train')

Using custom data configuration default-f7d20bad4b8d075b
Reusing dataset text (/home/cernypro/.cache/huggingface/datasets/text/default-f7d20bad4b8d075b/0.0.0/44d63bd03e7e554f16131765a251f2d8333a5fe8a73f6ea3de012dbc49443691)


In [3]:
small_raw_dataset = hdfs1_dataset.select(range(1000))

In [4]:
def remove_timestamp(example):
    # need to find third occurence of a space and slice the string after it
    # using a very non robust silly solution
    s = example['text']
    example['text'] = s[s.find(' ', s.find(' ', s.find(' ')+1)+1)+1:]
    return example

small_cleaned_dataset = small_raw_dataset.map(remove_timestamp)

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




In [7]:
pretrained_model_name = "distilbert-base-cased"
tokenizer = DistilBertTokenizerFast.from_pretrained(pretrained_model_name)

In [8]:
class ClsEncoderTower(torch.nn.Module):
    """
    Simple model on top of a BERT like model.
    It's a linear layer on the [CLS] token of each sentence from BERT.
    """
    def __init__(self, pretrained_model_name_or_path, output_encode_dimension=512):
        super(ClsEncoderTower, self).__init__()
        self.bert = AutoModel.from_pretrained(pretrained_model_name_or_path)
        self.linear = torch.nn.Linear(self.bert.config.dim, output_encode_dimension) # self.bert.config.dim most likely 768
        
    def forward(self, input_ids, attention_mask):
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_token_embedding = bert_output[0][:, 0]
        cls_encoding = self.linear(cls_token_embedding)
        return cls_encoding
    
class OneTowerICT(torch.nn.Module):
    """
    Network for the inverse close task, uses one BERT tower for creating encodings of target and context sentences (query and document as per nomenclature of original paper)
    Uses cross entropy loss
    """
    def __init__(self, pretrained_model_name_or_path, output_encode_dimension=512):
        super(OneTowerICT, self).__init__()
        self.tower = ClsEncoderTower(pretrained_model_name_or_path, output_encode_dimension)
        self.loss_fn = torch.nn.CrossEntropyLoss()
    def forward(self, target, target_mask, context, context_mask, correct_class):
        target_cls_encode = self.tower(input_ids=target, attention_mask=target_mask)
        context_cls_encode = self.tower(input_ids=context, attention_mask=context_mask)
        
        logits = torch.matmul(target_cls_encode, context_cls_encode.transpose(-2, -1))
        loss = self.loss_fn(logits, correct_class)
        return loss, target_cls_encode, context_cls_encode

In [15]:
saved_model_dir = Path.cwd().parent / 'models' / '1T_Eps_2_Lines_8000000_T-len_512_C-len_512_Tr-batch_64_Ev-b_64_O-dim_512'
saved_model_file = saved_model_dir / 'pytorch_model.bin'

In [16]:
state_dict = torch.load(saved_model_file)

In [20]:
model = OneTowerICT(pretrained_model_name)
model.load_state_dict(state_dict)

<All keys matched successfully>

In [38]:
encoder = model.tower
encoder = encoder.to('cuda')

In [50]:
def encode(examples, tokenizer, encoder):
    return {'embedding': encoder(**tokenizer(examples['text'], return_tensors='pt', truncation=True, padding=True).to('cuda')).cpu().detach().numpy().tolist()}

small_embedded_dataset = small_cleaned_dataset.map(encode, fn_kwargs={'tokenizer': tokenizer, 'encoder': encoder}, batched=True, batch_size=128)

Loading cached processed dataset at /home/cernypro/.cache/huggingface/datasets/text/default-f7d20bad4b8d075b/0.0.0/44d63bd03e7e554f16131765a251f2d8333a5fe8a73f6ea3de012dbc49443691/cache-9d1fd67dfb165774.arrow


In [52]:
small_embedded_dataset[0]

{'embedding': [-0.07249397784471512,
  0.08337511122226715,
  0.44695860147476196,
  -0.0203128382563591,
  -0.0008805245161056519,
  0.033679693937301636,
  -0.2584163248538971,
  -0.01713814213871956,
  -0.13464701175689697,
  0.2772810161113739,
  -0.10751543939113617,
  -0.02893763780593872,
  -0.1402347981929779,
  -0.1562652587890625,
  -0.2721223533153534,
  0.07620158791542053,
  -0.14240868389606476,
  0.3091081380844116,
  -0.00977490097284317,
  -0.027673259377479553,
  0.22512134909629822,
  0.02421267330646515,
  -0.1515868902206421,
  0.02557799220085144,
  -0.059474557638168335,
  0.13023941218852997,
  -0.11966827511787415,
  -0.2382989227771759,
  -0.04261612892150879,
  0.11512461304664612,
  -0.2678832411766052,
  0.27379319071769714,
  0.47852030396461487,
  0.007744520902633667,
  0.2810533046722412,
  -0.08203402161598206,
  0.027505144476890564,
  -0.1304769665002823,
  -0.06472301483154297,
  0.19899003207683563,
  -0.08140881359577179,
  -0.07877187430858612,
 