In [5]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader, Dataset
import nmslib

In [6]:
class WordDataset(Dataset):
    def __init__(self, words):
        """
        Initializes the dataset with a list of words.
        :param words: List of words to be converted into OHE tensors.
        """
        self.words = [word.lower() for word in words if isinstance(word, str)]
        
        # Vocabulary setup
        self.vocab = "abcdefghijklmnopqrstuvwxyz"
        self.vocab_size = len(self.vocab) + 1  # +1 for unknown characters
        self.ctoi = {char: idx for idx, char in enumerate(self.vocab)}
        
        # One-Hot Encoding Matrix
        self.create_OHE()
    
    def create_OHE(self):
        """Creates a One-Hot Encoding matrix for the vocabulary."""
        self.OHE = torch.eye(self.vocab_size)  # Identity matrix for one-hot encoding
    
    def get_OHE(self, word):
        """Converts a word into a one-hot encoding tensor."""
        emb = [self.OHE[self.ctoi.get(char, self.vocab_size - 1)] for char in word]
        return torch.stack(emb) if emb else torch.zeros((1, self.vocab_size))
    
    def __len__(self):
        return len(self.words)
    
    def __getitem__(self, idx):
        """Returns One-Hot Encoding tensor for a word."""
        return self.get_OHE(self.words[idx])

def collate_fn(batch):
    """Pads sequences in a batch to the max length in the batch."""
    max_len = max(word.shape[0] for word in batch)
    vocab_size = batch[0].shape[1]
    pad_tensor = torch.zeros((max_len, vocab_size))
    
    padded_batch = [torch.cat((word, pad_tensor[:max_len - word.shape[0]]), dim=0) for word in batch]
    
    return torch.stack(padded_batch)

In [7]:
class CustomModel(nn.Module):
    def __init__(self, vocab_size=27, emb_dim=100, num_epochs=15, lr=0.001):
        super().__init__()

        self.vocab = "abcdefghijklmnopqrstuvwxyz"
        self.vocab_size = len(self.vocab) + 1
        self.ctoi = {char: idx for idx, char in enumerate(self.vocab)}

        self.num_epochs = num_epochs
        self.vocab_size = vocab_size
        self.lstm1 = nn.LSTM(input_size=self.vocab_size, hidden_size=emb_dim, batch_first=True)
        self.lstm2 = nn.LSTM(input_size=emb_dim, hidden_size=emb_dim, batch_first=True)
        self.fc = nn.Linear(1, 1)

        self.loss_fn = nn.L1Loss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.to(self.device)  

        
        self.create_OHE()

    def create_OHE(self):
        """Creates a One-Hot Encoding matrix for the vocabulary."""
        self.OHE = torch.zeros((self.vocab_size, self.vocab_size))
        for i in range(self.vocab_size):
            self.OHE[i, i] = 1

    def get_OHE(self, word):
        """Converts a word into a one-hot encoding tensor."""
        emb = [self.OHE[self.ctoi.get(char, self.vocab_size - 1)] for char in word]
        return torch.stack(emb)
    
    def fit(self, batched_data):
        for epoch in range(self.num_epochs):
            epoch_loss = 0.0
            for x1, x2, target_batch in batched_data:
                x1, x2, target_batch = x1.to(self.device), x2.to(self.device), target_batch.to(self.device)

                self.optimizer.zero_grad()

                outputs = self.forward(x1, x2)
                loss = self.loss_fn(outputs, target_batch)
                loss.backward()
                self.optimizer.step()

                epoch_loss += loss.item()

            print(f"{epoch+1}/{self.num_epochs} - Loss: {epoch_loss / len(batched_data)}")

    def get_embedding(self, x):
        out1, _ = self.lstm1(x)
        out2, (hn, _) = self.lstm2(out1)
        return hn.squeeze(0) 

    def forward(self, x1, x2):
        emb1 = self.get_embedding(x1)
        emb2 = self.get_embedding(x2)

        diff = emb1 - emb2
        squared_norm = torch.sum(diff ** 2, dim=1, keepdim=True)

        out = torch.sigmoid(self.fc(squared_norm))
        return out
    
    def save_model(self, model_name):
        torch.save(self.state_dict(), model_name)
        print(f"Model saved to {model_name}")

    def load_model(self, model_name):
        self.load_state_dict(torch.load(model_name))
        self.eval()  # Set the model to evaluation mode after loading
        print(f"Model loaded from {model_name}")

In [8]:
loaded_model = CustomModel() 
loaded_model.load_model("./saved_model_3/char2vec.pth")
loaded_model

Model loaded from ./saved_model_3/char2vec.pth


  self.load_state_dict(torch.load(model_name))


CustomModel(
  (lstm1): LSTM(27, 100, batch_first=True)
  (lstm2): LSTM(100, 100, batch_first=True)
  (fc): Linear(in_features=1, out_features=1, bias=True)
  (loss_fn): L1Loss()
)

In [9]:
# dataset = pd.read_csv("./datasets/dict.csv")
# words = dataset["word"]
# words = list(set(words.to_numpy()))
# words = [item for item in words if isinstance(item, str) and len(item)>2]

words = np.load("./datasets/text_unique_tokens.npy",allow_pickle = True) 

In [10]:
dataset = WordDataset(words)
dataloader = DataLoader(dataset, batch_size=128, collate_fn=collate_fn)

In [11]:
for index, batch in enumerate(dataloader):
    batch_embds = loaded_model.get_embedding(batch.to(loaded_model.device))
    if index == 0:
        embds = batch_embds  
    else:
        embds = torch.cat((embds, batch_embds), dim=0)  

In [12]:
embds.shape

torch.Size([8001, 100])

In [13]:
index = nmslib.init(method="hnsw", space="cosinesimil")
index.addDataPointBatch(embds.detach().cpu())
index.createIndex({'post': 2}, print_progress=False)

# Save index and word list
index.saveIndex("./saved_model_3/word_index.bin", save_data=True)
np.save("./saved_model_3/word_list.npy", words)  # Save word order

In [14]:
index = nmslib.init(method="hnsw", space="cosinesimil")
index.loadIndex("./saved_model_3/word_index.bin", load_data=True)
words = np.load("./saved_model_3/word_list.npy", allow_pickle=True)

In [28]:
query_vector = loaded_model.get_embedding(loaded_model.get_OHE("fuck").to(loaded_model.device))

In [29]:
ids, distances = index.knnQuery(query_vector.detach().cpu(), k=100)
print([words[i] for i in ids], distances)

['stuck', 'luck', 'jack', 'tack', 'dock', 'such', 'touch', 'pluck', 'thick', 'quick', 'struck', 'much', 'neck', 'rack', 'sick', 'back', 'pick', 'lack', 'stick', 'track', 'lock', 'bunch', 'unpack', 'pack', 'stock', 'lunch', 'trick', 'sunk', 'flock', 'drunk', 'frock', 'each', 'wrack', 'hum', 'shock', 'trunk', 'teach', 'lucky', 'dark', 'dusk', 'bark', 'dank', 'rank', 'smack', 'park', 'couch', 'hugh', 'ink', 'black', 'prick', 'cock', 'knock', 'pinch', 'thank', 'block', 'think', 'fetch', 'brick', 'munich', 'bank', 'frank', 'crack', 'oak', 'jug', 'which', 'sum', 'lank', 'hawk', 'pink', 'wink', 'check', 'clock', 'though', 'church', 'laugh', 'ask', 'push', 'stark', 'etc', 'weak', 'murky', 'cubic', 'talk', 'cough', 'folk', 'network', 'lucy', 'trough', 'link', 'tricky', 'march', 'hush', 'freak', 'bulky', 'porch', 'branch', 'cusack', 'shrunk', 'disk', 'wreck'] [0.46391565 0.46976966 0.4967293  0.49677145 0.49921048 0.50254285
 0.50434566 0.5050125  0.5102318  0.51305234 0.5141005  0.5143329
 0.52