In [10]:
def evaluate(model, n):
    x, y = dg.batch_to_tensor(n)
    x, y = x.to(device), y.to(device)
    y_pred = model(x)
    for i in range(n):
        print(dg.tensor_to_string(x[i]), dg.tensor_to_string(y_pred[i]), dg.tensor_to_string(y[i]))

import random
import torch
import string
from torch import nn
from torch import optim
from matplotlib import pyplot as plt
from torch.nn.utils.rnn import pad_sequence

device = 'cuda' if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else 'cpu'
device

allowed_chars = string.digits + '+'

class Generator:
    def __init__(self) -> None:
        pass

    def sample(self):
        s1 = random.randint(100, 99999)
        s2 = random.randint(100, 99999)
        r = s1 + s2
        input = str(str(s1) + "+" + str(s2))
        output = str(r)
        return self.string_to_tensor(input), self.string_to_tensor(output)

    def batch(self, n):
        inputs = []
        outputs = []
        for _ in range(n):
            input, output = self.sample()
            inputs.append(input)
            outputs.append(output)
        return inputs, outputs

    def string_to_tensor(self, s):
        tensor = torch.zeros(len(s), len(allowed_chars))
        for i, char in enumerate(s):
            tensor[i, allowed_chars.index(char)] = 1
        return tensor

    def tensor_to_string(self, tensor):
        _, max_idx = tensor.max(1)
        return ''.join([allowed_chars[i] for i in max_idx])

    def batch_to_tensor(self, n):
        seq_in = []
        seq_out = []
        inputs, outputs = self.batch(n)
        for input, output in zip(inputs, outputs):
            seq_in.append(input)
            seq_out.append(output)
        return pad_sequence(seq_in, batch_first=True), pad_sequence(seq_out, batch_first=True)

class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.rnn = nn.LSTM(input_dim, hidden_dim, batch_first=True)

    def forward(self, x, hidden):
        output, (hidden, cell) = self.rnn(x, hidden)
        return output, (hidden, cell)

class Attention(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder_hidden_states = None
        self.softmax = nn.Softmax(dim=1)
        self.tanh = nn.Tanh()

    def assign_encoder_hidden_states(self, hidden_states) -> None:
        self.encoder_hidden_states = hidden_states

    def calculate_score(self, decoder_hidden_states):
        return torch.bmm(decoder_hidden_states, self.encoder_hidden_states.transpose(1, 2))

    def source_context(self, decoder_hidden_states):
        return self.softmax(self.calculate_score(decoder_hidden_states))

    def forward(self, decoder_hidden_state):
        return self.source_context(decoder_hidden_state)


class Decoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.rnn = nn.LSTM(input_dim, hidden_dim, batch_first=True)

    def forward(self, x, hidden, cell):
        output, (hidden, cell) = self.rnn(x, (hidden, cell))
        return output, (hidden, cell)

class Seq2Seq(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.encoder = Encoder(input_size, hidden_size)
        self.attention = Attention()
        self.decoder = Decoder(hidden_size, hidden_size)
        self.reduce_dimension = nn.Linear(hidden_size, output_size)
        self.output = nn.Linear(output_size, output_size)
        self.output_size = output_size
    
    def forward(self, input, hidden=None):
        output_enc, (hn_enc, cn_enc) = self.encoder(input, hidden)
        self.attention.assign_encoder_hidden_states(output_enc)
        latent_tensor = hn_enc[0].unsqueeze(1).repeat(1, 6, 1)
        out_dec, (_, _) = self.decoder(latent_tensor, hn_enc, cn_enc)
        attention_output = self.attention(out_dec)
        return self.output(self.reduce_dimension(out_dec) * attention_output)

model = Seq2Seq(input_size=len(allowed_chars), hidden_size=128, output_size=len(allowed_chars))

history = []

def train(model, optimizer, loss_fn, n_epochs, batch_size):
    for epoch in range(n_epochs):
        total_loss = 0

        optimizer.zero_grad()
        x, y = dg.batch_to_tensor(batch_size)
        x, y = x.to(device), y.to(device)
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        print(y_pred.shape)
        print(y.shape)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        if epoch % 100 == 0:
            print("Epoch: {}, Loss: {}".format(epoch, total_loss))
            history.append(total_loss)

optimizer = optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()
dg = Generator()

train(model, optimizer, loss_fn, 45_000, 128)
plt.plot(history, label='loss')

evaluate(model, 10)

torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
Epoch: 0, Loss: 0.11836356669664383
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([128, 6, 11])
torch.Size([12

KeyboardInterrupt: 