In [1]:
import numpy as np
import pandas as pd
import re
import spacy
import torch.utils.data as tud
nlp = spacy.load('es')
with open("../wordvecs/wiki.es/wiki.es.nospace.vec") as f:
    nlp.vocab.load_vectors(f)

In [60]:
class QADataset(tud.Dataset):
    def __init__(self, data_df):
        self.data_df = data_df

    def __len__(self):
        return self.data_df.shape[0]

    def __getitem__(self, i):
        s = np.zeros((2000, 300))
        s_mask = np.zeros(2000, dtype=np.int32)
        s_var = np.zeros(2000, dtype=np.int32)
        q = np.zeros((50, 300))
        q_mask = np.zeros(50, dtype=np.int32)
        q_var = np.zeros(50, dtype=np.int32)
        q_ph = np.zeros(50, dtype=np.int32)

        story = nlp(self.data_df['story'].iloc[i].lower(), parse=False, tag=False, entity=False)
        s_len = len(story)
        s_mask[:s_len] = [not w.has_vector for w in story]
        s_var[np.where([x.text[:7] == '@entity' for x in story])[0]] =\
            [int(re.search(r'\d+', x.text).group(0)) + 1 for x in story if x.text[:7] == '@entity']
        s[:s_len, :] = np.stack([x.vector for x in story])

        question = nlp(self.data_df['question'].iloc[i].lower(), parse=False, tag=False, entity=False)
        q_len = len(question)
        q_mask[:q_len] = [not w.has_vector for w in question]
        s_var[np.where([x.text[:7] == '@entity' for x in question])[0]] =\
            [int(re.search(r'\d+', x.text).group(0)) + 1 for x in question if x.text[:7] == '@entity']
        q_ph[np.where([x.text == '@placeholder' for x in question])[0]] = 1
        q[:q_len, :] = np.stack([x.vector for x in question])

        answer = int(re.search(r'\d+', self.data_df['answer'].iloc[i]).group(0))

        return s, q, s_len, q_len, s_mask, q_mask, s_var, q_var, q_ph, answer

In [61]:
train = pd.read_pickle("../input_data/input.pkl")
train.head()

Unnamed: 0,question,answer,story,story_length,question_length
6549,VIAJE AL REINO HERMÉTICO: Una española en @pla...,@entity80,"En @entity388 se aprende a llorar en silencio,...",2000,9
6550,El estudiante de EEUU detenido en @placeholder...,@entity212,"En @entity212 se aprende a llorar en silencio,...",2000,13
8163,Volver a nacer dos veces en un mismo siglo en ...,@entity2,Si se callejea por el casco antiguo de @entity...,1998,11
25466,"\nMientras el @entity458, @placeholder y @enti...",@entity170,"Desde que la crisis se hizo evidente, la @enti...",1993,21
25465,"\nMientras el @placeholder, @entity10 y @entit...",@entity126,"Desde que la crisis se hizo evidente, la @enti...",1993,21


In [62]:
ds = QADataset(train)
print(ds.__len__())
s, q, sl, ql, sm, qm, sv, qv, qph, a = ds.__getitem__(2)
print(s.shape)
print(q.shape)
print(a)
print(sl)
print(ql)
print(sm.shape)
print(qm.shape)
print(sv.shape)
print(qv.shape)
print(qph.shape)

56361
(2000, 300)
(50, 300)
2
1998
11
(2000,)
(50,)
(2000,)
(50,)
(50,)


In [63]:
qa_loader = tud.DataLoader(ds, batch_size=20)
s, q, sl, ql, sm, qm, sv, qv, qph, a = next(iter(qa_loader))
print(a)
print(s.shape)
print(q.shape)


  80
 212
   2
 170
 126
 101
 136
  91
  68
 496
 255
 471
 420
  51
 464
 337
 483
 362
 237
 460
[torch.LongTensor of size 20]

torch.Size([20, 2000, 300])
torch.Size([20, 50, 300])
