In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip -q install accelerate
!pip -q install bitsandbytes
!pip -q install huggingface_hub

In [None]:
!huggingface-cli login

In [None]:
%cd /content/drive/MyDrive/DSD

In [None]:
import gc
import numpy as np
from pickle import dump
from tqdm.notebook import tqdm
from sklearn.decomposition import FastICA

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel

In [None]:
class CFG:
    model_name = "solidrust/dolphin-2.9.1-llama-3-8b-AWQ"
    max_len = 4
    batch_size = 32
    num_workers = 12
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
loaded_words = []
with open('reddit_l2_w2v_words.txt', 'r') as f:
    for line in f:
        word = line.strip()
        loaded_words.append(word)

In [None]:
class WordsDataset(Dataset):
    def __init__(self, words, tokenizer):
        self.words = words
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.words)

    def __getitem__(self, idx):
        word = self.words[idx]
        inputs = self.tokenizer(word, padding='max_length', max_length = CFG.max_len, truncation = True)
        return {"input_ids": torch.tensor(inputs.input_ids, dtype=torch.long), "attention_mask": torch.tensor(inputs.attention_mask, dtype=torch.long)}

In [None]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

class Model(torch.nn.Module):
    def __init__(self, model_name):
        super(Model, self).__init__()
        self.encoder = AutoModel.from_pretrained(model_name, load_in_8bit = True)

    def forward(self, input_ids, attention_mask):
        model_output = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        sentence_embeddings = mean_pooling(model_output, attention_mask)
        return sentence_embeddings

lm_model = Model(CFG.model_name).to(CFG.device)

tokenizer = AutoTokenizer.from_pretrained(CFG.model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [None]:
def get_word_embeddings(words, model):
    model.eval()
    words_dataset = WordsDataset(words, tokenizer)
    words_dataloader = DataLoader(words_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers = CFG.num_workers)
    embeddings = []
    for batch in tqdm(words_dataloader, total = len(words_dataloader)):
        input_ids = torch.squeeze(batch['input_ids'].to(CFG.device), axis = 1)
        attention_mask = torch.squeeze(batch['attention_mask'].to(CFG.device), axis = 1)
        batch_embeddings = model(input_ids, attention_mask).detach().cpu().numpy()
        embeddings.append(batch_embeddings)
    # concatenate the embeddings into a single numpy array
    embeddings = np.concatenate(embeddings, axis=0)
    del words_dataset, words_dataloader
    gc.collect()
    return embeddings

word_embeddings = get_word_embeddings(loaded_words, lm_model)

In [None]:
word_embeddings_dict = {word: embedding for word, embedding in zip(loaded_words, word_embeddings)}
with open(f"word-embeddings/{CFG.model_name.replace('/','_')}_word_embeddings_reddit-l2.pkl", 'wb') as f:
    dump(word_embeddings_dict, f)