# Task A - Many to many

In [60]:
import torch
import torch.nn as nn

## Training data

In [61]:
char_encodings = [
    [1., 0., 0., 0., 0., 0., 0., 0.],  # ' '
    [0., 1., 0., 0., 0., 0., 0., 0.],  # 'h'
    [0., 0., 1., 0., 0., 0., 0., 0.],  # 'e'
    [0., 0., 0., 1., 0., 0., 0., 0.],  # 'l'
    [0., 0., 0., 0., 1., 0., 0., 0.],  # 'o'
    [0., 0., 0., 0., 0., 1., 0., 0.],  # 'w'
    [0., 0., 0., 0., 0., 0., 1., 0.],  # 'r'
    [0., 0., 0., 0., 0., 0., 0., 1.],  # 'd'
]

encoding_size = len(char_encodings)
index_to_char = [' ', 'h', 'e', 'l', 'o', 'w', 'r', 'd']

In [72]:
# ' hello world'
x_train = torch.tensor([[char_encodings[0]], [char_encodings[1]], [char_encodings[2]], [char_encodings[3]], 
                        [char_encodings[3]], [char_encodings[4]], [char_encodings[0]], [char_encodings[5]], 
                        [char_encodings[4]], [char_encodings[6]], [char_encodings[3]], [char_encodings[7]]]) 

In [63]:
# 'hello world '
y_train = torch.tensor([char_encodings[1], char_encodings[2], char_encodings[3], char_encodings[3],
                        char_encodings[4], char_encodings[0], char_encodings[5], char_encodings[4], 
                        char_encodings[6], char_encodings[3], char_encodings[7], char_encodings[0]]) 

## Model definition

In [64]:
class LongShortTermMemoryModel(nn.Module):
    def __init__(self, encoding_size):
        super(LongShortTermMemoryModel, self).__init__()

        self.lstm = nn.LSTM(encoding_size, 128)  # 128 is the state size
        self.dense = nn.Linear(128, encoding_size)  # 128 is the state size

    def reset(self):  # Reset states prior to new input sequence
        zero_state = torch.zeros(1, 1, 128)  # Shape: (number of layers, batch size, state size)
        self.hidden_state = zero_state
        self.cell_state = zero_state

    def logits(self, x):  # x shape: (sequence length, batch size, encoding size)
        out, (self.hidden_state, self.cell_state) = self.lstm(x, (self.hidden_state, self.cell_state))
        return self.dense(out.reshape(-1, 128))

    def f(self, x):  # x shape: (sequence length, batch size, encoding size)
        return torch.softmax(self.logits(x), dim=1)

    def loss(self, x, y):  # x shape: (sequence length, batch size, encoding size), y shape: (sequence length, encoding size)
        return nn.functional.cross_entropy(self.logits(x), y.argmax(1))

In [65]:
model = LongShortTermMemoryModel(encoding_size)

## Model training

In [66]:
learning_rate = 0.001
epochs = 500

In [67]:
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)

In [68]:
for epoch in range(epochs):
    model.reset()
    model.loss(x_train, y_train).backward()
    optimizer.step()
    optimizer.zero_grad()

    if epoch % 10 == 9:
        # Generate characters from the initial characters ' h'
        model.reset()
        text = ' h'
        model.f(torch.tensor([[char_encodings[0]]]))
        y = model.f(torch.tensor([[char_encodings[1]]]))
        text += index_to_char[y.argmax(1)]
        for c in range(50):
            y = model.f(torch.tensor([[char_encodings[y.argmax(1)]]]))
            text += index_to_char[y.argmax(1)]
        print(text)

 hlllooolddd d d d d d d d d d d d d d d d d d d d d 
 hlllo world      ll     ll     ll     ll     ll     
 hlllo world    rld    rld    rld    rld    rld    rl
 hello world   wrld   wrld   rrld   rlld  wrld   wrld
 hello world   rrld  wrld   wrld  wrrld  wrld   wrld 
 hello world  wrrld  wrld  world  wrld  world  wrld  
 hello world  wrld  world  wrld  world  wrld  world  
 hello world  wrld  world  wrld  world  wrld  world  
 hello world  wrld  world  wrld  world  wrld  world  
 hello world  wrld  world  wrld  wrlld world  wrld  w
 hello world  wrld  wrlld world  wrld  world world  w
 hello world  wrld  wrll  world  wrld  world world  w
 hello world  wrld  wrll  world world  wrld  world wo
 hello world  wrld  wrld  world world  wrld  world wo
 hello world world  wrld  wolld world world  wrld  wo
 hello world world  wrld  wrld  world world world  wr
 hello world world world  wrld  wolld world world wor
 hello world world world world world  wrld  wolld wor
 hello world world world wor

## Model testing 

In [74]:
model.reset()
text = ' h'
model.f(torch.tensor([[char_encodings[0]]]))
y = model.f(torch.tensor([[char_encodings[1]]]))
text += index_to_char[y.argmax(1)]

for c in range(50):
    y = model.f(torch.tensor([[char_encodings[y.argmax(1)]]]))
    text += index_to_char[y.argmax(1)]
    
print(text)

 hello world world world world world world world worl
