In [2]:
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
from transformers import LongformerTokenizer, LongformerModel, LongformerForMaskedLM
from typing import List, Tuple
import numpy as np
from tqdm import tqdm

In [25]:
class SparseModel:
  def __init__(self,model_path,max_length=512):
    self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    self.model = LongformerForMaskedLM.from_pretrained(model_path).to(self.device)
    self.max_length = max_length
    self.sparse_tokenizer = LongformerTokenizer.from_pretrained(model_path,truncation=True, padding='max_length', max_length=self.max_length)
    

  def decode_sparse_dict(self, sparse_dict,trim=None):
    a = np.zeros((30522))
    a[sparse_dict['indices']] = sparse_dict['values']
    if trim is not None:
      a[a.argsort()[:-trim]] = 0
    return a

  def decode_sparse_dicts(self, sparse_dicts,trim=None):
    res = []
    for _ in sparse_dicts:
      res.append(self.decode_sparse_dict(_,trim).tolist())
    return res

  def formalize(self, sparse_dict):
    idx2token = {idx: token for token, idx in self.sparse_tokenizer.get_vocab().items()}
    sparse_dict_tokens = {
        idx2token[idx]: weight for idx, weight in zip(sparse_dict['indices'], sparse_dict['values'])
    }
    sparse_dict_tokens = {
        k: v for k, v in sorted(
            sparse_dict_tokens.items(),
            key=lambda item: item[1],
            reverse=True
        )
    }
    return sparse_dict_tokens

  # This function is used to encode a list of texts into sparse vectors. Process all texts at once
  # Use this function if you have a GPU. Padding is used to make all texts have the same length
  # Faster in GPU , Slower in CPU
  def encode_texts(self, texts:List[str]):
    input_ids = self.sparse_tokenizer(texts, return_tensors='pt', padding='max_length' if len(texts)>1 else False).to(self.device)
    input_ids = {k: v[:self.max_length] for k, v in input_ids.items()}

    with torch.no_grad():
      logits = self.model(**input_ids).logits

    sparse_vecs = torch.max(
        torch.log(
            1+torch.relu(logits)
        )*input_ids['attention_mask'].unsqueeze(-1),
    dim=1)[0].cpu()

    sparse_dicts = []
    for sparse_vec in sparse_vecs:
      indices = sparse_vec.nonzero().squeeze().tolist()
      values = sparse_vec[indices].tolist()
      sparse_dict = {'indices': indices, 'values': values}
      sparse_dicts.append(sparse_dict)

    return sparse_dicts
  
  # This function is used to encode a single text into a sparse vector
  def encode_text(self, text:str):
    return self.encode_texts([text])[0]

  # This function is used to encode a list of texts into sparse vectors (It iterates over one text at a time)
  # Faster in CPU , Slower in GPU. Maintain batch_size=1
  def encode_text_list(self, texts:list, batch_size:int=1):

    sparse_dicts = []

    for i in tqdm(range(0, len(texts), batch_size)):
      batch = texts[i:i+batch_size]
      sparse_dicts += self.encode_texts(batch)

    return sparse_dicts

In [None]:
sample_model_path = "/Users/debasmitroy/Desktop/dbfte/longformer-base-4096"
sparse_model = SparseModel(sample_model_path,max_length=4096)

# sample_tokenizer = LongformerTokenizer.from_pretrained(sample_model_path,local_files_only=True)
# sample_model = LongformerForMaskedLM.from_pretrained(sample_model_path,local_files_only=True)

In [22]:
# sample_input = sample_tokenizer("Hello, my dog is cute", return_tensors="pt")

In [23]:
# with torch.no_grad():
#     sample_logits = sample_model(**sample_input)

In [24]:
# sample_logits.logits.shape

In [28]:
sample_emb = sparse_model.encode_text("Hello, my dog is cute")

In [30]:
sparse_model.formalize(sample_emb)

{'Ġis': 3.4714324474334717,
 ',': 3.463989019393921,
 'Ġdog': 3.4164011478424072,
 'Ġmy': 3.3851404190063477,
 '<s>': 3.2049334049224854,
 'Ġcute': 3.13922381401062,
 'Ġyour': 3.1321141719818115,
 'Hello': 3.0231399536132812,
 'Ġdogs': 2.9717912673950195,
 'Ġwas': 2.9472768306732178,
 'ĠDog': 2.946162700653076,
 'Ġpet': 2.935741424560547,
 'ĠMy': 2.9351789951324463,
 'Ġpuppy': 2.929840087890625,
 'is': 2.9272842407226562,
 'Ġare': 2.9129085540771484,
 'Ġhas': 2.909193992614746,
 'Ġbrother': 2.903501033782959,
 'Ġlooks': 2.903412103652954,
 'ĠIs': 2.902571439743042,
 'ĠIS': 2.897960901260376,
 'Ġa': 2.8849568367004395,
 'Ġam': 2.8732097148895264,
 'Ġhusband': 2.869065999984741,
 'Ġand': 2.8671536445617676,
 'Ġhis': 2.8611834049224854,
 'Hi': 2.842146873474121,
 "'s": 2.8371596336364746,
 'Ġguy': 2.834458351135254,
 'Ġthe': 2.829115629196167,
 'Ġdoes': 2.827543020248413,
 'ĠMY': 2.820107936859131,
 'Ġes': 2.8191583156585693,
 'Ġ,': 2.8188655376434326,
 'Ġdaughter': 2.8031203746795654,
 '