In [2]:
import torch
import torch.nn as nn
import numpy as np

In [3]:
class LongShortTermMemoryModel(nn.Module):

    def __init__(self, encoding_size, emoji_size):
        super(LongShortTermMemoryModel, self).__init__()

        self.lstm = nn.LSTM(encoding_size, 128)  # 128 is the state size
        self.dense = nn.Linear(128, emoji_size)  # 128 is the state size

    def reset(self):  # Reset states prior to new input sequence
        zero_state = torch.zeros(1, 1, 128)  # Shape: (number of layers, batch size, state size)
        self.hidden_state = zero_state
        self.cell_state = zero_state

    def logits(self, x):  # x shape: (sequence length, batch size, encoding size)
        out, (self.hidden_state, self.cell_state) = self.lstm(x, (self.hidden_state, self.cell_state))
        return self.dense(out.reshape(-1, 128))

    def f(self, x):  # x shape: (sequence length, batch size, encoding size)
        return torch.softmax(self.logits(x), dim=1)

    def loss(self, x, y):  # x shape: (sequence length, batch size, encoding size), y shape: (sequence length, encoding size)
        return nn.functional.cross_entropy(self.logits(x), y.argmax(1))

In [38]:
chars = [" ", "h", "a", "t", "r", "c", "f", "l"]
char_encodings = []
for i, _ in enumerate(chars):
    encoding = [0.0 for _ in range(i)]
    encoding.append(1.0)
    for _ in range(len(chars) - i - 1):
        encoding.append(0.0)
    char_encodings.append(encoding)

char_encodings

[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
 [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]]

In [32]:
emojis = {
    "hat": "🎩",
    "rat": "🐀",
    "cat": "🐈",
    "flat": "🏢",   
}
for k, v in emojis.items():
    print(k, "->", v)

hat -> 🎩
rat -> 🐀
cat -> 🐈
flat -> 🏢


In [33]:
values = emojis.values()
encodings = []
for i, _ in enumerate(values):
    encoding = [0.0 for _ in range(i)]
    encoding.append(1.0)
    for _ in range(len(values) - i - 1):
        encoding.append(0.0)
    encodings.append(encoding)
encodings

[[1.0, 0.0, 0.0, 0.0],
 [0.0, 1.0, 0.0, 0.0],
 [0.0, 0.0, 1.0, 0.0],
 [0.0, 0.0, 0.0, 1.0]]

In [34]:
x_train = torch.tensor([[[char_encodings[1]], [char_encodings[2]], [char_encodings[3]], [char_encodings[0]]],
                        [[char_encodings[4]], [char_encodings[2]], [char_encodings[3]], [char_encodings[0]]], 
                        [[char_encodings[5]], [char_encodings[2]], [char_encodings[3]], [char_encodings[0]]],
                        [[char_encodings[6]], [char_encodings[7]], [char_encodings[2]], [char_encodings[3]]]], dtype=torch.float)


y_train = torch.tensor([[encodings[0] for _ in range(4)],
                        [encodings[1] for _ in range(4)],
                        [encodings[2] for _ in range(4)],
                        [encodings[3] for _ in range(4)]], dtype=torch.float)

In [35]:
model = LongShortTermMemoryModel(len(char_encodings), len(encodings))
optimizer = torch.optim.RMSprop(model.parameters(), 0.001)  # 0.001
for epoch in range(500):
    for i in range(4):
        model.reset()
        model.loss(x_train[i], y_train[i]).backward()
        optimizer.step()
        optimizer.zero_grad()

In [36]:
emoji_values = list(emojis.items())
def predict(x: str):
    y = -1
    model.reset()
    for i in range(len(x)):
        y = model.f(torch.tensor([[char_encodings[chars.index(x[i])]]], dtype=torch.float))
    print(emoji_values[y.argmax(1)])

In [37]:
predict("hat")
predict("ha")
predict("cat")
predict("catt")
predict("ca")

('hat', '🎩')
('hat', '🎩')
('cat', '🐈')
('cat', '🐈')
('cat', '🐈')
