In [None]:
import os
import re
import torch
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from collections import Counter
from nltk.tokenize import word_tokenize

In [None]:
import nltk
nltk.download('punkt')

### Function to save and load files

In [None]:
def save_file(name, obj):
    """
    Function to save an object as pickle file
    """
    with open(name, 'wb') as f:
        pickle.dump(obj, f)


def load_file(name):
    """
    Function to load a pickle object
    """
    return pickle.load(open(name, "rb"))

# Data Processing

In [None]:
tokens_path = "Output/tokens.pkl"
file_path = "Input/complaints.csv"
col_name = "Consumer complaint narrative"

In [None]:
data = pd.read_csv(file_path)

In [None]:
data.shape

### Drop missing values

In [None]:
data.dropna(subset=[col_name], inplace=True)

In [None]:
data.shape

In [None]:
input_text = data[col_name]

### Convert text to lower case

In [None]:
input_text = [i.lower() for i in tqdm(input_text)]

### Remove punctuations except apostrophe

In [None]:
input_text = [re.sub(r"[^\w\d'\s]+", " ", i) for i in tqdm(input_text)]

### Remove digits

In [None]:
input_text = [re.sub("\d+", "", i) for i in tqdm(input_text)]

### Remove 'xxxx' in text

In [None]:
input_text = [re.sub(r'[x]{2,}', "", i) for i in tqdm(input_text)]

### Remove additional spaces

In [None]:
input_text = [re.sub(' +', ' ', i) for i in tqdm(input_text)]

### Tokenize the text

In [None]:
tokens = [word_tokenize(t) for t in tqdm(input_text[:100])]

### Save tokens

In [None]:
save_file(tokens_path, tokens)

# Data loader

In [None]:
k = 10
t = 1e-5
context_window = 5

In [None]:
class SkipGramDataset(torch.utils.data.Dataset):

    def __init__(self, input_data, context_window=5, out_path="Output",
                 t=1e-5, k=10):
        # Get word count
        self.k = k
        self.context_window = context_window
        print("Counting word tokens...")
        counter = Counter([t for d in tqdm(input_data) for t in d])
        self.vocab_count = len(counter)
        print(f"Unique words in the corpus: {self.vocab_count}")
        print("Creating data samples...")
        self.samples = self.positive_samples(input_data)
        word2idx = dict()
        idx2word = dict()
        sampling_prob = []
        print("Generating vocabulary...")
        for i, c in enumerate(counter.most_common(len(counter))):
            word2idx[c[0]] = i
            idx2word[i] = c[0]
            sampling_prob.append(c[1])
        self.word2idx = word2idx
        self.idx2word = idx2word
        print("Calculating sampling probabilities...")
        sampling_prob = np.sqrt(t/np.array(sampling_prob))
        sampling_prob = sampling_prob / np.sum(sampling_prob)
        self.sampling_prob = sampling_prob
        print("Saving files...")
        self.save_files(out_path)

    def __len__(self):
        return self.samples.shape[0]

    def __getitem__(self, idx):
        neg_words = self.negative_samples()
        center_word = self.word2idx[self.samples.loc[idx, "center_word"]]
        context_word = self.word2idx[self.samples.loc[idx, "context_word"]]
        return torch.tensor(center_word), torch.tensor([context_word]+neg_words)

    def positive_samples(self, input_data):
        samples = []
        cw = self.context_window
        for data in tqdm(input_data):
            text = [None] * cw + data + [None] * cw
            for i in range(cw, len(text) - cw):
                samples.append((text[i], text[i - cw:i] + text[i + 1: i + cw + 1]))
        samples = pd.DataFrame(samples, columns=["center_word", "context_word"])
        samples = samples.explode("context_word")
        samples.dropna(inplace=True)
        samples.reset_index(drop=True, inplace=True)
        return samples

    def negative_samples(self):
        neg_words = list(np.random.choice(np.arange(self.vocab_count), self.k,
                                          p=self.sampling_prob))
        return neg_words

    def save_files(self, out_path="Output"):
        save_file(os.path.join(out_path, "word2idx.pkl"), self.word2idx)
        save_file(os.path.join(out_path, "idx2word.pkl"), self.idx2word)

# Skip-Gram Model

In [None]:
embedding_size = 64

In [None]:
class SkipGram(nn.Module):

    def __init__(self, vocab_len, embedding_size=64):
        super(SkipGram, self).__init__()
        self.embeddings = nn.Embedding(vocab_len, embedding_size)
        self.weights = torch.empty(embedding_size, vocab_len, requires_grad=True).type(torch.FloatTensor)
        _ = torch.nn.init.normal_(self.weights)
        self.out = nn.LogSigmoid()

    def forward(self, center_word, context_words):
        embeddings_ = self.embeddings(center_word)
        weights_ = self.weights[:, context_words]
        output = torch.einsum('bi,ibo->bo', embeddings_, weights_)
        true_y = torch.zeros(output.shape[0], dtype=torch.int64)
        return self.out(output), true_y

    def save_files(self, out_path="Output"):
        save_file(os.path.join(out_path, "emb.pkl"), self.embeddings)
        save_file(os.path.join(out_path, "weights.pkl"), self.weights)

# Training

In [None]:
k = 10
lr = 0.01
num_epochs = 2
batch_size = 128
context_window = 5
out_path = "Output"

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def train_sg(dataloader, model, criterion, optimizer, device, num_epochs):
    model.train()
    best_loss = 1e8
    patience = 0
    for i in range(num_epochs):
        epoch_loss = []
        print(f"Epoch {i+1} of {num_epochs}")
        for center_word, context_words in tqdm(dataloader):
            center_word = center_word.to(device)
            context_words = context_words.to(device)
            output, true_y = model(center_word, context_words)
            loss = criterion(output, true_y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss.append(loss.item())
        epoch_loss = np.mean(epoch_loss)
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            patience = 0
        else:
            patience += 1
        print(f"Loss: {epoch_loss}")
        if patience == 5:
            print("Early stopping...")
    model.save_files()

In [None]:
dataset = SkipGramDataset(input_data=tokens, 
                          context_window=context_window,
                          out_path=out_path, 
                          t=t, k=k)

In [None]:
dataloader = torch.utils.data.DataLoader(dataset, 
                                         batch_size=batch_size,
                                         shuffle=True, 
                                         drop_last=True)

In [None]:
model = SkipGram(dataset.vocab_count, embedding_size=embedding_size)

In [None]:
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

In [None]:
train_sg(dataloader, model, criterion, optimizer, device, num_epochs)

# Using embedings to get word vectors

In [None]:
word2idx = load_file("Output/word2idx.pkl")

In [None]:
word2idx["payments"]

In [None]:
embeddings = load_file("Output/emb.pkl")

In [None]:
embeddings(torch.tensor(83))