In [1]:
import time

import torch
import torch.nn as nn
from torch import nn, optim, Tensor
from torch.utils.data import DataLoader
import torch.nn.functional as F

import numpy as np
from typing import List, Tuple, Any, Optional, Dict

from data.raw_data import Instance, Data
from data.data_utils import Dictionary, triples_to_indices
from data.dataset import CLUTRRdata
from models.encoders import EncoderLSTM
from models.decoders import DecoderLSTM
from training.train import train, predict_raw

In [2]:
test_data_path = ['data/clutrr-data/data_089907f8/1.2_test.csv']
data = Data(test_paths = test_data_path)
dictionary = Dictionary(data)
story, target = data.train[5080].story, data.train[5000].target
s, p, o = target
print(f'Number of entities {dictionary.num_entities}')
print(f'Example story: {story}')
print(f'Example target: {target}')

entity_embeddings = nn.Embedding(dictionary.num_entities, 100, sparse=True)
relation_embeddings = nn.Embedding(dictionary.num_relations, 100, sparse=True)

# Currently no backprop through embeddings
entity_embeddings.weight.requires_grad = False
relation_embeddings.weight.requires_grad = False
print(f'entity embeddings shape: {entity_embeddings.weight.shape}')

indices, ent, rel = triples_to_indices(dictionary, story)

Number of entities 78
Example story: [('Dwight', 'brother', 'Christopher'), ('Christopher', 'daughter', 'Lucille')]
Example target: ('James', 'father', 'Jason')
entity embeddings shape: torch.Size([78, 100])


In [3]:
train_set = CLUTRRdata(data.train, dictionary)

test_data = data.test['data/clutrr-data/data_089907f8/1.2_test.csv']
test_set = CLUTRRdata(test_data, dictionary)

In [4]:
encoder = EncoderLSTM(hidden_size = 100,
                    entity_embeddings = entity_embeddings,
                    relation_embeddings = relation_embeddings)

decoder = DecoderLSTM(hidden_size = 100,
                    entity_embeddings = entity_embeddings,
                    num_relations = dictionary.num_relations)

In [5]:
train(encoder, decoder, train_set, test_set, 4, 0.01)

Training model...
Running epoch [1/4],Step: [100/505],Loss: 103.27               

KeyboardInterrupt: 

In [None]:
predict_raw(test_data[3], encoder, decoder, dictionary)