In [1]:
import pandas as pd
import time

In [4]:
df = pd.read_csv('data.csv')

In [5]:
df.head()

Unnamed: 0.1,Unnamed: 0,id,episode_id,number,raw_text,timestamp_in_ms,speaking_line,character_id,location_id,raw_character_text,raw_location_text,spoken_words,normalized_text,word_count
0,0,10368,35,29,"Lisa Simpson: Maggie, look. What's that?",235000,True,9,5.0,Lisa Simpson,Simpson Home,"Maggie, look. What's that?",maggie look whats that,4.0
1,1,10369,35,30,Lisa Simpson: Lee-mur. Lee-mur.,237000,True,9,5.0,Lisa Simpson,Simpson Home,Lee-mur. Lee-mur.,lee-mur lee-mur,2.0
2,2,10370,35,31,Lisa Simpson: Zee-boo. Zee-boo.,239000,True,9,5.0,Lisa Simpson,Simpson Home,Zee-boo. Zee-boo.,zee-boo zee-boo,2.0
3,3,10372,35,33,Lisa Simpson: I'm trying to teach Maggie that ...,245000,True,9,5.0,Lisa Simpson,Simpson Home,I'm trying to teach Maggie that nature doesn't...,im trying to teach maggie that nature doesnt e...,24.0
4,4,10374,35,35,"Lisa Simpson: It's like an ox, only it has a h...",254000,True,9,5.0,Lisa Simpson,Simpson Home,"It's like an ox, only it has a hump and a dewl...",its like an ox only it has a hump and a dewlap...,18.0


In [9]:
phrases = df['normalized_text'].tolist()

In [12]:
phrases[:10]

['maggie look whats that',
 'lee-mur lee-mur',
 'zee-boo zee-boo',
 'im trying to teach maggie that nature doesnt end with the barnyard i want her to have all the advantages that i didnt have',
 'its like an ox only it has a hump and a dewlap hump and dew-lap hump and dew-lap',
 'you know his blood type how romantic',
 'oh yeah whats my shoe size',
 'ring',
 'yes dad',
 'ooh look maggie what is that do-dec-ah-edron dodecahedron']

In [29]:
text = [[c for c in ph] for ph in phrases if type(ph) is str]

In [30]:
CHARS = set('abcdefghijklmnopqrstuvwxyz ')

In [31]:
INDEX_TO_CHAR = ['none'] + [w for w in CHARS]

In [32]:
INDEX_TO_CHAR

['none',
 'x',
 'e',
 'a',
 'r',
 'd',
 'q',
 'g',
 'c',
 'v',
 'h',
 'p',
 'w',
 'm',
 'b',
 'i',
 'z',
 's',
 'j',
 'u',
 'f',
 'n',
 'k',
 'o',
 'l',
 ' ',
 't',
 'y']

In [33]:
CHAR_TO_INDEX = {w: i for i, w in enumerate(INDEX_TO_CHAR)}

In [34]:
CHAR_TO_INDEX

{'none': 0,
 'x': 1,
 'e': 2,
 'a': 3,
 'r': 4,
 'd': 5,
 'q': 6,
 'g': 7,
 'c': 8,
 'v': 9,
 'h': 10,
 'p': 11,
 'w': 12,
 'm': 13,
 'b': 14,
 'i': 15,
 'z': 16,
 's': 17,
 'j': 18,
 'u': 19,
 'f': 20,
 'n': 21,
 'k': 22,
 'o': 23,
 'l': 24,
 ' ': 25,
 't': 26,
 'y': 27}

In [35]:
import torch

In [36]:
MAX_LEN = 50

In [40]:
X = torch.zeros((len(text), MAX_LEN), dtype=int)

In [41]:
for i in range(len(text)):
    for j, w in enumerate(text[i]):
        if j >= MAX_LEN:
            break
        X[i, j] = CHAR_TO_INDEX.get(w, CHAR_TO_INDEX['none'])

In [42]:
X[0:1]

tensor([[13,  3,  7,  7, 15,  2, 25, 24, 23, 23, 22, 25, 12, 10,  3, 26, 17, 25,
         26, 10,  3, 26,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]])

In [43]:
class Network(torch.nn.Module):

    def __init__(self):
        super(Network, self).__init__()
        self.word_embeddings = torch.nn.Embedding(len(INDEX_TO_CHAR), 28)
        self.gru = torch.nn.RNN(28, 128, batch_first=True)
        self.hidden2tag = torch.nn.Linear(128, len(INDEX_TO_CHAR))

    def forward(self, sentences):
        embeds = self.word_embeddings(sentences)
        gru_out, state = self.gru(embeds)
        tag_space = self.hidden2tag(gru_out.reshape(-1, 128))
        return tag_space.reshape(sentences.shape[0], sentences.shape[1], -1), state

    def forward_state(self, sentences, state):
        embeds = self.word_embeddings(sentences)
        gru_out, state = self.gru(embeds, state)
        tag_space = self.hidden2tag(gru_out.reshape(-1, 128))
        return tag_space.reshape(sentences.shape[0], sentences.shape[1], -1), state

In [44]:
model = Network()

In [45]:
model.forward(X[0:1])[0].shape

torch.Size([1, 50, 28])

In [47]:
def generate_sentence():
    sentence = ['h', 'e', 'l', 'l', 'o']
    state = None
    for i in range(MAX_LEN):
        X = torch.Tensor([[CHAR_TO_INDEX[sentence[i]]]]).type(torch.long)
        if i == 0:
            result, state = model.forward(X)
        else:
            result, state = model.forward_state(X, state)
        prediction = result[0, -1, :]
        index_of_prediction = prediction.argmax()
        if i >= len(sentence) - 1:
            if index_of_prediction == 0:
                break
        sentence.append(INDEX_TO_CHAR[index_of_prediction])
    print(''.join(sentence))

In [48]:
generate_sentence()

helloltrdkrzz z ggkgkiwzwzgmgugwiwkwmtmggitiwwgggmuwiwi


In [49]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=.05)

In [53]:
for ep in range(300):
    start = time.time()
    train_loss = 0.
    train_passed = 0

    for i in range(int(len(X) / 100)):
        batch = X[i * 100:(i + 1) * 100]
        X_batch = batch[:, :-1]
        Y_batch = batch[:, 1:].flatten()

        optimizer.zero_grad()
        answers, _ = model.forward(X_batch)
        answers = answers.view(-1, len(INDEX_TO_CHAR))
        loss = criterion(answers, Y_batch)
        train_loss += loss.item()

        loss.backward()
        optimizer.step()
        train_passed += 1

    print("Epoch {}. Time: {:.3f}, Train loss: {:.3f}".format(ep, time.time() - start, train_loss / train_passed))
    generate_sentence()


Epoch 0. Time: 2.980, Train loss: 1.497
helloa    ntoooo  n  tooaoh n n toaoah tn  thothhet e  
Epoch 1. Time: 2.978, Train loss: 1.489
helloa    ntoooo  n  tooaoh n n toaoah tn  thothhet e  
Epoch 2. Time: 3.105, Train loss: 1.481
helloa l  ntoaho  n  tooaoh n n toaoah tn  thothhet e  
Epoch 3. Time: 3.251, Train loss: 1.473
helloa l  ntoaho  n  tooaoh n n toaoah tn  thothhet e  
Epoch 4. Time: 3.225, Train loss: 1.465
helloa l  ntoaho  n  tooaoh n n toaoah tn  thothhet e  
Epoch 5. Time: 3.365, Train loss: 1.458
helloa l  ntoaho  n  tooaoh n n toaoah tn  thothhet e  
Epoch 6. Time: 3.221, Train loss: 1.451
helloa l  ntoaho  n  tooaoh n n toaoah tn  thothhet e  
Epoch 7. Time: 3.041, Train loss: 1.445
helloa l  ntoaho  n  tooaoh n n toaoah tn  thothhet e  
Epoch 8. Time: 2.970, Train loss: 1.438
helloa l  ntoahoe n   toaooh tnn thotthe   e sooamoun e
Epoch 9. Time: 2.927, Train loss: 1.432
helloa l  ntoahoern     tooooh nnn tooeth n  etotorh   
Epoch 10. Time: 3.327, Train loss: 1.426

KeyboardInterrupt: 

## Полезное

In [None]:
torchnlp.samplers.BucketBatchSampler - для сэмплирования данных