In [None]:
import torch
import torch.nn as nn
import numpy as np

dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(dev)


class LongShortTermMemoryModel(nn.Module):
    def __init__(self, encoding_size):
        super().__init__()

        self.lstm = nn.LSTM(encoding_size, 128)  # 128 is the state size
        self.dense = nn.Linear(128, encoding_size)  # 128 is the state size

    def reset(self):  # Reset states prior to new input sequence
        zero_state = torch.zeros(1, 1, 128)  # Shape: (number of layers, batch size, state size)
        self.hidden_state = zero_state
        self.cell_state = zero_state

    def logits(self, x):  # x shape: (sequence length, batch size, encoding size)
        out, (self.hidden_state, self.cell_state) = self.lstm(x, (self.hidden_state.to(dev), self.cell_state.to(dev)))
        return self.dense(out.reshape(-1, 128)).to(dev)

    def f(self, x):  # x shape: (sequence length, batch size, encoding size)
        return torch.softmax(self.logits(x), dim=1)

    def loss(self, x, y):  # x shape: (sequence length, batch size, encoding size), y shape: (sequence length, encoding size)
        return nn.functional.cross_entropy(self.logits(x), y.argmax(1)).to(dev)


encoding_size = len(char_encodings)
index_to_char = [' ', 'h', 'e', 'l', 'o', 'w', 'r', 'd']
char_encodings = np.eye(len(index_to_char))

x_train = torch.tensor([[char_encodings[0]], [char_encodings[1]], [char_encodings[2]], [char_encodings[3]], [char_encodings[3]], [char_encodings[4]], 
                       [char_encodings[0]], [char_encodings[5]], [char_encodings[4]], [char_encodings[6]], [char_encodings[3]], [char_encodings[7]]], dtype=torch.float).to(dev)  # ' hello world'
y_train = torch.tensor([char_encodings[1], char_encodings[2], char_encodings[3], char_encodings[3], char_encodings[4], char_encodings[0], 
                       char_encodings[5], char_encodings[4], char_encodings[6], char_encodings[3], char_encodings[7], char_encodings[0]], dtype=torch.float).to(dev)  # 'hello world'

model = LongShortTermMemoryModel(encoding_size).to(dev)

optimizer = torch.optim.RMSprop(model.parameters(), 0.001)
for epoch in range(500):
    model.reset()
    model.loss(x_train.to(dev), y_train.to(dev)).backward()
    optimizer.step()
    optimizer.zero_grad()

    if epoch % 10 == 9:
        # Generate characters from the initial characters ' h'
        model.reset()
        text = ' h'
        model.f(torch.tensor([[char_encodings[0]]], dtype=torch.float).to(dev))
        y = model.f(torch.tensor([[char_encodings[1]]], dtype=torch.float).to(dev))
        text += index_to_char[y.argmax(1)]
        for c in range(50):
            y = model.f(torch.tensor([[char_encodings[y.argmax(1)]]], dtype=torch.float).to(dev))
            text += index_to_char[y.argmax(1)]
        print(text)




cuda
 hllloo                                              
 hlllo wrll                                          
 hlllo world    dd    d                              
 hello world   rld    rdd   rld   rdd    rdd   rld   
 hello world   rld   rld   rrdd   rld   rld   rld   r
 hello world  wrld   rld   rld   rrld  wrld  wrld   r
 hello world  wrld  wrld  wrld  wrld  wrld  wrrdd  rr
 hello world  wrld  wrld  wrld  wrld  wrld  wrld  wrr
 hello world  wrld  wrld  wrld  wrld  wrld  wrld  wrr
 hello world  wrld  wrld  wrld  wrld  wrld  wrld  wrl
 hello world  wrld  wrld  wrld  wrld  wrld  wrld  wrl
 hello world world  wrld  wrld  wrld  wrld  wrld  wrl
 hello world world world  wrld  wrld  wrld  wrld  wrl
 hello world world world world  wrld  wrld  wrld  wrl
 hello world world world world world world  wrld  wrl
 hello world world world world world world world worl
 hello world world world world world world world worl
 hello world world world world world world world worl
 hello world world worl