In [1]:
import time

import torch
import torch.nn as nn
from torch import nn, optim, Tensor
from torch.utils.data import DataLoader
import torch.nn.functional as F

import numpy as np
from typing import List, Tuple, Any, Optional, Dict

from training.data import Instance, Data, Dictionary, triples_to_indices
from training.dataset import CLUTRRdata
from training.train import train, predict_raw
from models.encoders import EncoderLSTM
from models.decoders import DecoderLSTM

Triple = Fact = Tuple[str, str, str]
Story = List[Fact]

In [2]:
data = Data()
dictionary = Dictionary(data)
story, target = data.train[5000].story, data.train[5000].target
s, p, o = target
print(f'Number of entities {dictionary.num_entities}')
print(f'Example story: {story}')
print(f'Example target: {target}')

entity_embeddings = nn.Embedding(dictionary.num_entities, 100, sparse=True)
relation_embeddings = nn.Embedding(dictionary.num_relations, 100, sparse=True)

# Currently no backprop through embeddings
entity_embeddings.weight.requires_grad = False
relation_embeddings.weight.requires_grad = False
print(f'entity embeddings shape: {entity_embeddings.weight.shape}')

indices, ent, rel = triples_to_indices(dictionary, story)

Number of entities 78
Example story: [('James', 'sister', 'Lisa'), ('Lisa', 'father', 'Jason')]
Example target: ('James', 'father', 'Jason')
entity embeddings shape: torch.Size([78, 100])


In [3]:
train_set = CLUTRRdata(data.train[:5000], dictionary)
trainloader = DataLoader(dataset = train_set,
                        batch_size = 5,
                        shuffle = True,)

In [4]:
encoder = EncoderLSTM(hidden_size = 100,
                    entity_embeddings = entity_embeddings,
                    relation_embeddings = relation_embeddings)

decoder = DecoderLSTM(hidden_size = 100,
                    entity_embeddings = entity_embeddings,
                    num_relations = dictionary.num_relations)

In [5]:
train(encoder, decoder, trainloader, 4, 0.01)

Training model...
Epoch [1/5], Train loss: 145.04                                  
Epoch [2/5], Train loss: 0.27                                  
Epoch [3/5], Train loss: 0.10                                  
Epoch [4/5], Train loss: 0.04                                  
Epoch [5/5], Train loss: 0.02                                  
Finished training



In [6]:
predict_raw(data.train[10], encoder, decoder, dictionary)

Story: [('Lisa', 'son', 'Joe'), ('Joe', 'grandmother', 'Gabrielle')]
tensor(12)
Target: ('Lisa', 'mother', 'Gabrielle')
Predicted relation: mother
