In [34]:
import torch
import torch.nn as nn

char_encodings = [
    [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],  # ' '
    [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],  # 'h'
    [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],  # 'a'
    [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],  # 't'
    [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],  # 'r'
    [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],  # 'c'
    [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],  # 'f'
    [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],  # 'l'
    [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],  # 'm'
    [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],  # 'p'
    [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],  # 's'
    [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],  # 'o'
    [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]   # 'n'
]
char_encoding_size = len(char_encodings)

index_to_char = [' ', 'h', 'a', 't', 'r', 'c', 'f', 'l', 'm', 'p', 's', 'o', 'n']

hat = [[char_encodings[1]], [char_encodings[2]], [char_encodings[3]], [char_encodings[0]]]
rat = [[char_encodings[4]], [char_encodings[2]], [char_encodings[3]], [char_encodings[0]]]
cat = [[char_encodings[5]], [char_encodings[2]], [char_encodings[3]], [char_encodings[0]]]
flat = [[char_encodings[6]], [char_encodings[7]], [char_encodings[2]], [char_encodings[3]]]
matt = [[char_encodings[8]], [char_encodings[2]], [char_encodings[3]], [char_encodings[3]]]
cap =  [[char_encodings[5]], [char_encodings[2]], [char_encodings[9]], [char_encodings[0]]]
son =  [[char_encodings[10]], [char_encodings[11]], [char_encodings[12]], [char_encodings[0]]]

hat_y = [char_encodings[0], char_encodings[0], char_encodings[0], char_encodings[0]]
rat_y = [char_encodings[1], char_encodings[1], char_encodings[1], char_encodings[1]]
cat_y = [char_encodings[2], char_encodings[2], char_encodings[2], char_encodings[2]]
flat_y = [char_encodings[3], char_encodings[3], char_encodings[3], char_encodings[3]]
matt_y = [char_encodings[4], char_encodings[4], char_encodings[4], char_encodings[4]]
cap_y =  [char_encodings[5], char_encodings[5], char_encodings[5], char_encodings[5]]
son_y =  [char_encodings[6], char_encodings[6], char_encodings[6], char_encodings[6]]

emojis = {
        'hat': '\U0001F3A9',
        'cat': '\U0001F408',
        'rat': '\U0001F400',
        'flat': '\U0001F3E2',
        'matt': '\U0001F468',
        'cap': '\U0001F9E2',
        'son': '\U0001F466'
    }

index_to_emoji = [emojis['hat'], emojis['rat'], emojis['cat'], emojis['flat'], emojis['matt'], emojis['cap'], emojis['son']]

emoji_encoding = [
    [1., 0., 0., 0., 0., 0., 0.],  # 'hat'
    [0., 1., 0., 0., 0., 0., 0.], # 'cat'
    [0., 0., 1., 0., 0., 0., 0.], # 'rat'
    [0., 0., 0., 1., 0., 0., 0.], # 'flat'
    [0., 0., 0., 0., 1., 0., 0.],  # 'matt'
    [0., 0., 0., 0., 0., 1., 0.],  # 'cap'
    [0., 0., 0., 0., 0., 0., 1.]  # 'son'
]
emoji_encoding_size = len(emoji_encoding)

x_train = torch.tensor([hat,
                        rat,
                        cat,
                        flat,
                        matt,
                        cap,
                        son])  # ' All the words '
y_train = torch.tensor([hat_y,
                        rat_y,
                        cat_y,
                        flat_y,
                        matt_y,
                        cap_y,
                        son_y])  # ' All the words'

test_words_list = ["rt", "rats", "sn", "at", "cat", "hats"]

In [35]:
class LongShortTermMemoryModel(nn.Module):
    def __init__(self, encoding_size, label_size):
        super(LongShortTermMemoryModel, self).__init__()

        self.lstm = nn.LSTM(char_encoding_size, 128)  # 128 is the state size
        self.dense = nn.Linear(128, emoji_encoding_size)  # 128 is the state size

    def reset(self, batch_size = 1):  # Reset states prior to new input sequence
        zero_state = torch.zeros(1, batch_size, 128)  # Shape: (number of layers, batch size, state size)
        self.hidden_state = zero_state
        self.cell_state = zero_state

    def logits(self, x):  # x shape: (sequence length, batch size, encoding size)
        out, (self.hidden_state, self.cell_state) = self.lstm(x, (self.hidden_state, self.cell_state))
        return self.dense(out.reshape(-1, 128))

    def f(self, x):  # x shape: (sequence length, batch size, encoding size)
        return torch.softmax(self.logits(x), dim=1)

    def loss(self, x, y):  # x shape: (sequence length, batch size, encoding size), y shape: (sequence length, encoding size)
        return nn.functional.cross_entropy(self.logits(x), y.argmax(1))

In [36]:
model = LongShortTermMemoryModel(char_encoding_size, emoji_encoding_size)

optimizer = torch.optim.RMSprop(model.parameters(), 0.001)


def test_words(word_list):
    for word in word_list:
        model.reset()
        for letter in range(len(word)):
            index = index_to_char.index(word[letter])
            emoji_index = model.f(torch.tensor([[char_encodings[index]]]))
            if letter == len(word) -1:
                print("word: %s prediction: %s" %(word, index_to_emoji[emoji_index.argmax(1)]))


for epoch in range(501):
    for i in range(7):
        model.reset()
        loss = model.loss(x_train[i], y_train[i])
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    if epoch % 100 == 0:
        print("Epoch: %s" % epoch)
        test_words(test_words_list)
        print("\n")

Epoch: 0
word: rt prediction: 🎩
word: rats prediction: 🎩
word: sn prediction: 🎩
word: at prediction: 🎩
word: cat prediction: 🎩
word: hats prediction: 🎩


Epoch: 100
word: rt prediction: 🐀
word: rats prediction: 🐀
word: sn prediction: 👦
word: at prediction: 👨
word: cat prediction: 🐈
word: hats prediction: 🎩


Epoch: 200
word: rt prediction: 🐀
word: rats prediction: 🐀
word: sn prediction: 👦
word: at prediction: 🐀
word: cat prediction: 🐈
word: hats prediction: 🎩


Epoch: 300
word: rt prediction: 🐀
word: rats prediction: 🐀
word: sn prediction: 👦
word: at prediction: 🎩
word: cat prediction: 🐈
word: hats prediction: 🎩


Epoch: 400
word: rt prediction: 🐀
word: rats prediction: 🐀
word: sn prediction: 👦
word: at prediction: 🐀
word: cat prediction: 🐈
word: hats prediction: 🎩


Epoch: 500
word: rt prediction: 🐀
word: rats prediction: 🐀
word: sn prediction: 👦
word: at prediction: 🐀
word: cat prediction: 🐈
word: hats prediction: 🎩


