<a href="https://colab.research.google.com/github/LukasStankevicius/Random-embeddings-baseline-for-text-level-NLP-tasks/blob/main/Random_embeddings_baseline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from collections import defaultdict
from transformers import AutoTokenizer
import numpy as np

class RandomEmbeddingsBaseline:
    def __init__(self, model_name, vector_size, mu=0, sigma=0.1, random_state=0):
        self.rng = np.random.default_rng(random_state)
        self.tested_models = ['bert-base-uncased', "t5-small", "google/mt5-small", "google/byt5-small", "meta-llama/Llama-2-7b-hf"]

        if model_name == "lowersplit":
            self.word2id = {}
            self.word2id = defaultdict(lambda: len(self.word2id))
            self.tokenizer = lambda text: {'input_ids': [self.word2id[i] for i in text.lower().split(' ')]}
            self.tokenizer_kwargs = {}
            n_tokens = 750_000
        elif model_name in self.tested_models:
            use_fast = model_name in ["google/mt5-small"]
            self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=use_fast)
            self.tokenizer_kwargs = dict(add_special_tokens=False)
            n_tokens = self.tokenizer.vocab_size
        else:
          raise ValueError('Untested model')

        self.vectors = self.rng.normal(mu, sigma, (n_tokens, vector_size))
        self.vectors = self.vectors.astype('float32')

    def text_to_vector(self, text):
      token_indexes = self.tokenizer(text, **self.tokenizer_kwargs)['input_ids']
      sentence_vector = self.vectors[token_indexes].mean(axis=0) if token_indexes else np.zeros(shape=(self.vectors.shape[1]), dtype='float32')
      return sentence_vector

In [3]:
m = RandomEmbeddingsBaseline(model_name='bert-base-uncased', vector_size=48)
text = "This is the test sentence."
print(m.text_to_vector(text))

[-0.01570123  0.05268576 -0.01440263  0.01644494 -0.02822436  0.00737305
  0.00538737 -0.00791293  0.00618587 -0.05503013 -0.01529175 -0.12204976
  0.01475287 -0.02839016 -0.05140474 -0.09164638  0.00553092  0.09184373
 -0.03690748  0.02093396  0.00842    -0.04734223 -0.02477467  0.01362352
  0.0004928  -0.03422094  0.06709204 -0.02618556 -0.00178645 -0.08016231
 -0.07860541  0.01794009  0.03173523 -0.02214176 -0.01232245  0.00542088
 -0.00080929  0.05717109  0.01065253  0.02649156  0.05100659  0.08113872
 -0.01249488  0.02838389 -0.00670874  0.00496173  0.02301297 -0.03081081]
