In [39]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader, Dataset
import nmslib

In [40]:
class WordDataset(Dataset):
    def __init__(self, words):
        """
        Initializes the dataset with a list of words.
        :param words: List of words to be converted into OHE tensors.
        """
        self.words = [word.lower() for word in words if isinstance(word, str)]
        
        # Vocabulary setup
        self.vocab = "abcdefghijklmnopqrstuvwxyz"
        self.vocab_size = len(self.vocab) + 1  # +1 for unknown characters
        self.ctoi = {char: idx for idx, char in enumerate(self.vocab)}
        
        # One-Hot Encoding Matrix
        self.create_OHE()
    
    def create_OHE(self):
        """Creates a One-Hot Encoding matrix for the vocabulary."""
        self.OHE = torch.eye(self.vocab_size)  # Identity matrix for one-hot encoding
    
    def get_OHE(self, word):
        """Converts a word into a one-hot encoding tensor."""
        emb = [self.OHE[self.ctoi.get(char, self.vocab_size - 1)] for char in word]
        return torch.stack(emb) if emb else torch.zeros((1, self.vocab_size))
    
    def __len__(self):
        return len(self.words)
    
    def __getitem__(self, idx):
        """Returns One-Hot Encoding tensor for a word."""
        return self.get_OHE(self.words[idx])

def collate_fn(batch):
    """Pads sequences in a batch to the max length in the batch."""
    max_len = max(word.shape[0] for word in batch)
    vocab_size = batch[0].shape[1]
    pad_tensor = torch.zeros((max_len, vocab_size))
    
    padded_batch = [torch.cat((word, pad_tensor[:max_len - word.shape[0]]), dim=0) for word in batch]
    
    return torch.stack(padded_batch)

In [33]:
class CustomModel(nn.Module):
    def __init__(self, vocab_size=27, emb_dim=300, num_epochs=15, lr=0.001):
        super().__init__()

        self.vocab = "abcdefghijklmnopqrstuvwxyz"
        self.vocab_size = len(self.vocab) + 1
        self.ctoi = {char: idx for idx, char in enumerate(self.vocab)}

        self.num_epochs = num_epochs
        self.vocab_size = vocab_size
        self.lstm1 = nn.LSTM(input_size=self.vocab_size, hidden_size=emb_dim, batch_first=True)
        self.lstm2 = nn.LSTM(input_size=emb_dim, hidden_size=emb_dim, batch_first=True)
        self.fc = nn.Linear(1, 1)

        self.loss_fn = nn.L1Loss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.to(self.device)  

        
        self.create_OHE()

    def create_OHE(self):
        """Creates a One-Hot Encoding matrix for the vocabulary."""
        self.OHE = torch.zeros((self.vocab_size, self.vocab_size))
        for i in range(self.vocab_size):
            self.OHE[i, i] = 1

    def get_OHE(self, word):
        """Converts a word into a one-hot encoding tensor."""
        emb = [self.OHE[self.ctoi.get(char, self.vocab_size - 1)] for char in word]
        return torch.stack(emb)
    
    def fit(self, batched_data):
        for epoch in range(self.num_epochs):
            epoch_loss = 0.0
            for x1, x2, target_batch in batched_data:
                x1, x2, target_batch = x1.to(self.device), x2.to(self.device), target_batch.to(self.device)

                self.optimizer.zero_grad()

                outputs = self.forward(x1, x2)
                loss = self.loss_fn(outputs, target_batch)
                loss.backward()
                self.optimizer.step()

                epoch_loss += loss.item()

            print(f"{epoch+1}/{self.num_epochs} - Loss: {epoch_loss / len(batched_data)}")

    def get_embedding(self, x):
        out1, _ = self.lstm1(x)
        out2, (hn, _) = self.lstm2(out1)
        return hn.squeeze(0) 

    def forward(self, x1, x2):
        emb1 = self.get_embedding(x1)
        emb2 = self.get_embedding(x2)

        diff = emb1 - emb2
        squared_norm = torch.sum(diff ** 2, dim=1, keepdim=True)

        out = torch.sigmoid(self.fc(squared_norm))
        return out
    
    def save_model(self, model_name):
        torch.save(self.state_dict(), model_name)
        print(f"Model saved to {model_name}")

    def load_model(self, model_name):
        self.load_state_dict(torch.load(model_name))
        self.eval()  # Set the model to evaluation mode after loading
        print(f"Model loaded from {model_name}")

In [41]:
loaded_model = CustomModel() 
loaded_model.load_model("./saved_model/char2vec.pth")
loaded_model

Model loaded from ./saved_model/char2vec.pth


  self.load_state_dict(torch.load(model_name))


CustomModel(
  (lstm1): LSTM(27, 300, batch_first=True)
  (lstm2): LSTM(300, 300, batch_first=True)
  (fc): Linear(in_features=1, out_features=1, bias=True)
  (loss_fn): L1Loss()
)

In [42]:
words = ['cattle', 'beautiful', 'input', 'daring', 'predict', 'giraffe', 'blend', 'simplify', 'knack', 'parent', 'elevate', 'incentive', 'gather', 'clumsy', 'zoom', 'aluminum', 'example', 'town', 'clerk', 'work', 'disease', 'solve', 'change', 'flavor', 'vase', 'attach', 'assemble', 'interest', 'react', 'wince', 'show', 'sensitive', 'mount', 'luck', 'wager', 'pretty', 'aspect', 'camera', 'zodiac', 'zinc', 'explore', 'mile', 'grid', 'poor', 'violent', 'absorb', 'desire', 'zero', 'charitable', 'tolerate', 'doubt', 'guitar', 'enrich', 'clamp', 'across', 'active', 'friend', 'feature', 'never', 'blanket', 'middle', 'flood', 'kitchen', 'candidate', 'cease', 'leads', 'bravery', 'rescue', 'outdoor', 'update', 'trust', 'neutral', 'mobile', 'roam', 'dance', 'bitter', 'energy', 'together', 'discover', 'couch', 'link', 'model', 'finance', 'citizen', 'horizon', 'cliff', 'ability', 'cool', 'latter', 'exhibit', 'relate', 'crisis', 'enemy', 'arrange', 'project', 'rate', 'zone', 'visit', 'combine', 'athlete', 'clock', 'hurdle', 'banker', 'brisk', 'stranger', 'buy', 'skip', 'assess', 'carbon', 'console', 'alike', 'waste', 'expect', 'emerge', 'drink', 'house', 'false', 'death', 'youth', 'rival', 'explode', 'wonder', 'heart', 'myth', 'tension', 'charm', 'expose', 'result', 'stabilize', 'synthesis', 'jumpy', 'excel', 'community', 'combat', 'error', 'bake', 'return', 'record', 'accept', 'meadow', 'create', 'bottle', 'favor', 'private', 'actor', 'demand', 'dear', 'chamber', 'warn', 'puppet', 'academy', 'collect', 'store', 'worse', 'variable', 'ride', 'human', 'excited', 'text', 'vote', 'debt', 'yearly', 'natural', 'improve', 'imagine', 'occupy', 'stuck', 'flame', 'forward', 'society', 'your', 'jump', 'mental', 'volume', 'advance', 'lacks', 'yell', 'judge', 'cage', 'hospital', 'billion', 'barrier', 'primary', 'season', 'clean', 'spirit', 'grace', 'crown', 'leader', 'court', 'winter', 'formula', 'run', 'stiff', 'shelter', 'depth', 'peach', 'resource', 'engage', 'mention', 'better', 'dangerous', 'review', 'obtain', 'noble', 'clue', 'truth', 'stare', 'attack', 'student', 'health', 'agency', 'literary', 'average', 'wish', 'action', 'accuse', 'abuse', 'yellow', 'common', 'bicycle', 'habit', 'smell', 'impress', 'quality', 'profit', 'trade', 'mild', 'clash', 'place', 'wheel', 'listen', 'fitness', 'eternal', 'order', 'fluent', 'tourist', 'picture', 'shine', 'fishing', 'social', 'stress', 'whole', 'bargain', 'bizarre', 'chill', 'survive', 'calibrate', 'biology', 'appliance', 'foolish', 'kite', 'cloud', 'nature', 'tender', 'image', 'exact', 'vacation', 'reduce', 'universe', 'amount', 'experiment', 'limit', 'humor', 'front', 'apply', 'anxiety', 'artist', 'bricks', 'post', 'belief', 'journal', 'public', 'noir', 'fantastic', 'garden', 'future', 'yoga', 'select', 'object', 'effect', 'toll', 'element', 'vary', 'diverse', 'approach', 'allege', 'navigate', 'start', 'begin', 'morning', 'choice', 'approve', 'smart', 'elate', 'airport', 'angry', 'balloon', 'single', 'justify', 'fruit', 'vast', 'further', 'become', 'support', 'school', 'mind', 'compete', 'government', 'anniversary', 'notable', 'agree', 'utility', 'essay', 'control', 'position', 'value', 'voice', 'class', 'custom', 'issue', 'mature', 'symbol', 'lemon', 'joke', 'doubtful', 'blink', 'council', 'snow', 'restore', 'edgy', 'group', 'honor', 'allergy', 'tackle', 'plane', 'piano', 'adapt', 'earth', 'damage', 'neat', 'believe', 'transport', 'allegiance', 'option', 'danger', 'test', 'cherry', 'study', 'welcome', 'reason', 'total', 'trouble', 'fact', 'commerce', 'jolly', 'vision', 'program', 'system', 'enforce', 'cheap', 'vivid', 'mansion', 'table', 'found', 'knight', 'yarn', 'aware', 'advice', 'verify', 'cancer', 'company', 'commodity', 'lives', 'impact', 'account', 'outcome', 'respect', 'mob', 'runway', 'architecture', 'careful', 'outline', 'trick', 'chance', 'silent', 'badly', 'workshop', 'address', 'affect', 'life', 'focus', 'today', 'celebrate', 'weary', 'entry', 'near', 'excuse', 'decade', 'unit', 'quote', 'allies', 'notice', 'water', 'other', 'wave', 'argue', 'clarify', 'classic', 'journey', 'calm', 'guard', 'sense', 'beyond', 'migrate', 'applause', 'beauty', 'accident', 'brother', 'alter', 'threat', 'reach', 'duty', 'bachelor', 'mice', 'brown', 'plenty', 'absence', 'union', 'price', 'roar', 'deficit', 'move', 'policy', 'joyful', 'density', 'banish', 'virus', 'river', 'taste', 'climate', 'event', 'king', 'equal', 'moment', 'target', 'cube', 'yes', 'logic', 'gesture', 'true', 'stock', 'divide', 'western', 'laugh', 'minute', 'celestial', 'rain', 'open', 'moral', 'devil', 'expand', 'mail', 'prove', 'type', 'access', 'remove', 'budget', 'peace', 'witness', 'asset', 'delicate', 'pulse', 'charter', 'point', 'inspire', 'anger', 'difference', 'grate', 'freedom', 'comet', 'station', 'current', 'juror', 'already', 'tooth', 'personal', 'outer', 'partner', 'suffer', 'ideal', 'knee', 'older', 'color', 'close', 'tempt', 'walk', 'obese', 'mourn', 'assault', 'worker', 'harsh', 'scare', 'adopt', 'ready', 'dominate', 'happen', 'absent', 'privacy', 'ignore', 'magic', 'method', 'scale', 'settle', 'host', 'fashion', 'debate', 'gender', 'tonight', 'breathe', 'guilt', 'cycle', 'moat', 'vibe', 'teeth', 'unite', 'against', 'analysis', 'chicken', 'bother', 'mask', 'global', 'trend', 'analyze', 'nail', 'provoke', 'cabin', 'assistant', 'unique', 'rosy', 'slope', 'street', 'ancient', 'abandon', 'step', 'sick', 'deal', 'yesterday', 'array', 'march', 'defend', 'dungeon', 'pattern', 'round', 'majority', 'motel', 'smooth', 'adult', 'appeal', 'sound', 'bring', 'recall', 'ladder', 'culture', 'design', 'news', 'dialogue', 'oxygen', 'alert', 'almost', 'torch', 'whisper', 'hike', 'separate', 'learn', 'market', 'label', 'catch', 'final', 'dreamer', 'handle', 'buyer', 'brief', 'forecast', 'cousin', 'acquire', 'second', 'query', 'agile', 'knock', 'power', 'advantage', 'fine', 'lender', 'pet', 'capture', 'orange', 'planet', 'june', 'fetch', 'flock', 'zebra', 'brand', 'occasions', 'team', 'kiss']

In [43]:
dataset = WordDataset(words)
dataloader = DataLoader(dataset, batch_size=128, collate_fn=collate_fn)

In [44]:
for index, batch in enumerate(dataloader):
    batch_embds = loaded_model.get_embedding(batch.to(loaded_model.device))
    if index == 0:
        embds = batch_embds  
    else:
        embds = torch.cat((embds, batch_embds), dim=0)  

In [45]:
embds.shape

torch.Size([606, 300])

In [46]:
index = nmslib.init(method="hnsw", space="cosinesimil")
index.addDataPointBatch(embds.detach().cpu())
index.createIndex({'post': 2}, print_progress=False)

# Save index and word list
index.saveIndex("./saved_model/word_index.bin", save_data=True)
np.save("./saved_model/word_list.npy", words)  # Save word order


In [47]:
index = nmslib.init(method="hnsw", space="cosinesimil")
index.loadIndex("./saved_model/word_index.bin", load_data=True)
words = np.load("./saved_model/word_list.npy", allow_pickle=True)

In [48]:
query_vector = loaded_model.get_embedding(loaded_model.get_OHE("acquery").to(loaded_model.device))

In [49]:
ids, distances = index.knnQuery(query_vector.detach().cpu(), k=3)
print([words[i] for i in ids], distances)

['acquire', 'academy', 'query'] [0.27343154 0.31734115 0.31778264]
