In [20]:
import torch, os, re
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
#tensorboard --logdir 'runs\RNN_torch' --host localhost --port 8888
from collections import Counter
from typing import List
import datetime
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from IPython.display import clear_output

In [21]:
data = open(os.path.join('data_processing','media','tinyshakespeare.txt'), 'r').read() # should be simple plain text file
data = re.sub(r'[^a-zA-Z\s]', '', data) # remove non-alphabet
seq_length = 8
words_all = data.split()
words = list(set([w for w in words_all if len(w) == seq_length]))
chars_all = [c for w in words for c in w] #[*data]
chars = list(set(chars_all))  #list(set(data))
data_size, vocab_size = len(words), len(chars)
print('data has %d words of size %d, and %d unique characters.' % (len(words), seq_length, vocab_size))
'Chars: ' + ' '.join(chars)

data has 1975 words of size 8, and 50 unique characters.


'Chars: g D P m L A K F z j p U Z h u N e J k Y B r x C f o c d v t G M n i T I q l H W O E s y a R w b V S'

In [22]:
frequency = dict(Counter(chars_all))
vocab_unique = list(sorted(frequency.keys(), key=lambda x: frequency[x], reverse=True))
char_to_ix = { ch:i for i,ch in enumerate(vocab_unique) }
ix_to_char = { i:ch for i,ch in enumerate(vocab_unique) }
'Sorted chars: ' + ' '.join(vocab_unique)

'Sorted chars: e s r i n t a o d l c u g h p m f b y v w k C S A R M I q E B T D P L O j x N U H F z W V G K J Y Z'

In [23]:
def oneHotEncode(chars: List[str]) -> torch.Tensor:
    """
        lets hold encoded vectors as columns
    """
    encode = torch.zeros(len(chars),vocab_size)
    x,y = torch.tensor([(i,char_to_ix[a]) for i,a in enumerate(chars)]).T
    encode[x,y] = 1
    return encode.T
split_word = [*words[0]];print(split_word)
word_index = [char_to_ix[a] for a in split_word];print(word_index)
oneHotEncode(split_word)[:10,:].to(int)

['d', 'i', 'r', 'e', 'c', 't', 'l', 'y']
[8, 3, 2, 0, 10, 5, 9, 18]


tensor([[0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0]])

In [24]:
def i2s(inp):
    # index to string
    pos = torch.topk(inp, dim = 1, k = 1)[1].view(-1)
    return ''.join([ix_to_char[int(i)] for i in pos]) 

In [25]:
# changed char columns to rows
num_words_max = 40
X_train = torch.zeros(size=(num_words_max, seq_length, vocab_size), device=device) # words, chars, seq of chars
for i in range(num_words_max):
    X_train[i] = oneHotEncode(words[i]).T

X_train.shape

torch.Size([40, 8, 50])

In [26]:
class wordsDataset(Dataset):
    def __init__(self, data_in):
        self.x = data_in

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return (self.x[idx])[:-1], (self.x[idx])[1:]

dataset = wordsDataset(X_train)
BATCH_SIZE = 1
data_loader_train   = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [27]:
class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(SimpleRNN, self).__init__()
        self.hidden_size= hidden_size
        self.num_layers = num_layers
        self.rnn        = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fc         = nn.Linear(hidden_size, output_size)
                
    def forward(self, x):
        hidden =  torch.zeros(1, self.hidden_size).to(device)#x.size(0)
        out, h = self.rnn(x, hidden)
        y = torch.softmax(self.fc(out),dim=1)
        #print(torch.allclose( out[[-1]], h )) # >>> True
        return y, h
    
    @torch.no_grad
    def predict(self, x, hidden ):
        out, h = self.rnn(x, hidden)
        y = torch.softmax(self.fc(out),dim=1)
        return y, h

hidden_size = 150 
net = SimpleRNN(input_size = vocab_size, hidden_size = hidden_size, 
                num_layers = 1, output_size = vocab_size).to(device)

In [32]:
num_epochs = 15301
criterion = nn.CrossEntropyLoss()#nn.MSELoss()
optimizer = Adam(net.parameters(), lr = 1e-3)#, weight_decay=1e-5)
now = datetime.datetime.now()
s2 = now.strftime("%H_%M_%S")
writer = SummaryWriter(fr'runs/RNN_torch/{s2}')
step = 0
top = f'{"left":^{seq_length}}|{"right":^{seq_length}}|{"guess":^{seq_length+1}}'
top2 = '-'*len(top)
for i in range(num_epochs):
    
    loss2 = 0
    if i % (num_epochs//5) == 0:
        print(top);print(top2)
    for x,y in data_loader_train:
        x = x.squeeze(0)    # (seq_len - 1, len_vocab)
        y = y.squeeze(0)
        output, hidden = net(x)
        loss   = criterion(output, y)
        loss2 += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i % (num_epochs//5) == 0:
            ss = i2s(x) + '_|_' + i2s(y) + '|_'  + i2s(output) + '|' 
            print(ss)
            
    if i % (num_epochs//5) == 0:
        print(f'epoch: {str(i):<4}, loss: {loss2:0.3f}', end="")
        clear_output(wait=True)
        
    writer.add_scalar('Training Loss', loss2, global_step=step)
    step += 1

  left  | right  |  guess  
---------------------------
Childre_|_hildren|_oildren|
singula_|_ingular|_onguVar|
prettil_|_rettily|_rottily|
chaplai_|_haplain|_roplain|
enointe_|_nointed|_nointed|
Benvoli_|_envolio|_envJlio|
Conside_|_onsider|_onsider|
unsever_|_nseverd|_nsqverd|
meeting_|_eetings|_eetings|
Bequeat_|_equeath|_engmjth|
returns_|_eturnst|_emfrnst|
remembe_|_emember|_emember|
directl_|_irectly|_erectly|
bruisin_|_ruising|_rwIsing|
Flander_|_landers|_landers|
quaintl_|_uaintly|_uaintly|
LUCENTI_|_UCENTIO|_UCENTIO|
pickloc_|_icklock|_rclloJk|
misgive_|_isgives|_eUgives|
ravenou_|_avenous|_ebenous|
soldier_|_oldiers|_oldiers|
publicl_|_ublicly|_rblicly|
Howling_|_owlings|_owlings|
overcom_|_vercome|_uercgme|
grinnin_|_rinning|_renning|
thinkin_|_hinking|_hinking|
greates_|_reatest|_renteZt|
Standin_|_tanding|_tanding|
delicat_|_elicate|_elicate|
outstri_|_utstrip|_utsAOip|
forcefu_|_orceful|_orceful|
Welshma_|_elshman|_elsHman|
Beaumon_|_eaumond|_envmond|
standin_|_tanding|_o

Supply RNN **first_n_chars** first characters of a word to generate a hidden state.<br>
It will generate first prediction. Use hidden state to iterate forward and update state.

In [33]:
# changed from original in 'neural_networks/RNN_recurrent-NN/RNN_from_scratch.ipynb'
def test(input, num_words = 5, first_n_chars = 1):
    shuffle_idx = torch.randperm(len(input))[:num_words]
    for i in shuffle_idx:
        word = input[i]

        # store known word start in solution
        store = torch.zeros_like(word, device=device)
        store[:first_n_chars] = word[:first_n_chars]

        
        # generate hidden state for known part
        hidden =  torch.zeros(1, hidden_size).to(device)
        prediction, hidden_state = net.predict(word[:first_n_chars], hidden)
        prediction = prediction[[-1]]
        store[[first_n_chars]] = prediction
    
        word_whole = i2s(word)
        word_start = i2s(word[:first_n_chars])
        word_guess_1 = i2s(prediction)    
    
        # continue predicting next chars based on latest hidden state
        for i in range(first_n_chars + 1, store.size(0)):
            prediction, hidden_state = net.predict(prediction, hidden_state)
            store[[i]] = prediction
            #prediction, hidden_state = net.predict(store[:i], hidden_state)
            #store[[i]] = prediction[[-1]]

        outp    =  word_whole + '; guess: ('
        outp += word_start + ')' + f'({word_guess_1})'
        outp += i2s(store[first_n_chars+1:])

        print(outp)
    return
    
print('whole word; guess: (given)(1st guess)rest')
test(X_train, num_words = 5, first_n_chars = 3)
X_train.shape

whole word; guess: (given)(1st guess)rest
prettily; guess: (pre)(t)tily
ravenous; guess: (rav)(e)nous
Bequeath; guess: (Beq)(c)mond
enointed; guess: (eno)(i)nted
soldiers; guess: (sol)(d)iers


torch.Size([40, 8, 50])