In [2]:
import torch
import torch.nn as nn
import numpy as np

In [3]:
class LongShortTermMemoryModel(nn.Module):

    def __init__(self, encoding_size):
        super(LongShortTermMemoryModel, self).__init__()

        self.lstm = nn.LSTM(encoding_size, 128)  # 128 is the state size
        self.dense = nn.Linear(128, encoding_size)  # 128 is the state size

    def reset(self):  # Reset states prior to new input sequence
        zero_state = torch.zeros(1, 1, 128)  # Shape: (number of layers, batch size, state size)
        self.hidden_state = zero_state
        self.cell_state = zero_state

    def logits(self, x):  # x shape: (sequence length, batch size, encoding size)
        out, (self.hidden_state, self.cell_state) = self.lstm(x, (self.hidden_state, self.cell_state))
        return self.dense(out.reshape(-1, 128))

    def f(self, x):  # x shape: (sequence length, batch size, encoding size)
        return torch.softmax(self.logits(x), dim=1)

    def loss(self, x, y):  # x shape: (sequence length, batch size, encoding size), y shape: (sequence length, encoding size)
        return nn.functional.cross_entropy(self.logits(x), y.argmax(1))

In [18]:
chars = [' ', 'h', 'e', 'l', 'o', 'w', 'r', 'd']
char_encodings = []
for i, _ in enumerate(chars):
    encoding = [0.0 for _ in range(i)]
    encoding.append(1.0)
    for _ in range(len(chars) - i - 1):
        encoding.append(0.0)
    char_encodings.append(encoding)

char_encodings

[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
 [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]]

In [34]:
def code_char(character, nested):
    if nested:
        return [char_encodings[chars.index(character)]]
    else:
        return char_encodings[chars.index(character)]

def encode_str(string, nest=False):
    if nest:
        return [[char_encodings[chars.index(c)]] for c in string]
    else:
        return [char_encodings[chars.index(c)] for c in string]

In [35]:
x_train = torch.tensor(encode_str(" hello world", nest=True))
y_train = torch.tensor(encode_str("hello world "))

In [41]:
model = LongShortTermMemoryModel(len(chars))

optimizer = torch.optim.RMSprop(model.parameters(), 0.001)
epochs = 500
for epoch in range(epochs):
    model.reset()
    model.loss(x_train, y_train).backward()
    optimizer.step()
    optimizer.zero_grad()

    if epoch % 10 == 9:
        text = ' h'
        model.f(torch.tensor([[char_encodings[0]]]))
        y = model.f(torch.tensor([[char_encodings[1]]]))
        text += chars[y.argmax(1)]
        for c in range(50):
            y = model.f(torch.tensor([[char_encodings[y.argmax(1)]]]))
            text += chars[y.argmax(1)]
            
        print(text)

 h                                                   
 h  rdd   rrdd   rrd    rdd   rrdd   rrd    rdd   rrd
 h rrld  wrrld  wrrld  wrrld  wrrld  wrrld  wrrld  wr
 hwrrld  wrld  world  wrrld  wrrld  wrld  world  wrrl
 hwrrld world  wrrld  wrld  world  wrrld  wrld  world
 hwrld  world  wrlld world  wrrld world  wrrld world 
 horld  world world  world world  world world  world 
 horld  wrrld world  wrrld world  wrlld world  wrlld 
 horld  wrrld world  wrlld world  wrlld world  wrlld 
 horld  wrlld world world  world world  wrlld world w
 horld  wrlld world world  wrrld world world  world w
 horld  wrlld world world  wrlld world world  wrlld w
 horld world  wrrld world world  wrlld world world  w
 horld world  wrlld world world  wrlld world world  w
 horld world  wrlld world world world  wrlld world wo
 horld world world  wrlld world world world  wrlld wo
 horld world world  wrlld world world world  wrlld wo
 horld world world world  wrlld world world world wor
 horld world world world  wr