In [51]:
#from data_generator_with_sos import generate
# data generator in next cell. It's just to translate number words (e.g. 'eight','nine',etc) to digits
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import math

device =  torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device in use:", device)

NUM_INPUTS = 28 #No. of possible characters
NUM_OUTPUTS = 12  # (0-9 + '#')

### Hyperparameters and general configs
MAX_SEQ_LEN = 8
MIN_SEQ_LEN = 5
TRAINING_SIZE = 100
LEARNING_RATE = 0.003
N_LAYERS = 2
DROPOUT = 0.5

# Hidden size of enc and dec need to be equal if last hidden of encoder becomes init hidden of decoder
# Otherwise we would need e.g. a linear layer to map to a space with the correct dimension
NUM_UNITS_ENC = NUM_UNITS_DEC = 256
HIDDEN_DIM = 512
TEST_SIZE = 100
EPOCHS = 10
TEACHER_FORCING = 0.5
NUM_OF_BATCHES=8

# assert TRAINING_SIZE % NUM_OF_BATCHES == 0

Device in use: cpu


In [None]:
target_to_text = {
    "1": "one",
    "2": "two",
    "3": "three",
    "4": "four",
    "5": "five",
    "6": "six",
    "7": "seven",
    "8": "eight",
    "9": "nine",
}
EOS = "#"
SOS='*'
PAD = "0"

input_characters = " ".join(target_to_text.values())
valid_characters = [
    PAD,
    "1",
    "2",
    "3",
    "4",
    "5",
    "6",
    "7",
    "8",
    "9",
    EOS,
    SOS,
] + list(set(input_characters))


def print_valid_characters():
    l = ""
    for i, c in enumerate(valid_characters):
        l += "'%s'=%i,\t" % (c, i)
    print("Number of valid characters:", len(valid_characters))
    print(l)


ninput_chars = len(valid_characters)


def generate(
    num_batches=10, batch_size=100, min_len=3, max_len=3, invalid_set=set()
):
    """
    Generates random sequences of integers and translates them to text i.e. 1->'one'.

    :param batch_size: number of samples to return
    :param min_len: minimum length of target
    :param max_len: maximum length of target
    :param invalid_set: set of invalid 'text targets out', e.g. if we want to avoid samples
    that are already in the training set
    """

    text_inputs = []
    int_inputs = []
    text_targets_in = []
    text_targets_out = []
    int_targets_in = []
    int_targets_out = []
    inputs = []
    targets_in = []
    targets_out = []
    batch_target_max_len = np.zeros((1, num_batches))
    targets_mask = []
    batch_input_max_len = np.zeros((1, num_batches))
    inputs_len = []
    _printed_warning = False
    # loop through number of batches
    for i in range(num_batches):
        temp_text_inputs = []
        temp_int_inputs = []
        temp_text_targets_in = []
        temp_text_targets_out = []
        temp_int_targets_in = []
        temp_int_targets_out = []
        iterations = 0
        # while loop until number of rows per batch is reached
        while len(temp_text_inputs) < batch_size:
            iterations += 1

            # choose random sequence length
            tar_len = np.random.randint(min_len, max_len + 1)

            # list of text digits
            text_target = inp_str = "".join(
                map(str, np.random.randint(1, 10, tar_len))
            )
            text_target_in = SOS + text_target
            text_target_out = text_target + EOS

            # generate the targets as a list of integers
            int_target_in = map(
                lambda c: valid_characters.index(c), text_target_in
            )
            int_target_in = list(int_target_in)

            int_target_out = map(
                lambda c: valid_characters.index(c), text_target_out
            )
            int_target_out = list(int_target_out)

            # generate the text input
            text_input = " ".join(map(lambda k: target_to_text[k], inp_str))

            # generate the inputs as a list of integers
            int_input = map(lambda c: valid_characters.index(c), text_input)
            int_input = list(int_input)

            if not _printed_warning and iterations > 5 * batch_size:
                print(
                    "WARNING: doing a lot of iterations because I'm trying to generate a batch that does not"
                    " contain samples from 'invalid_set'."
                )
                _printed_warning = True

            if text_target_out in invalid_set:
                continue
            # append created row to temp arrays
            temp_text_inputs.append(text_input)
            temp_int_inputs.append(int_input)
            temp_text_targets_in.append(text_target_in)
            temp_text_targets_out.append(text_target_out)
            temp_int_targets_in.append(int_target_in)
            temp_int_targets_out.append(int_target_out)
        # turn the temp arrays into tensors and pad them to max length

        # append completed temp batch array to full array of batches
        text_inputs.append(temp_text_inputs)
        int_inputs.append(temp_int_inputs)
        text_targets_in.append(temp_text_targets_in)
        text_targets_out.append(temp_text_targets_out)
        int_targets_in.append(temp_int_targets_in)
        int_targets_out.append(temp_int_targets_out)

        max_target_out_len = max(map(len, int_targets_out[-1]))
        max_input_len = max(map(len, int_inputs[-1]))
        targets_mask_tmp = np.zeros((batch_size, max_target_out_len))
        add_targets_out = np.full((batch_size, max_target_out_len), int(PAD))
        len_arr = [-len(thing) for thing in temp_int_inputs]
        sorted_arr = np.argsort(len_arr)
        add_targets_in = np.full((batch_size, max_target_out_len), int(PAD))
        add_inputs = np.full((batch_size, max_input_len), int(PAD))
        tmp_inputs_len = np.zeros(len(sorted_arr))
        for short_index, row in enumerate(sorted_arr):
            tmp_element = temp_int_inputs[row]
            add_inputs[short_index, : len(tmp_element)] = tmp_element
            tmp_inputs_len[short_index] = len(tmp_element)

            tmp_element = temp_int_targets_in[row]
            add_targets_in[short_index, : len(tmp_element)] = tmp_element

            tmp_element = temp_int_targets_out[row]
            add_targets_out[short_index, : len(tmp_element)] = tmp_element
            targets_mask_tmp[short_index, : len(tmp_element)] = 1
        inputs_len.append(tmp_inputs_len)
        targets_mask.append(targets_mask_tmp)
        inputs.append(add_inputs.astype("int32"))
        targets_in.append(add_targets_in.astype("int32"))
        targets_out.append(add_targets_out.astype("int32"))
        target_in_seq_lengths = torch.LongTensor(
            list(map(len, temp_int_targets_in))
        )
        input_seq_lengths = torch.LongTensor(list(map(len, temp_int_inputs)))

        batch_target_max_len[0, i] = target_in_seq_lengths.max()
        batch_input_max_len[0, i] = input_seq_lengths.max()
    return (
        inputs,
        batch_input_max_len.astype("int32"),
        targets_in,
        targets_out,
        batch_target_max_len.astype("int32"),
        targets_mask,
        text_inputs,
        text_targets_in,
        text_targets_out,
        inputs_len,
    )



def main():
    batch_size = 3
    (
        inputs,
        inputs_seqlen,
        targets_in,
        targets_out,
        targets_seqlen,
        targets_mask,
        text_inputs,
        text_targets_in,
        text_targets_out,
        inputs_len,
    ) = generate(8, 10, min_len=1, max_len=2)

    print_valid_characters()
    print("Stop/start character = #")

    for i in range(batch_size):
        print("\nSAMPLE", i)
        print("TEXT INPUTS:\t\t\t", text_inputs[i])
        print("ENCODED INPUTS:\t\t\t\n", inputs[i])
        print("INPUTS SEQUENCE LENGTH:\t\n", inputs_seqlen)
        print("TEXT TARGETS INPUT:\t\t", text_targets_in[i])
        print("TEXT TARGETS OUTPUT:\t", text_targets_out[i])
        print("ENCODED TARGETS INPUT:\t\n", targets_in[i])
        print("ENCODED TARGETS OUTPUT:\t\n", targets_out[i])
        print("TARGETS SEQUENCE LENGTH:", targets_seqlen)
        print("TARGETS MASK:\t\t\t\n", targets_mask[i])
        print("INPUTS LEN:\t\t\t\n", inputs_len[i])


if __name__ == "__main__":
    main()


In [53]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, emb_size, hidden_size, dropout):
        super().__init__()
        self.hidden_size = hidden_size
        self.emb_size = emb_size

        self.embedding = nn.Embedding(input_size, self.emb_size)
        self.rnn = nn.GRU(
            self.emb_size,
            self.hidden_size,
            bidirectional=True,
            batch_first=True
        )

        self.fc = nn.Linear(self.hidden_size * 2, self.hidden_size)

        self.dropout = nn.Dropout(dropout)

    def forward(self, inputs, hidden, inputs_len):
        # Input shape [batch, seq_in_len]
        # inputs = [inputs[0],inputs[2]]
        inputs = inputs.long()

        # Embedded shape [batch, seq_in_len, embed]
        embedded = self.dropout(self.embedding(inputs))
        # embedded = embedded.view(embedded.shape[0]*embedded.shape[1],embedded.shape[2],embedded.shape[3])

        # pack_padded_sequence so that padded items in the sequence won't be shown to the LSTM
        packed_embedded = torch.nn.utils.rnn.pack_padded_sequence(
            embedded, inputs_len, batch_first=True
        )

        packed_outputs, hidden = self.rnn(packed_embedded)

        # Output shape [batch, seq_in_len, embed]
        # Hidden shape [1, batch, embed], last hidden state of the GRU cell
       
        outputs, _ = nn.utils.rnn.pad_packed_sequence(
            packed_outputs, batch_first=True
        )
        hidden = torch.tanh(self.fc(torch.cat((hidden[-2], hidden[-1]), dim=1)))
        return outputs, hidden

    def init_hidden(self, batch_size):
        init = torch.zeros(1, batch_size, self.hidden_size, device=device)
        return init


In [54]:
class Attention(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()

        self.attn = nn.Linear((hidden_size * 2) + hidden_size, hidden_size)
        self.v = nn.Parameter(torch.rand(hidden_size))

    def forward(self, hidden, encoder_outputs, mask):

        # hidden = [batch size, dec hid dim]
        # encoder_outputs = [src sent len, batch size, enc hid dim * 2]
        # mask = [batch size, src sent len]

        batch_size = encoder_outputs.shape[0]
        src_len = encoder_outputs.shape[1]

        # encoder output =  [33, 8, 512], hidden = [8, 256]
        # print(encoder_outputs.shape[0], encoder_outputs.shape[1])

        # repeat encoder hidden state src_len times
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)

        # hidden = [batch size, src sent len, dec hid dim]
        # encoder_outputs = [batch size, src sent len, enc hid dim * 2]

        energy = torch.tanh(
            self.attn(torch.cat((hidden, encoder_outputs), dim=2))
        )

        # energy = [batch size, src sent len, dec hid dim]
        energy = energy.permute(0, 2, 1)
        # energy = [batch size, dec hid dim, src sent len]

        # v = [dec hid dim]
        v = self.v.repeat(batch_size, 1).unsqueeze(1)
        # v = [batch size, 1, dec hid dim]

        attention = torch.bmm(v, energy).squeeze(1)
        # attention = [batch size, src sent len]
        attention = attention.masked_fill(mask == 0, -1e10)

        return F.softmax(attention, dim=1)

In [55]:
class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, emb_size, output_size, dropout, attention):
        super().__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.emb_size = emb_size
        self.attention = attention

        self.embedding = nn.Embedding(self.output_size, self.emb_size)
        self.out = nn.Linear(
            (self.hidden_size * 2) + self.hidden_size + self.emb_size,
            output_size
        )
        self.rnn = nn.GRU((self.hidden_size * 2) + self.emb_size, self.hidden_size,num_layers=1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, inputs, hidden, encoder_outputs, mask):
        # Input shape: [batch, output_len]
        # Hidden shape: [seq_len=1, batch_size, hidden_dim] (the last hidden state of the encoder)
        dec_input = inputs.unsqueeze(1)
        embedded = self.dropout(self.embedding(dec_input))
        embedded = embedded.permute(1, 0, 2)
      
        a = self.attention(hidden, encoder_outputs, mask)
        a = a.unsqueeze(1)
    
        weighted = torch.bmm(a, encoder_outputs)
        weighted = weighted.permute(1, 0, 2)
        rnn_input = torch.cat((embedded, weighted), dim=2)
        out, hidden = self.rnn(rnn_input,hidden.unsqueeze(0))
        assert (out == hidden).all()
        embedded = embedded.squeeze(0)
        out = out.squeeze(0)
        weighted = weighted.squeeze(0)

        output = self.out(torch.cat((out, weighted, embedded), dim=1))
      # [batch_size x seq_len x output_size]
        hidden = hidden.squeeze(0)

        return output, hidden.squeeze(0), a.squeeze(1)

In [57]:
def create_mask(src):
    mask = src != 0  # .permute(1, 0)
    return mask

In [64]:
def forward_pass(
    encoder, decoder, x, t, t_in, x_len, criterion, teacher_forcing_ratio=0.5
):
    """
    Executes a forward pass through the whole model.

    :param encoder:
    :param decoder:
    :param x: input to the encoder, shape [batch, seq_in_len]
    :param t: target output predictions for decoder, shape [batch, seq_t_len]
    :param criterion: loss function
    :param max_t_len: maximum target length

    :return: output (after log-softmax), loss, accuracy (per-symbol)
    """
    batch_size = x.shape[0]
    trg_len = t_in.shape[1]
    trg_vocab_size = NUM_OUTPUTS

    # tensor to store decoder outputs
    outputs = torch.zeros(trg_len, batch_size, trg_vocab_size)

    # Run encoder and get last hidden state (and output)

    enc_h = encoder.init_hidden(batch_size)
    enc_out, enc_h = encoder(x, enc_h, x_len)

    # first input to the decoder is the <sos> tokens
    inputs = t_in[:, 0]
    dec_h = enc_h
    mask = create_mask(x)
    for i in range(1, trg_len + 1):

        # insert input token embedding, previous hidden state, all encoder hidden states
        #  and mask
        # receive output tensor (predictions) and new hidden state
        output, dec_h, _ = decoder(inputs, dec_h, enc_out, mask)

        # place predictions in a tensor holding predictions for each token
        outputs[i - 1] = output

        # decide if we are going to use teacher forcing or not
        teacher_force = random.random() < teacher_forcing_ratio

        # get the highest predicted token from our predictions
        top1 = output.argmax(1)
        # if teacher forcing, use actual next token as next input
        # if not, use predicted token
        if i < trg_len:
            inputs = t_in[:, i] if teacher_force else top1

    out = outputs.permute(1, 2, 0)
    # Shape: [batch_size x num_classes x out_sequence_len], with second dim containing log probabilities
    loss = criterion(out, t)
    pred = get_pred(log_probs=out)
    accuracy = (pred == t).type(torch.FloatTensor).mean()

    return out, loss, accuracy

In [60]:
def train(
    encoder,
    decoder,
    inputs,
    targets,
    targets_in,
    criterion,
    enc_optimizer,
    dec_optimizer,
    epoch,
    inputs_len
):
    encoder.train()
    decoder.train()
    epoch_loss = 0
    for batch_idx, (x, t, t_in, x_len) in enumerate(
        zip(inputs, targets, targets_in, inputs_len)
    ):
        # print(x.shape)
        x = torch.LongTensor(x).to(device)
        t = torch.LongTensor(t).to(device)
        t_in = torch.LongTensor(t_in).to(device)
        x_len = torch.LongTensor(x_len).to(device)

        enc_optimizer.zero_grad()
        dec_optimizer.zero_grad()

        # print(batch_idx)
        #         inputs = inputs.to(device)
        #         targets = targets.long()
        #         targets_in = targets_in.long()
        out, loss, accuracy = forward_pass(
            encoder,
            decoder,
            x,
            t,
            t_in,
            x_len,
            criterion,
            teacher_forcing_ratio=TEACHER_FORCING
        )

        loss.backward()
        enc_optimizer.step()
        dec_optimizer.step()
        if batch_idx % 200 == 0:
            print(
                "Epoch {} [{}/{} ({:.0f}%)]\tTraining loss: {:.4f} \tTraining accuracy: {:.1f}%".format(
                    epoch,
                    batch_idx * len(x),
                    TRAINING_SIZE * NUM_OF_BATCHES,
                    100.0
                    * batch_idx
                    * len(x)
                    / (TRAINING_SIZE * NUM_OF_BATCHES),
                    loss.item(),
                    100.0 * accuracy.item()
                )
            )

In [61]:
def test(encoder, decoder, inputs, targets, targets_in, inputs_len, criterion):
    encoder.eval()
    decoder.eval()
    epoch_loss = 0
    with torch.no_grad():
        inputs = inputs.view(inputs.shape[1], inputs.shape[2])
        targets = targets.view(targets.shape[1], targets.shape[2])
        targets_in = targets_in.view(targets_in.shape[1], targets_in.shape[2])
        inputs_len = torch.LongTensor(inputs_len[0]).to(device)

        out, loss, accuracy = forward_pass(encoder,decoder,inputs, targets,targets_in,
                                inputs_len,criterion,teacher_forcing_ratio=TEACHER_FORCING
        )
        # print(out.shape,targets_in.shape)
    return out, loss, accuracy

In [62]:
def numbers_to_text(seq):
    return "".join([str(to_np(i)) if to_np(i) != 10 else "#" for i in seq])


def to_np(x):
    return x.cpu().numpy()


def get_pred(log_probs):
    """
    Get class prediction (digit prediction) from the net's output (the log_probs)
    :param log_probs: Tensor of shape [batch_size x n_classes x sequence_len]
    :return:
    """
    return torch.argmax(log_probs, dim=1)

In [65]:
attn = Attention(NUM_UNITS_ENC)
encoder = EncoderRNN(NUM_INPUTS, HIDDEN_DIM, NUM_UNITS_ENC, DROPOUT).to(device)
decoder = DecoderRNN(NUM_UNITS_DEC, HIDDEN_DIM, NUM_OUTPUTS, DROPOUT, attn).to(
    device
)
enc_optimizer = optim.RMSprop(encoder.parameters(), lr=LEARNING_RATE)
dec_optimizer = optim.RMSprop(decoder.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss(reduction="mean", ignore_index=0)

# Get training set

inputs,_,targets_in,targets,targets_seqlen,_,text,_,text_targ,inputs_len= generate(
    TRAINING_SIZE, NUM_OF_BATCHES, min_len=MIN_SEQ_LEN, max_len=MAX_SEQ_LEN
)
max_target_len = max(targets_seqlen)
unique_text_targets = set([i for x in text_targ for i in x])

# Get validation set
(
    val_inputs,_,val_targets_in,val_targets,val_targets_seqlen,_,
    val_text_in,_,val_text_targ,val_inputs_len
) = generate(
    1,
    TEST_SIZE,
    min_len=MIN_SEQ_LEN,
    max_len=MAX_SEQ_LEN,
    invalid_set=unique_text_targets,
)

val_inputs = torch.LongTensor(val_inputs).to(device)
val_targets = torch.LongTensor(val_targets).to(device)
val_targets_in = torch.LongTensor(val_targets_in).to(device)
val_inputs_len = torch.LongTensor(val_inputs_len).to(device)
max_val_target_len = max(val_targets_seqlen)


# Quick and dirty - just loop over training set without reshuffling
for epoch in range(1, EPOCHS + 1):
    train(
        encoder,
        decoder,
        inputs,
        targets,
        targets_in,
        criterion,
        enc_optimizer,
        dec_optimizer,
        epoch,
        inputs_len
    )
    _, loss, accuracy = test(
        encoder,
        decoder,
        val_inputs,
        val_targets,
        val_targets_in,
        val_inputs_len,
        criterion
    )
    print(
        "\nTest set: Average loss: {:.4f} \tAccuracy: {:.3f}%\n".format(
            loss, accuracy.item() * 100.0
        )
    )

    # Show examples
    print("Examples: prediction | input")
    out, _, _ = test(
        encoder,
        decoder,
        val_inputs[:10],
        val_targets[:10],
        val_targets_in[:10],
        val_inputs_len[:10],
        criterion
    )
    pred = get_pred(out)
    pred_text = [numbers_to_text(sample) for sample in pred]
    for i in range(9):
        print(pred_text[i], "\t", val_text_in[0][i])
    print()


Test set: Average loss: 2.2217 	Accuracy: 24.333%

Examples: prediction | input
68873#3#6 	 four seven nine five seven six
1779#7##9 	 one six one eight four three three nine
777337### 	 five seven one nine eight seven
9775#7#7# 	 seven eight seven three one three five five
771797##9 	 nine two seven eight five six two four
566#6#66# 	 five four one four five
77675#6#6 	 five six two nine three eight two
261616##6 	 three four nine three one
499#9#9#9 	 six six six one one three


Test set: Average loss: 1.8242 	Accuracy: 32.111%

Examples: prediction | input
388888888 	 four seven nine five seven six
199979### 	 one six one eight four three three nine
77175#### 	 five seven one nine eight seven
399779797 	 seven eight seven three one three five five
877999999 	 nine two seven eight five six two four
56963#### 	 five four one four five
7765791## 	 five six two nine three eight two
29199963# 	 three four nine three one
499999999 	 six six six one one three


Test set: Average loss: 1.6