In [1]:
import time

import torch
import torch.nn as nn
from torch import nn, optim, Tensor
from torch.utils.data import DataLoader
import torch.nn.functional as F

import numpy as np
from typing import List, Tuple, Any, Optional, Dict

from data.raw_data import Instance, Data
from data.data_utils import Dictionary, triples_to_indices
from data.dataset import CLUTRRdata
from models.encoders import EncoderLSTM
from models.decoders import DecoderLSTM
from training.train import train, predict_raw

Triple = Fact = Tuple[str, str, str]
Story = List[Fact]

In [23]:
test_data_path = ['data/clutrr-data/data_089907f8/1.2_test.csv']
data = Data(test_paths = test_data_path)
dictionary = Dictionary(data)
story, target = data.train[5000].story, data.train[5000].target
s, p, o = target
print(f'Number of entities {dictionary.num_entities}')
print(f'Example story: {story}')
print(f'Example target: {target}')

entity_embeddings = nn.Embedding(dictionary.num_entities, 100, sparse=True)
relation_embeddings = nn.Embedding(dictionary.num_relations, 100, sparse=True)

# Currently no backprop through embeddings
entity_embeddings.weight.requires_grad = False
relation_embeddings.weight.requires_grad = False
print(f'entity embeddings shape: {entity_embeddings.weight.shape}')

indices, ent, rel = triples_to_indices(dictionary, story)

Number of entities 78
Example story: [('James', 'sister', 'Lisa'), ('Lisa', 'father', 'Jason')]
Example target: ('James', 'father', 'Jason')
entity embeddings shape: torch.Size([78, 100])


In [24]:
test_data = data.test['data/clutrr-data/data_089907f8/1.2_test.csv']

In [4]:
train_set = CLUTRRdata(data.train[:5000], dictionary)
#trainloader = DataLoader(dataset = train_set,
#                        batch_size = 5,
#                        shuffle = True,)

In [5]:
len(train_set)

5000

In [6]:
test_set = CLUTRRdata(test_data, dictionary)
#testloader = DataLoader(dataset = test_set,
#                        batch_size = 5,
#                        shuffle = True,)

In [7]:
encoder = EncoderLSTM(hidden_size = 100,
                    entity_embeddings = entity_embeddings,
                    relation_embeddings = relation_embeddings)

decoder = DecoderLSTM(hidden_size = 100,
                    entity_embeddings = entity_embeddings,
                    num_relations = dictionary.num_relations)

In [8]:
train(encoder, decoder, train_set, test_set, 4, 0.01)

Training model...
Epoch [1/4], Train loss: 27.3562, Test loss: 0.1362                        
Epoch [2/4], Train loss: 0.0524, Test loss: 0.0385                        
Epoch [3/4], Train loss: 0.0189, Test loss: 0.0162                        
Epoch [4/4], Train loss: 0.0089, Test loss: 0.0082                        
Finished training



In [31]:
predict_raw(test_data[0], encoder, decoder, dictionary)

Story: [('Jason', 'grandson', 'Scott'), ('Scott', 'brother', 'Lewis')]
Target: ('Jason', 'grandson', 'Lewis')
Predicted relation: grandson


In [22]:
test_data[4].story

[('Jason', 'grandson', 'Donald'), ('Donald', 'brother', 'Russell')]

In [9]:
a = 3

In [25]:
for i in range(10):
    print(test_data[i])

[('Jason', 'grandson', 'Scott'), ('Scott', 'brother', 'Lewis')]	('Jason', 'grandson', 'Lewis')
[('Gabrielle', 'husband', 'Jason'), ('Jason', 'daughter', 'Lisa')]	('Gabrielle', 'daughter', 'Lisa')
[('Gabrielle', 'husband', 'Jason'), ('Jason', 'daughter', 'Myrna')]	('Gabrielle', 'daughter', 'Myrna')
[('Gabrielle', 'grandson', 'David'), ('David', 'brother', 'Joe')]	('Gabrielle', 'grandson', 'Joe')
[('Jason', 'grandson', 'Donald'), ('Donald', 'brother', 'Russell')]	('Jason', 'grandson', 'Russell')
[('Dorothy', 'husband', 'James'), ('James', 'daughter', 'Theresa')]	('Dorothy', 'daughter', 'Theresa')
[('Myrna', 'husband', 'Christopher'), ('Christopher', 'daughter', 'Lucille')]	('Myrna', 'daughter', 'Lucille')
[('Gabrielle', 'grandson', 'Kevin'), ('Kevin', 'brother', 'Dan')]	('Gabrielle', 'grandson', 'Dan')
[('Gabrielle', 'grandson', 'Dan'), ('Dan', 'brother', 'Micheal')]	('Gabrielle', 'grandson', 'Micheal')
[('Gabrielle', 'grandson', 'Dan'), ('Dan', 'brother', 'Kevin')]	('Gabrielle', 'grands

In [26]:
len(test_data)

38