In [17]:
import torch
import torch.nn as nn

from torch.utils.data import Dataset

import numpy as np

import re

import linecache

from tqdm import tqdm

In [2]:
device = torch.device("cpu")

In [195]:
class SimpleGRU(nn.Module):
    def __init__(self, hidden_size, embedding_path, vocab_path, num_layers=4):
        super(SimpleGRU, self).__init__()

        self.num_layers = num_layers
        self.hidden_size = hidden_size

        word_embeddings = np.load(embedding_path)
        self.vocab = np.load(vocab_path)

        self.embedding_layer = nn.Embedding.from_pretrained(torch.from_numpy(word_embeddings).float())
        self.gru = nn.GRU(hidden_size, hidden_size, num_layers=num_layers)

        self.pool = nn.AdaptiveMaxPool1d(1)
        self.sigmoid = nn.Sigmoid()


    def forward(self, x, hidden):
        embeddings = self.embedding_layer(x)

        final_output = torch.zeros_like(embeddings)
        
        for i in range(embeddings.shape[1]):
            output, hidden = self.gru(embeddings[:,i,:].view(1, 1, -1), hidden)
            final_output[:,i,:] = output
        
        final_output = self.pool(final_output)
        final_output = self.sigmoid(final_output)

        return final_output.squeeze(2)


    def init_hidden(self):
        return torch.zeros(self.num_layers, 1, self.hidden_size, device=device)


    def get_indices(self, string):
        indices = []

        string_arr = string.split(" ")

        for word in string_arr:
            if np.where(self.vocab == word)[0].shape[0] != 0:
                indices += [np.where(self.vocab == word)[0]]
            else:
                indices += [np.array([0])]
        
        return torch.tensor(indices).squeeze(1)

In [234]:
class EntityDataset(Dataset):
    # overriden methods
    def __init__(self, file_path, model):
        self.file_path = file_path

        self.model = model


    def __len__(self):
        with open(self.file_path, "rbU") as f:
            num_lines = sum(1 for _ in f)
        
        # don't count first line
        return num_lines - 1


    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        particular_line = linecache.getline(self.file_path, idx+1)
        cleaned_sample = self.clean_up(particular_line)

        input_sentence = cleaned_sample[4]
        entity1 = cleaned_sample[1]
        entity2 = cleaned_sample[2]

        sentence_arr = self.model.get_indices(input_sentence)
        entity1_arr = self.model.get_indices(entity1)
        entity2_arr = self.model.get_indices(entity2)

        '''labels = torch.zeros_like(sentence_arr)

        for _, word_index in enumerate(sentence_arr):
            for entity_index in entity1_arr:
                if entity_index == word_index:
                    labels[_] = 1
            for entity_index in entity2_arr:
                if entity_index == word_index:
                    labels[_] = 1'''

        labels1 = self.generate_labels(sentence_arr, entity1_arr)
        labels2 = self.generate_labels(sentence_arr, entity2_arr)

        labels = labels1 | labels2

        return sentence_arr, labels.type(torch.float32)
    

    # first instance of entity in sentence
    def generate_labels(self, sentence_arr, entity_arr):
        correct = []

        for _, token in enumerate(sentence_arr):
            if _ < sentence_arr.shape[0] - entity_arr.shape[0]:
                if token == entity_arr[0]:
                    not_equal = False
                    for i, val in enumerate(entity_arr):
                        if sentence_arr[_+i] != val:
                            not_equal = True
                    if not not_equal:
                        correct += [_]
            else:
                break
        
        labels = torch.zeros_like(sentence_arr)
        for i in correct:
            labels[i:i+entity_arr.shape[0]] = 1
        
        return labels
        

    # helper
    def clean_up(self, line):
        remove_chars = ["<e1>", "</e1>", "<e2>", "</e2>"]

        line = line.strip()

        for char in remove_chars:
            line = line.replace(char, "")
        
        # string clean up
        line = re.sub(r'[^a-zA-Z1-9\s]', '', line)
        line = re.sub(' +', ' ', line)
        line = line.lower()

        line_data = line.split("\t")
        
        return line_data

In [240]:
test_string = "hello how are you"

model = SimpleGRU(50, "utils/embs_npa.npy", "utils/vocab_npa.npy")

indices = model.get_indices(test_string)

print(indices.shape)

start_hidden = model.init_hidden()
output = model(indices.unsqueeze(0), start_hidden)

print(output.shape)

torch.Size([4])
torch.Size([1, 4])


In [241]:
batch_size = 1

dataset = EntityDataset('data/en_corpora_test.txt', model)

train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

  with open(self.file_path, "rbU") as f:


In [242]:
sample = next(iter(train_loader))

print(sample)

[tensor([[  261,  2488,  3107,  1305,  2885,    16,     9, 11357,   245,  2311]]), tensor([[1., 1., 1., 0., 0., 0., 0., 1., 0., 0.]])]


  with open(self.file_path, "rbU") as f:


In [243]:
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [244]:
EPOCHS = 1

for epoch in range(EPOCHS):
    for _, sample in tqdm(enumerate(train_loader), total=len(train_loader)):
        optimizer.zero_grad()

        input = sample[0]
        labels = sample[1]

        start_hidden = model.init_hidden()
        output = model(input, start_hidden)

        loss = criterion(output, labels)

        loss.backward()

        optimizer.step()

        if _ % 50 == 0:
            print(loss.item())

  with open(self.file_path, "rbU") as f:
  0%|          | 1/5461 [00:00<34:59,  2.60it/s]

0.7963269352912903


  1%|          | 51/5461 [00:16<29:27,  3.06it/s]

0.7449396252632141


  2%|▏         | 101/5461 [00:32<26:34,  3.36it/s]

0.7097101807594299


  3%|▎         | 151/5461 [00:47<26:24,  3.35it/s]

0.7184340953826904


  4%|▎         | 201/5461 [01:02<30:38,  2.86it/s]

0.7189067006111145


  5%|▍         | 251/5461 [01:17<25:02,  3.47it/s]

0.7064711451530457


  6%|▌         | 301/5461 [01:34<22:12,  3.87it/s]

0.6930516958236694


  6%|▋         | 351/5461 [01:50<23:10,  3.68it/s]

0.6752590537071228


  7%|▋         | 401/5461 [02:06<26:50,  3.14it/s]

0.6595116257667542


  8%|▊         | 451/5461 [02:21<26:58,  3.09it/s]

0.6264247298240662


  9%|▉         | 502/5461 [02:38<22:37,  3.65it/s]

0.5769418478012085


 10%|█         | 551/5461 [02:52<20:07,  4.07it/s]

0.5429247617721558


 11%|█         | 601/5461 [03:08<26:46,  3.03it/s]

0.4897362291812897


 12%|█▏        | 651/5461 [03:24<23:13,  3.45it/s]

0.4838390052318573


 13%|█▎        | 701/5461 [03:41<28:14,  2.81it/s]

0.4204164445400238


 14%|█▍        | 751/5461 [03:57<27:01,  2.91it/s]

0.501850962638855


 15%|█▍        | 801/5461 [04:12<27:43,  2.80it/s]

0.5096961855888367


 16%|█▌        | 851/5461 [04:28<21:31,  3.57it/s]

0.45458701252937317


 16%|█▋        | 901/5461 [04:43<29:53,  2.54it/s]

0.38806697726249695


 17%|█▋        | 951/5461 [04:58<16:00,  4.69it/s]

0.6070653796195984


 18%|█▊        | 1001/5461 [05:14<23:34,  3.15it/s]

0.5304690003395081


 19%|█▉        | 1051/5461 [05:30<18:44,  3.92it/s]

0.41750457882881165


 20%|██        | 1101/5461 [05:45<16:07,  4.51it/s]

0.47457680106163025


 21%|██        | 1151/5461 [06:01<26:16,  2.73it/s]

0.3770817220211029


 22%|██▏       | 1202/5461 [06:17<16:46,  4.23it/s]

0.4663456082344055


 23%|██▎       | 1251/5461 [06:32<23:44,  2.96it/s]

0.4909062087535858


 24%|██▍       | 1301/5461 [06:48<19:56,  3.48it/s]

0.5039552450180054


 25%|██▍       | 1351/5461 [07:04<22:02,  3.11it/s]

0.3598646819591522


 26%|██▌       | 1401/5461 [07:19<17:24,  3.89it/s]

0.3822510540485382


 27%|██▋       | 1451/5461 [07:37<26:48,  2.49it/s]

0.37120920419692993


 27%|██▋       | 1501/5461 [07:52<20:19,  3.25it/s]

0.42139241099357605


 28%|██▊       | 1551/5461 [08:09<25:13,  2.58it/s]

0.5351325869560242


 29%|██▉       | 1601/5461 [08:24<20:36,  3.12it/s]

0.3829941749572754


 30%|███       | 1651/5461 [08:41<17:53,  3.55it/s]

0.6592323780059814


 31%|███       | 1701/5461 [08:56<18:00,  3.48it/s]

0.38424304127693176


 32%|███▏      | 1751/5461 [09:12<22:15,  2.78it/s]

0.533634603023529


 33%|███▎      | 1801/5461 [09:27<15:10,  4.02it/s]

0.35946032404899597


 34%|███▍      | 1851/5461 [09:42<16:12,  3.71it/s]

0.4089503884315491


 35%|███▍      | 1901/5461 [09:58<16:10,  3.67it/s]

0.490263432264328


 36%|███▌      | 1951/5461 [10:14<18:44,  3.12it/s]

0.487600177526474


 37%|███▋      | 2001/5461 [10:29<17:52,  3.23it/s]

0.4117589592933655


 38%|███▊      | 2051/5461 [10:45<14:03,  4.04it/s]

0.5000258684158325


 38%|███▊      | 2101/5461 [11:02<17:03,  3.28it/s]

0.3940788805484772


 39%|███▉      | 2151/5461 [11:18<16:02,  3.44it/s]

0.3786279559135437


 40%|████      | 2201/5461 [11:33<16:23,  3.31it/s]

0.3976755142211914


 41%|████      | 2252/5461 [11:50<15:05,  3.54it/s]

0.3423773944377899


 42%|████▏     | 2301/5461 [12:05<15:23,  3.42it/s]

0.3438480794429779


 43%|████▎     | 2351/5461 [12:22<15:13,  3.41it/s]

0.3373469412326813


 44%|████▍     | 2401/5461 [12:37<19:13,  2.65it/s]

0.6316428780555725


 45%|████▍     | 2451/5461 [12:53<14:26,  3.47it/s]

0.5344651341438293


 46%|████▌     | 2501/5461 [13:09<18:14,  2.70it/s]

0.40964561700820923


 47%|████▋     | 2551/5461 [13:24<12:42,  3.82it/s]

0.41632455587387085


 48%|████▊     | 2601/5461 [13:41<18:14,  2.61it/s]

0.786083459854126


 49%|████▊     | 2651/5461 [13:57<13:18,  3.52it/s]

0.3808371126651764


 49%|████▉     | 2701/5461 [14:11<14:46,  3.11it/s]

0.4261455833911896


 50%|█████     | 2751/5461 [14:27<17:18,  2.61it/s]

0.9540546536445618


 51%|█████▏    | 2801/5461 [14:42<13:48,  3.21it/s]

0.3818776607513428


 52%|█████▏    | 2851/5461 [14:57<11:53,  3.66it/s]

0.41829773783683777


 53%|█████▎    | 2901/5461 [15:12<14:13,  3.00it/s]

0.3806592524051666


 54%|█████▍    | 2951/5461 [15:28<13:23,  3.13it/s]

0.4615945518016815


 55%|█████▍    | 3001/5461 [15:45<11:02,  3.71it/s]

0.3293856978416443


 56%|█████▌    | 3051/5461 [16:00<11:25,  3.51it/s]

0.48591864109039307


 57%|█████▋    | 3101/5461 [16:17<13:40,  2.88it/s]

0.40889230370521545


 58%|█████▊    | 3151/5461 [16:33<11:21,  3.39it/s]

0.5565140247344971


 59%|█████▊    | 3201/5461 [16:48<12:36,  2.99it/s]

0.5043671131134033


 60%|█████▉    | 3251/5461 [17:04<11:37,  3.17it/s]

0.3691061735153198


 60%|██████    | 3301/5461 [17:21<14:09,  2.54it/s]

0.37234199047088623


 61%|██████▏   | 3351/5461 [17:37<12:05,  2.91it/s]

0.6648101210594177


 62%|██████▏   | 3401/5461 [17:54<09:21,  3.67it/s]

0.42945584654808044


 63%|██████▎   | 3452/5461 [18:10<07:29,  4.47it/s]

0.5584418773651123


 64%|██████▍   | 3501/5461 [18:24<09:21,  3.49it/s]

0.3690791428089142


 65%|██████▌   | 3551/5461 [18:40<11:41,  2.72it/s]

0.5330342054367065


 66%|██████▌   | 3601/5461 [18:56<09:54,  3.13it/s]

0.3515178859233856


 67%|██████▋   | 3651/5461 [19:11<08:01,  3.76it/s]

0.3617291748523712


 68%|██████▊   | 3701/5461 [19:28<10:03,  2.92it/s]

0.3816923499107361


 69%|██████▊   | 3751/5461 [19:43<06:59,  4.08it/s]

0.37208691239356995


 70%|██████▉   | 3801/5461 [19:58<07:32,  3.67it/s]

0.5812370777130127


 71%|███████   | 3851/5461 [20:15<08:02,  3.34it/s]

0.3786599636077881


 71%|███████▏  | 3901/5461 [20:31<07:28,  3.48it/s]

0.38698649406433105


 72%|███████▏  | 3951/5461 [20:46<07:18,  3.44it/s]

0.43494200706481934


 73%|███████▎  | 4001/5461 [21:03<09:21,  2.60it/s]

0.39720574021339417


 74%|███████▍  | 4052/5461 [21:19<07:51,  2.99it/s]

0.4362933933734894


 75%|███████▌  | 4101/5461 [21:35<06:55,  3.27it/s]

0.3547227680683136


 76%|███████▌  | 4152/5461 [21:51<05:48,  3.76it/s]

0.3972003161907196


 77%|███████▋  | 4201/5461 [22:06<05:58,  3.52it/s]

0.3440462052822113


 78%|███████▊  | 4251/5461 [22:22<07:00,  2.88it/s]

0.5277103781700134


 79%|███████▉  | 4301/5461 [22:38<06:47,  2.85it/s]

0.5264571309089661


 80%|███████▉  | 4351/5461 [22:54<06:54,  2.68it/s]

0.4653143882751465


 81%|████████  | 4401/5461 [23:10<05:25,  3.26it/s]

0.40359973907470703


 82%|████████▏ | 4452/5461 [23:27<05:05,  3.31it/s]

0.3766101002693176


 82%|████████▏ | 4501/5461 [23:43<04:39,  3.44it/s]

0.38809752464294434


 83%|████████▎ | 4551/5461 [23:59<05:21,  2.83it/s]

0.31971976161003113


 84%|████████▍ | 4602/5461 [24:14<03:53,  3.68it/s]

0.5704258680343628


 85%|████████▌ | 4651/5461 [24:30<03:45,  3.58it/s]

0.381595253944397


 86%|████████▌ | 4701/5461 [24:45<03:51,  3.28it/s]

0.36729696393013


 87%|████████▋ | 4752/5461 [25:00<03:10,  3.73it/s]

0.4121834933757782


 88%|████████▊ | 4801/5461 [25:15<03:48,  2.89it/s]

0.5052014589309692


 89%|████████▉ | 4851/5461 [25:32<02:58,  3.42it/s]

0.48709964752197266


 90%|████████▉ | 4901/5461 [25:46<02:59,  3.12it/s]

0.3289797902107239


 91%|█████████ | 4951/5461 [26:03<03:25,  2.49it/s]

0.4433615803718567


 92%|█████████▏| 5001/5461 [26:19<02:22,  3.24it/s]

0.39484626054763794


 92%|█████████▏| 5051/5461 [26:35<02:24,  2.85it/s]

0.38547343015670776


 93%|█████████▎| 5101/5461 [26:51<01:51,  3.22it/s]

0.40412116050720215


 94%|█████████▍| 5151/5461 [27:06<01:18,  3.97it/s]

0.599708080291748


 95%|█████████▌| 5201/5461 [27:23<01:33,  2.78it/s]

0.34666335582733154


 96%|█████████▌| 5251/5461 [27:39<01:27,  2.40it/s]

0.4209079146385193


 97%|█████████▋| 5301/5461 [27:54<01:01,  2.59it/s]

0.3551273047924042


 98%|█████████▊| 5351/5461 [28:10<00:36,  2.99it/s]

0.36369073390960693


 99%|█████████▉| 5401/5461 [28:26<00:17,  3.49it/s]

0.3585550785064697


100%|█████████▉| 5451/5461 [28:42<00:03,  3.02it/s]

0.6047248840332031


100%|██████████| 5461/5461 [28:45<00:00,  3.16it/s]


In [250]:
test_string = "I am a very cool person"

indices = model.get_indices(test_string)

start_hidden = model.init_hidden()
output = model(indices.unsqueeze(0), start_hidden)

print(output)

tensor([[0.6960, 0.5403, 0.2758, 0.2697, 0.2694, 0.2693]],
       grad_fn=<SqueezeBackward1>)
