In [1]:
import torch
from torch.nn import Module, Dropout, ReLU, Embedding, Sequential, Linear
from torch.nn.functional import normalize
from torch.utils.data import Dataset, DataLoader, random_split

from sklearn.metrics import accuracy_score
import pandas as pd
import numpy as np
from typing import Tuple, List

from torchtext import data
from torchtext.vocab import Vectors
from torch.nn import init
from tqdm import tqdm

In [19]:
def create_fields():
    path_field = data.Field(sequential=True, tokenize=lambda x: x.split(), lower=True, fix_length=10)
    entity_field = data.Field(sequential=False)
    return path_field, entity_field

class MyDataset(data.Dataset):
    def __init__(self, corpus_path:str, path_field:data.Field, entity_field:data.Field, test:bool=False, **kwargs):
        
        fields = [('id', None), ('path', path_field), ('subjs', entity_field), ('objs', entity_field)]
        corpus_data = pd.read_csv(corpus_path)

        if test:
            examples = [data.Example.fromlist([None, text, None, None], fields=fields) for text in tqdm(corpus_data['path'])]
        else:
            examples = [data.Example.fromlist([None, path, subj, obj], fields=fields) for path, subj, obj in tqdm(zip(corpus_data['path'], corpus_data['subj'], corpus_data['obj']))]
        super(MyDataset, self).__init__(examples=examples, fields=fields, **kwargs)



In [16]:
df = pd.read_csv('../data/corpus/dataset.csv')
df = df.sample(frac=1).reset_index(drop=True)
total_num = len(df)
train_df = df[:int(total_num*0.8)]
valid_df = df[int(total_num*0.8):]
train_df.to_csv('../data/corpus/train.csv')
valid_df.to_csv('../data/corpus/valid.csv')

In [20]:
path_field, entity_field = create_fields()
train_data = MyDataset('../data/corpus/train.csv', path_field=path_field, entity_field=entity_field, test=False)
valid_data = MyDataset('../data/corpus/valid.csv', path_field=path_field, entity_field=entity_field, test=False)

689292it [00:07, 93436.13it/s] 
172324it [00:02, 84563.39it/s] 


In [21]:
path_field.build_vocab(train_data)
entity_field.build_vocab(train_data)
train_iter, val_iter = data.BucketIterator.splits((train_data, valid_data), batch_sizes=(32, 32), device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), sort_key=lambda x: len(x.path_field), sort_within_batch=True, repeat=False)




In [23]:
len(entity_field.vocab)

31604