In [17]:
import torch
import torch.nn as nn

from torch.utils.data import Dataset

import numpy as np

import re

import linecache

from tqdm import tqdm

In [2]:
device = torch.device("cpu")

In [195]:
class SimpleGRU(nn.Module):
    def __init__(self, hidden_size, embedding_path, vocab_path, num_layers=4):
        super(SimpleGRU, self).__init__()

        self.num_layers = num_layers
        self.hidden_size = hidden_size

        word_embeddings = np.load(embedding_path)
        self.vocab = np.load(vocab_path)

        self.embedding_layer = nn.Embedding.from_pretrained(torch.from_numpy(word_embeddings).float())
        self.gru = nn.GRU(hidden_size, hidden_size, num_layers=num_layers)

        self.pool = nn.AdaptiveMaxPool1d(1)
        self.sigmoid = nn.Sigmoid()


    def forward(self, x, hidden):
        embeddings = self.embedding_layer(x)

        final_output = torch.zeros_like(embeddings)
        
        for i in range(embeddings.shape[1]):
            output, hidden = self.gru(embeddings[:,i,:].view(1, 1, -1), hidden)
            final_output[:,i,:] = output
        
        final_output = self.pool(final_output)
        final_output = self.sigmoid(final_output)

        return final_output.squeeze(2)


    def init_hidden(self):
        return torch.zeros(self.num_layers, 1, self.hidden_size, device=device)


    def get_indices(self, string):
        indices = []

        string_arr = string.split(" ")

        for word in string_arr:
            if np.where(self.vocab == word)[0].shape[0] != 0:
                indices += [np.where(self.vocab == word)[0]]
            else:
                indices += [np.array([0])]
        
        return torch.tensor(indices).squeeze(1)

In [234]:
class EntityDataset(Dataset):
    # overriden methods
    def __init__(self, file_path, model):
        self.file_path = file_path

        self.model = model


    def __len__(self):
        with open(self.file_path, "rbU") as f:
            num_lines = sum(1 for _ in f)
        
        # don't count first line
        return num_lines - 1


    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        particular_line = linecache.getline(self.file_path, idx+1)
        cleaned_sample = self.clean_up(particular_line)

        input_sentence = cleaned_sample[4]
        entity1 = cleaned_sample[1]
        entity2 = cleaned_sample[2]

        sentence_arr = self.model.get_indices(input_sentence)
        entity1_arr = self.model.get_indices(entity1)
        entity2_arr = self.model.get_indices(entity2)

        '''labels = torch.zeros_like(sentence_arr)

        for _, word_index in enumerate(sentence_arr):
            for entity_index in entity1_arr:
                if entity_index == word_index:
                    labels[_] = 1
            for entity_index in entity2_arr:
                if entity_index == word_index:
                    labels[_] = 1'''

        labels1 = self.generate_labels(sentence_arr, entity1_arr)
        labels2 = self.generate_labels(sentence_arr, entity2_arr)

        labels = labels1 | labels2

        return sentence_arr, labels.type(torch.float32)
    

    # first instance of entity in sentence
    def generate_labels(self, sentence_arr, entity_arr):
        correct = []

        for _, token in enumerate(sentence_arr):
            if _ < sentence_arr.shape[0] - entity_arr.shape[0]:
                if token == entity_arr[0]:
                    not_equal = False
                    for i, val in enumerate(entity_arr):
                        if sentence_arr[_+i] != val:
                            not_equal = True
                    if not not_equal:
                        correct += [_]
            else:
                break
        
        labels = torch.zeros_like(sentence_arr)
        for i in correct:
            labels[i:i+entity_arr.shape[0]] = 1
        
        return labels
        

    # helper
    def clean_up(self, line):
        remove_chars = ["<e1>", "</e1>", "<e2>", "</e2>"]

        line = line.strip()

        for char in remove_chars:
            line = line.replace(char, "")
        
        # string clean up
        line = re.sub(r'[^a-zA-Z1-9\s]', '', line)
        line = re.sub(' +', ' ', line)
        line = line.lower()

        line_data = line.split("\t")
        
        return line_data

In [235]:
test_string = "hello how are you"

model = SimpleGRU(50, "utils/embs_npa.npy", "utils/vocab_npa.npy")

indices = model.get_indices(test_string)

print(indices.shape)

start_hidden = model.init_hidden()
output = model(indices.unsqueeze(0), start_hidden)

print(output.shape)

torch.Size([4])
torch.Size([1, 4])


In [236]:
batch_size = 1

dataset = EntityDataset('data/en_corpora_test.txt', model)

train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

  with open(self.file_path, "rbU") as f:


In [237]:
sample = next(iter(train_loader))

print(sample)

  with open(self.file_path, "rbU") as f:


[tensor([[189584,     17,     31,   2490,   1139,    777,   1349,      8,  48947,
            747,      8,    949,      2,    777,   5711,      5,   1486,  13356,
           6639,   3824,    281,  44075,   7302,      7,    686,      0,   4684,
           6851,   6639]]), tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0.]])]


In [238]:
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [239]:
EPOCHS = 10

for epoch in range(EPOCHS):
    for _, sample in tqdm(enumerate(train_loader), total=len(train_loader)):
        optimizer.zero_grad()

        input = sample[0]
        labels = sample[1]

        start_hidden = model.init_hidden()
        output = model(input, start_hidden)

        loss = criterion(output, labels)

        loss.backward()

        optimizer.step()

        if _ % 50 == 0:
            print(loss.item())

  with open(self.file_path, "rbU") as f:
  0%|          | 1/5461 [00:00<27:15,  3.34it/s]

0.8746582269668579


  1%|          | 51/5461 [00:15<24:24,  3.69it/s]

0.7528480887413025


  2%|▏         | 101/5461 [00:31<26:05,  3.42it/s]

0.7176336050033569


  3%|▎         | 151/5461 [00:45<21:57,  4.03it/s]

0.7151199579238892


  4%|▎         | 201/5461 [01:00<31:06,  2.82it/s]

0.6894572973251343


  5%|▍         | 251/5461 [01:15<23:27,  3.70it/s]

0.6850866675376892


  6%|▌         | 302/5461 [01:33<24:23,  3.52it/s]

0.6643531918525696


  6%|▋         | 351/5461 [01:49<30:26,  2.80it/s]

0.6368778347969055


  7%|▋         | 401/5461 [02:04<27:39,  3.05it/s]

0.5844876170158386


  8%|▊         | 451/5461 [02:20<24:52,  3.36it/s]

0.5559438467025757


  9%|▉         | 501/5461 [02:35<28:34,  2.89it/s]

0.5631774067878723


 10%|█         | 551/5461 [02:52<20:28,  4.00it/s]

0.4965883493423462


 11%|█         | 601/5461 [03:08<20:05,  4.03it/s]

0.47169849276542664


 12%|█▏        | 651/5461 [03:24<29:41,  2.70it/s]

0.4506806433200836


 13%|█▎        | 701/5461 [03:40<29:28,  2.69it/s]

0.5040118098258972


 14%|█▍        | 751/5461 [03:55<23:16,  3.37it/s]

0.4600246548652649


 15%|█▍        | 801/5461 [04:11<25:42,  3.02it/s]

0.5288750529289246


 16%|█▌        | 851/5461 [04:27<24:21,  3.15it/s]

0.41640231013298035


 16%|█▋        | 901/5461 [04:41<20:58,  3.62it/s]

0.43765848875045776


 17%|█▋        | 951/5461 [04:57<24:15,  3.10it/s]

0.4671437442302704


 18%|█▊        | 1002/5461 [05:13<22:11,  3.35it/s]

0.46147453784942627


 19%|█▉        | 1051/5461 [05:28<18:02,  4.08it/s]

0.4397002160549164


 20%|██        | 1101/5461 [05:44<30:25,  2.39it/s]

0.3948560953140259


 21%|██        | 1151/5461 [05:59<19:17,  3.72it/s]

0.44869381189346313


 22%|██▏       | 1201/5461 [06:15<20:19,  3.49it/s]

0.5957033038139343


 23%|██▎       | 1251/5461 [06:33<21:20,  3.29it/s]

0.4940255284309387


 24%|██▍       | 1301/5461 [06:48<21:41,  3.20it/s]

0.5244684815406799


 24%|██▍       | 1321/5461 [06:54<19:08,  3.61it/s]