In [1]:
import torch

device = torch.device("cpu")

class GRUModel(torch.nn.Module):
    def __init__(self, num_layers, hidden_size, dict_size):
        super().__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.dict_size = dict_size
        self.embeddings = torch.nn.Embedding(dict_size, hidden_size)
        self.gru = torch.nn.GRU(
            input_size = hidden_size,
            hidden_size = hidden_size,
            num_layers = num_layers,
            batch_first = True,
        )
        self.projection = torch.nn.Linear(hidden_size, dict_size)

    def forward(self, x, h):
        x = self.embeddings(x)
        y, h = self.gru(x, h)
        return self.projection(y), h

    def zero_state(self, batch_size):
        return torch.zeros(self.num_layers, batch_size, self.hidden_size) if batch_size is not None else torch.zeros(self.num_layers, self.hidden_size)

In [2]:
import requests

all_shakespeare = requests.get("https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt").content.decode()
print(len(all_shakespeare))

5458199


In [3]:
dictionary = list(set(all_shakespeare))
dictionary.append("<start>")
dictionary.append("<end>")
dictionary.append("<empty>")
print(dictionary)
print(len(dictionary))

sym2idx = {s: i for i, s in enumerate(dictionary)}
print(sym2idx)

['f', ')', '|', '6', 'q', "'", 'U', 'Q', 'u', 'T', 'p', ',', 'S', 'y', ' ', 'N', '(', 'P', 'X', 'r', 'n', 'M', ':', 'v', '"', '[', 'A', 'd', '2', 'F', 'W', '>', 'K', 'h', 'E', 'c', 'e', 'k', 'O', '~', '!', '8', 'C', 'm', 'B', '7', '`', 'L', 'j', '=', '.', 'a', 's', '_', '3', 'Y', 'i', 'x', 'I', '*', 'Z', 'w', '@', '9', 't', 'l', '}', '0', '5', ']', '<', '?', 'G', 'z', 'R', 'J', 'o', '\n', '1', '4', ';', '#', '%', '&', 'H', 'V', 'D', '/', '-', 'b', 'g', '<start>', '<end>', '<empty>']
94
{'f': 0, ')': 1, '|': 2, '6': 3, 'q': 4, "'": 5, 'U': 6, 'Q': 7, 'u': 8, 'T': 9, 'p': 10, ',': 11, 'S': 12, 'y': 13, ' ': 14, 'N': 15, '(': 16, 'P': 17, 'X': 18, 'r': 19, 'n': 20, 'M': 21, ':': 22, 'v': 23, '"': 24, '[': 25, 'A': 26, 'd': 27, '2': 28, 'F': 29, 'W': 30, '>': 31, 'K': 32, 'h': 33, 'E': 34, 'c': 35, 'e': 36, 'k': 37, 'O': 38, '~': 39, '!': 40, '8': 41, 'C': 42, 'm': 43, 'B': 44, '7': 45, '`': 46, 'L': 47, 'j': 48, '=': 49, '.': 50, 'a': 51, 's': 52, '_': 53, '3': 54, 'Y': 55, 'i': 56, 'x': 

In [4]:
model = GRUModel(2, 128, len(dictionary)).to(device)
for param in model.parameters():
    print(param)

Parameter containing:
tensor([[-1.2531, -0.2977, -0.7140,  ..., -0.3242,  0.2216,  0.6172],
        [-0.0093,  1.2123,  0.7450,  ...,  1.6020,  0.4401,  0.2959],
        [-0.8178,  0.3585, -0.9875,  ..., -0.5761,  0.0409, -0.4539],
        ...,
        [-0.6046,  1.1757,  0.9760,  ..., -0.1582, -0.5503,  0.8111],
        [-0.3206,  0.3883, -0.5169,  ..., -0.1127,  0.5922,  0.7392],
        [ 1.4290,  1.1875,  0.2454,  ...,  0.6470,  1.1186,  0.5240]],
       requires_grad=True)
Parameter containing:
tensor([[-0.0857, -0.0580, -0.0820,  ...,  0.0054, -0.0045,  0.0139],
        [ 0.0072, -0.0568,  0.0318,  ..., -0.0798,  0.0056, -0.0280],
        [-0.0411, -0.0668,  0.0286,  ..., -0.0289,  0.0539, -0.0095],
        ...,
        [ 0.0362,  0.0871, -0.0809,  ...,  0.0282,  0.0326, -0.0023],
        [ 0.0553, -0.0223, -0.0495,  ..., -0.0634,  0.0587,  0.0255],
        [ 0.0491, -0.0529,  0.0266,  ..., -0.0712, -0.0881,  0.0243]],
       requires_grad=True)
Parameter containing:
tensor([[-0.

In [5]:
import random

random.seed(42)

data = all_shakespeare.split("\n\n")
data = list(filter(lambda x: x, data))
random.shuffle(data)

print(len(data))
print(data[128])

6483
  SICINIUS. Well, here he comes.
  MENENIUS. Calmly, I do beseech you.
  CORIOLANUS. Ay, as an ostler, that for th' poorest piece
    Will bear the knave by th' volume. Th' honour'd gods
    Keep Rome in safety, and the chairs of justice
    Supplied with worthy men! plant love among's!
    Throng our large temples with the shows of peace,
    And not our streets with war!
  FIRST SENATOR. Amen, amen!
  MENENIUS. A noble wish.


In [6]:
train = [data[i] for i in range(len(data)) if i % 10 != 0]
test = [data[i] for i in range(len(data)) if i % 10 == 0]

print(train[-5])
print("")
print(test[-5])

  MONTJOY. You know me by my habit.
  KING HENRY. Well then, I know thee; what shall I know of thee?
  MONTJOY. My master's mind.
  KING HENRY. Unfold it.
  MONTJOY. Thus says my king. Say thou to Harry of England: Though we
    seem'd dead we did but sleep; advantage is a better soldier than
    rashness. Tell him we could have rebuk'd him at Harfleur, but  
    that we thought not good to bruise an injury till it were full
    ripe. Now we speak upon our cue, and our voice is imperial:
    England shall repent his folly, see his weakness, and admire our
    sufferance. Bid him therefore consider of his ransom, which must
    proportion the losses we have borne, the subjects we have lost,
    the disgrace we have digested; which, in weight to re-answer, his
    pettiness would bow under. For our losses his exchequer is too
    poor; for th' effusion of our blood, the muster of his kingdom
    too faint a number; and for our disgrace, his own person kneeling
    at our feet but a weak 

In [7]:
import numpy as np

def generate(model, len_limit):
    model.eval()
    with torch.no_grad():
        result = ""
        state = model.zero_state(None).to(device)
        x = "<start>"
        while len(result) < len_limit:
            x = torch.tensor([sym2idx[x]]).to(device)
            y, state = model(x, state)
            y = y[0].cpu().numpy()
            y = np.exp(y)
            y /= np.sum(y)
            x = dictionary[np.random.choice(y.shape[0], p = y)]
            if x in ["<start>", "<end>", "<empty>"]:
                break
            result += x
        return result

print(generate(model, 1000))

U=Wtabcy[u3Gc7MOxNSAj?]YN[X_|iEnCp7,


In [8]:
import tqdm

def iterate_batches(data, batch_size, device):
    x, y, max_len = [], [], 0
    for k in tqdm.tqdm(range(len(data))):
        item = data[k]
        x.append([sym2idx[sym] for sym in ["<start>"] + list(item)])
        y.append([sym2idx[sym] for sym in list(item) + ["<end>"]])
        max_len = max(max_len, len(x[-1]))
        if len(x) == batch_size or k + 1 == len(data):
            for i in range(len(x)):
                x[i] = x[i] + [sym2idx["<empty>"] for _ in range(max_len - len(x[i]))]
                y[i] = y[i] + [sym2idx["<empty>"] for _ in range(max_len - len(y[i]))]
            x = torch.tensor(x).to(device)
            y = torch.tensor(y).to(device)
            yield x, y
            x, y, max_len = [], [], 0
        

def train_epoch(data, model):
    model.train()
    loss_function = torch.nn.CrossEntropyLoss()
    #optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    optimizer = torch.optim.AdamW(model.parameters())
    total_loss, total_count = 0.0, 1e-38
    random.shuffle(data)
    for inputs, answers in iterate_batches(data, 64, device):
        optimizer.zero_grad()
        state = model.zero_state(inputs.shape[0]).to(device)
        #print(inputs.shape)
        #print(answers.shape)
        outputs, state = model(inputs, state)
        outputs = outputs.transpose(1, 2)
        #print(outputs.shape)
        loss = loss_function(outputs, answers)
        total_loss += (loss.item() * inputs.shape[0])
        total_count += inputs.shape[0]
        loss.backward()
        optimizer.step()
    return total_loss / total_count

def test_epoch(data, model):
    with torch.no_grad():
        model.eval()
        loss_function = torch.nn.CrossEntropyLoss()
        total_loss, total_count = 0.0, 1e-38
        for inputs, answers in iterate_batches(data, 64, device):
            state = model.zero_state(inputs.shape[0]).to(device)
            outputs, state = model(inputs, state)
            outputs = outputs.transpose(1, 2)
            loss = loss_function(outputs, answers)
            total_loss += (loss.item() * inputs.shape[0])
            total_count += inputs.shape[0]
        return total_loss / total_count

for i in range(10):
    train_loss = train_epoch(train[:100], model)
    test_loss = test_epoch(test[:100], model)
    print("Epoch {} loss: {:.5f} {:.5f}".format(i, train_loss, test_loss))
    print(generate(model, 1000))
    print("")

100%|█████████████████████████████████████████| 100/100 [00:19<00:00,  5.09it/s]
100%|█████████████████████████████████████████| 100/100 [00:02<00:00, 34.17it/s]


Epoch 0 loss: 4.49801 3.56978
%B`Z;,os!0EG@HN<DR`CO--]"q`Hs;Vd6yy2Wtw



100%|█████████████████████████████████████████| 100/100 [00:19<00:00,  5.06it/s]
100%|█████████████████████████████████████████| 100/100 [00:02<00:00, 33.63it/s]


Epoch 1 loss: 3.33407 2.32542




 63%|██████████████████████████▍               | 63/100 [00:08<00:05,  7.38it/s]


KeyboardInterrupt: 