In [1]:
import re
import time
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
FILE_NAME = './data/nietzsche.txt'
CAESAR_OFFSET = 2
MAX_LEN = 50 
STEP = 50
CHARS = list('abcdefghijklmnopqrstuvwxyz ')
INDEX_TO_CHAR = [w for w in sorted(CHARS)]
CHAR_TO_INDEX = {w: i for i, w in enumerate(INDEX_TO_CHAR)}
BATCH_SIZE = 512
NUM_EPOCHS = 15

In [8]:
def load_and_vanilla_preprocess(txt_path):
    with open(txt_path, encoding='utf-8') as txt_file:
        text = txt_file.read().lower()
    text = re.sub('[^a-z ]', ' ', text)
    text = re.sub('\s+', ' ', text)
    txt_file.close()
    return text

def text_to_sentences(text):
    sentences = []
    for i in range(0, len(text) - MAX_LEN, STEP):
        sentences.append(text[i: i + MAX_LEN])
    return sentences

def tokenize(text):
    return [[c for c in ph] for ph in text if type(ph) is str]

def caesor(s, shift):
    res = ''
    for c in s:
        if c not in CHARS:
            res += ' '
        else:
            res += CHARS[(CHARS.index(c.lower()) + shift) % len(CHARS)]
    return res

raw_text = load_and_vanilla_preprocess(FILE_NAME)

sentences = text_to_sentences(raw_text)

text = tokenize(sentences)


In [9]:
def vectorize(text):
    X = torch.zeros((len(text), MAX_LEN), dtype=int)
    Y = torch.zeros((len(text), MAX_LEN), dtype=int)

    for i in range(len(text)):
        for j, w in enumerate(text[i]):
            if j >= MAX_LEN:
                break
            X[i, j] = CHAR_TO_INDEX.get(caesor(w, CAESAR_OFFSET), CHAR_TO_INDEX[' '])
            Y[i, j] = CHAR_TO_INDEX.get(w, CHAR_TO_INDEX[' '])
    return X, Y

X, Y = vectorize(text)

dataset = torch.utils.data.TensorDataset(X, Y)

In [10]:
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

train_dl = torch.utils.data.DataLoader(train_dataset, BATCH_SIZE, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_dataset, BATCH_SIZE, shuffle=True)

In [11]:
class RNNModel(torch.nn.Module):
    
    def __init__(self):
        super().__init__()
        self.embed = torch.nn.Embedding(len(CHARS) + CAESAR_OFFSET, 32)
        self.rnn = torch.nn.RNN(32, 128, batch_first=True)
        self.linear = torch.nn.Linear(128, len(CHARS) + CAESAR_OFFSET)

    def forward(self, sentence, state=None):
        embed = self.embed(sentence)
        o, h = self.rnn(embed)
        return self.linear(o)

In [12]:
model = RNNModel()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=.05)

In [13]:
for epoch in range(NUM_EPOCHS):
    train_loss, train_acc, iter_num = .0, .0, .0
    start_epoch_time = time.time()
    model.train()
    for x_in, y_in in train_dl:
        x_in = x_in
        y_in = y_in.view(1, -1).squeeze()
        optimizer.zero_grad()
        out = model.forward(x_in).view(-1, len(CHARS) + CAESAR_OFFSET)
        l = criterion(out, y_in)
        train_loss += l.item()
        batch_acc = (out.argmax(dim=1) == y_in)
        train_acc += batch_acc.sum().item() / batch_acc.shape[0]
        l.backward()
        optimizer.step()
        iter_num += 1
    print(
        f"Epoch: {epoch}, loss: {train_loss:.4f}, acc: "
        f"{train_acc / iter_num:.4f}",
        end=" | "
    )
    test_loss, test_acc, iter_num = .0, .0, .0
    model.eval()
    for x_in, y_in in test_dl:
        x_in = x_in
        y_in = y_in.view(1, -1).squeeze()
        out = model.forward(x_in).view(-1, len(CHARS) + CAESAR_OFFSET)
        l = criterion(out, y_in)
        test_loss += l.item()
        batch_acc = (out.argmax(dim=1) == y_in)
        test_acc += batch_acc.sum().item() / batch_acc.shape[0]
        iter_num += 1
    print(
        f"test loss: {test_loss:.4f}, test acc: {test_acc / iter_num:.4f} | "
        f"{time.time() - start_epoch_time:.2f} sec."
    )

Epoch: 0, loss: 51.3414, acc: 0.4844 | test loss: 10.3585, test acc: 0.7190 | 1.92 sec.
Epoch: 1, loss: 32.3631, acc: 0.7679 | test loss: 6.9035, test acc: 0.7958 | 1.56 sec.
Epoch: 2, loss: 22.1778, acc: 0.8137 | test loss: 4.9036, test acc: 0.8378 | 1.48 sec.
Epoch: 3, loss: 16.2316, acc: 0.8612 | test loss: 3.7295, test acc: 0.8823 | 1.52 sec.
Epoch: 4, loss: 12.6501, acc: 0.8963 | test loss: 2.9649, test acc: 0.9089 | 1.46 sec.
Epoch: 5, loss: 10.2176, acc: 0.9214 | test loss: 2.4302, test acc: 0.9378 | 1.78 sec.
Epoch: 6, loss: 8.4100, acc: 0.9495 | test loss: 2.0202, test acc: 0.9565 | 1.47 sec.
Epoch: 7, loss: 7.0602, acc: 0.9606 | test loss: 1.7117, test acc: 0.9636 | 1.51 sec.
Epoch: 8, loss: 6.0229, acc: 0.9660 | test loss: 1.4689, test acc: 0.9703 | 1.42 sec.
Epoch: 9, loss: 5.2029, acc: 0.9744 | test loss: 1.2791, test acc: 0.9792 | 1.41 sec.
Epoch: 10, loss: 4.5598, acc: 0.9821 | test loss: 1.1272, test acc: 0.9848 | 1.43 sec.
Epoch: 11, loss: 4.0344, acc: 0.9878 | test lo

In [16]:
def decoder(phrase, model):
    phrase_idx = [CHAR_TO_INDEX[k] for k in phrase]
    result = model.forward(torch.tensor([phrase_idx])).argmax(dim=2)
    return "".join([INDEX_TO_CHAR[item.item()] for i, item in enumerate(result[0])])

phrase = 'jgnnqbyqtnf'

decoder(phrase, model)

'hello world'