In [1]:
import os
import pandas as pd
import torch
from torch import nn 
import time
import numpy as np 

In [2]:
simpson_data = pd.read_csv('Data/data.csv').iloc[:, 1:]

In [3]:
simpson_data.head()

Unnamed: 0,id,episode_id,number,raw_text,timestamp_in_ms,speaking_line,character_id,location_id,raw_character_text,raw_location_text,spoken_words,normalized_text,word_count
0,10368,35,29,"Lisa Simpson: Maggie, look. What's that?",235000,True,9,5.0,Lisa Simpson,Simpson Home,"Maggie, look. What's that?",maggie look whats that,4.0
1,10369,35,30,Lisa Simpson: Lee-mur. Lee-mur.,237000,True,9,5.0,Lisa Simpson,Simpson Home,Lee-mur. Lee-mur.,lee-mur lee-mur,2.0
2,10370,35,31,Lisa Simpson: Zee-boo. Zee-boo.,239000,True,9,5.0,Lisa Simpson,Simpson Home,Zee-boo. Zee-boo.,zee-boo zee-boo,2.0
3,10372,35,33,Lisa Simpson: I'm trying to teach Maggie that ...,245000,True,9,5.0,Lisa Simpson,Simpson Home,I'm trying to teach Maggie that nature doesn't...,im trying to teach maggie that nature doesnt e...,24.0
4,10374,35,35,"Lisa Simpson: It's like an ox, only it has a h...",254000,True,9,5.0,Lisa Simpson,Simpson Home,"It's like an ox, only it has a hump and a dewl...",its like an ox only it has a hump and a dewlap...,18.0


In [4]:
phrases = simpson_data.loc[:, 'normalized_text'].to_list()

In [5]:
phrases[:3]

['maggie look whats that', 'lee-mur lee-mur', 'zee-boo zee-boo']

In [6]:
phrases_cleaned = [*filter(lambda el: isinstance(el, str), phrases)]

In [7]:
text = [[sym for sym in p] for p in phrases if isinstance(p, str)]

In [8]:
CHARS = sorted(set('abcdefghijklmnopqrstuvwxyz '))
INDEX_TO_CHAR = ['none'] + [w for w in CHARS]
CHAR_TO_INDEX = {i:w for w, i in enumerate(INDEX_TO_CHAR)}

In [9]:
alphabet = set('abcdefghijklmnopqrstuvwxyz')

In [10]:
orig = np.array(sorted([*alphabet]))
shifted = np.roll(orig, 13)

In [11]:
orig_shift = {l_or:l_sh for l_or, l_sh in zip(orig, shifted)}
orig_shift

{'a': 'n',
 'b': 'o',
 'c': 'p',
 'd': 'q',
 'e': 'r',
 'f': 's',
 'g': 't',
 'h': 'u',
 'i': 'v',
 'j': 'w',
 'k': 'x',
 'l': 'y',
 'm': 'z',
 'n': 'a',
 'o': 'b',
 'p': 'c',
 'q': 'd',
 'r': 'e',
 's': 'f',
 't': 'g',
 'u': 'h',
 'v': 'i',
 'w': 'j',
 'x': 'k',
 'y': 'l',
 'z': 'm'}

In [12]:
shifted_phrases = [''.join([orig_shift.get(i, ' ') for i in phrase]) for phrase in phrases_cleaned]

In [13]:
shifted_phrases[:10]

['znttvr ybbx jungf gung',
 'yrr zhe yrr zhe',
 'mrr obb mrr obb',
 'vz gelvat gb grnpu znttvr gung angher qbrfag raq jvgu gur onealneq v jnag ure gb unir nyy gur nqinagntrf gung v qvqag unir',
 'vgf yvxr na bk bayl vg unf n uhzc naq n qrjync uhzc naq qrj ync uhzc naq qrj ync',
 'lbh xabj uvf oybbq glcr ubj ebznagvp',
 'bu lrnu jungf zl fubr fvmr',
 'evat',
 'lrf qnq',
 'bbu ybbx znttvr jung vf gung qb qrp nu rqeba qbqrpnurqeba']

In [14]:
max_len = 50 
Y = torch.zeros([len(phrases_cleaned), max_len], dtype=int)
for i, sentence in enumerate(phrases_cleaned):
    for j, num in enumerate([CHAR_TO_INDEX.get(letter, CHAR_TO_INDEX['none']) for letter in sentence[:50] if isinstance(letter, str)]):
        Y[i, j] = num

In [15]:
max_len = 50 
X = torch.zeros([len(shifted_phrases), max_len], dtype=int)
for i, sentence in enumerate(shifted_phrases):
    for j, num in enumerate([CHAR_TO_INDEX.get(letter, CHAR_TO_INDEX['none']) for letter in sentence[:50] if isinstance(letter, str)]):
        X[i, j] = num

In [16]:
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.embed = nn.Embedding(len(CHAR_TO_INDEX), 28)
        self.rnn = nn.RNN(28, 128, batch_first=True)
        self.linear = nn.Linear(128, len(CHAR_TO_INDEX))
        
    def forward(self, sentences, state=None):
        embed = self.embed(sentences)
        output, hn = self.rnn(embed)
        out = self.linear(output)
        return out 

In [17]:
model = Network()

In [18]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.05)

In [19]:
for ep in range(10 + 1):
    start = time.time()
    train_loss = 0.
    train_passed = 0 
    
    for i in range(int(len(X)/ 100)):
        batch_X = X[i*100:(i+1)*100]
        batch_Y = Y[i*100:(i+1)*100]
        X_batch = batch_X[:, :]
        Y_batch = batch_Y[:, :].flatten()
        
        optimizer.zero_grad()
        
        answers = model(X_batch)
        answers = answers.view(-1, len(INDEX_TO_CHAR))
        loss = criterion(answers, Y_batch)
        train_loss += loss.item()
        
        loss.backward()
        optimizer.step()
        train_passed += 1
        
    print(f'\nEpoch {ep}, Time: {time.time() - start:.3f}, Train loss: {train_loss / train_passed}')


Epoch 0, Time: 2.282, Train loss: 1.3573947212210409

Epoch 1, Time: 2.255, Train loss: 0.46284454067548114

Epoch 2, Time: 2.250, Train loss: 0.22225435381686245

Epoch 3, Time: 2.250, Train loss: 0.13364730161373262

Epoch 4, Time: 2.265, Train loss: 0.09236302823518161

Epoch 5, Time: 2.247, Train loss: 0.07025524749662038

Epoch 6, Time: 2.239, Train loss: 0.05717332650803857

Epoch 7, Time: 2.239, Train loss: 0.048726750227312245

Epoch 8, Time: 2.248, Train loss: 0.042856611863330556

Epoch 9, Time: 2.250, Train loss: 0.03853362096749522

Epoch 10, Time: 2.283, Train loss: 0.03520585175741602


In [20]:
text = """
species of flowering plants with showy flowers. It takes its name from the Greek word for a rainbow, which is also the name for the Greek goddess of the rainbow, Iris. Some authors state that the name refers to the wide variety of flower colors found among the many species.[3] As well as being the scientific name, iris is also widely used as a common name for all Iris species, as well as some belonging to other closely related genera. A common name for some species is 'flags', while the plants of the subgenus Scorpiris are widely known as 'junos', particularly in horticulture. It is a popular garden flower.
"""

text_norm = ''.join([letter for letter in text.lower().strip() if letter in INDEX_TO_CHAR])
test_phrase = ''.join([orig_shift.get(i, ' ') for i in text_norm])

In [21]:
max_len = 50 
X_test = torch.zeros([1, max_len], dtype=int)
for j, num in enumerate([CHAR_TO_INDEX.get(letter, CHAR_TO_INDEX['none']) for letter in test_phrase]):
    if j >= max_len:
        break
    X_test[0, j] = num

In [22]:
for i, j in zip( X_test[0], torch.argmax(model(X_test)[0], axis=1) ):
    print(INDEX_TO_CHAR[i], INDEX_TO_CHAR[j])

f s
c p
r e
p c
v i
r e
f s
   
b o
s f
   
s f
y l
b o
j w
r e
e r
v i
a n
t g
   
c p
y l
n a
a n
g t
f s
   
j w
v i
g t
u h
   
f s
u h
b o
j w
l y
   
s f
y l
b o
j w
r e
e r
f s
   
v i
g t
   


In [26]:
orig_text = text_norm[:50]
predicted_text = ''.join([INDEX_TO_CHAR[j] for j in torch.argmax(model(X_test)[0], axis=1)])

In [29]:
print('\n', orig_text, '\n', predicted_text)


 species of flowering plants with showy flowers it  
 species of flowering plants with showy flowers it 


In [30]:
orig_text == predicted_text 

True