In [1]:
import torch

device = torch.device("cuda")

class RNNModel(torch.nn.Module):
    def __init__(self, hidden_size, dict_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.dict_size = dict_size
        self.embeddings = torch.nn.Embedding(dict_size, hidden_size)
        self.wh = torch.nn.Linear(hidden_size, hidden_size, bias=False)
        self.wy = torch.nn.Linear(hidden_size, hidden_size, bias=False)
        self.uh = torch.nn.Linear(hidden_size, hidden_size, bias=False)
        self.bh = torch.nn.Parameter(torch.randn(hidden_size))
        self.by = torch.nn.Parameter(torch.randn(hidden_size))
        self.projection = torch.nn.Linear(hidden_size, dict_size)

    def forward(self, x, h):
        x = self.embeddings(x)
        h = torch.sigmoid(self.wh(x) + self.uh(h) + self.bh)
        y = self.projection(torch.sigmoid(self.wy(h) + self.by))
        return y, h

    def zero_state(self, batch_size):
        return torch.zeros(batch_size, self.hidden_size)

In [2]:
import requests

all_shakespeare = requests.get("https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt").content.decode()
print(len(all_shakespeare))

5458199


In [3]:
dictionary = list(set(all_shakespeare))
dictionary.append("<start>")
dictionary.append("<end>")
dictionary.append("<empty>")
print(dictionary)
print(len(dictionary))

sym2idx = {s: i for i, s in enumerate(dictionary)}
print(sym2idx)

['P', 'D', 'B', 'Z', 'T', ';', 'W', '\n', 'S', '>', 'p', 'u', 'g', 'd', 'E', '|', '(', 'n', 'O', '9', '6', '/', '`', ')', 't', 'V', 'q', 'A', 'm', '1', ':', 'j', 'w', 'h', '-', '*', '@', 'R', 'Y', 'N', 'c', "'", '?', ',', '%', 'f', 'i', 'a', 'I', '!', ']', '7', 's', 'G', 'F', '2', '[', '5', 'M', 'X', 'k', 'C', '.', 'x', 'r', 'y', 'z', 'Q', '_', '8', 'J', 'L', 'U', '"', 'K', '&', 'b', ' ', '=', '0', '~', '4', 'v', 'e', 'H', '<', '}', '#', 'o', '3', 'l', '<start>', '<end>', '<empty>']
94
{'P': 0, 'D': 1, 'B': 2, 'Z': 3, 'T': 4, ';': 5, 'W': 6, '\n': 7, 'S': 8, '>': 9, 'p': 10, 'u': 11, 'g': 12, 'd': 13, 'E': 14, '|': 15, '(': 16, 'n': 17, 'O': 18, '9': 19, '6': 20, '/': 21, '`': 22, ')': 23, 't': 24, 'V': 25, 'q': 26, 'A': 27, 'm': 28, '1': 29, ':': 30, 'j': 31, 'w': 32, 'h': 33, '-': 34, '*': 35, '@': 36, 'R': 37, 'Y': 38, 'N': 39, 'c': 40, "'": 41, '?': 42, ',': 43, '%': 44, 'f': 45, 'i': 46, 'a': 47, 'I': 48, '!': 49, ']': 50, '7': 51, 's': 52, 'G': 53, 'F': 54, '2': 55, '[': 56, '5':

In [4]:
model = RNNModel(128, len(dictionary)).to(device)
for param in model.parameters():
    print(param)

Parameter containing:
tensor([ 0.8450,  0.4010, -2.1780, -0.1664, -0.4047, -0.7855,  1.5931, -0.5891,
         0.6570,  0.1699, -1.3531, -0.4122,  1.1594, -0.6741, -1.5409, -0.2923,
         0.8097,  0.1572, -1.3926,  1.0935,  0.5900, -0.7343, -1.0554, -0.0265,
         0.4737,  0.5921, -1.2053, -0.1082,  0.7025,  0.4492, -0.9323,  1.1177,
         0.5482, -0.6988,  0.4176,  0.5328,  1.0138, -0.5300,  0.2484,  0.1550,
        -0.2713, -0.9858,  0.9586,  0.5857,  0.8949, -1.1385,  0.6538,  0.1927,
        -1.0815,  0.7078,  0.6650,  0.2809, -0.8523,  0.3628, -0.2790, -0.1556,
         1.4167, -0.1194,  0.6041,  0.6225,  0.6362, -0.6690,  1.2722, -0.0845,
        -0.6244, -0.7992, -1.5093, -0.2532, -0.6413,  0.9036,  1.0816,  0.9328,
         1.4068, -0.0289,  0.7068,  0.4110,  0.2703,  0.0034,  0.2026, -0.6429,
        -0.0257,  1.2472,  1.1970, -0.1720, -0.8664, -0.5702, -0.0184, -0.7364,
        -0.8061,  0.5111,  1.0159, -0.7839, -0.0441,  0.7904,  1.3348, -0.1857,
        -0.3873, -

In [5]:
import random

random.seed(42)

data = all_shakespeare.split("\n\n")
data = list(filter(lambda x: x, data))
random.shuffle(data)

print(len(data))
print(data[128])

6483
  SICINIUS. Well, here he comes.
  MENENIUS. Calmly, I do beseech you.
  CORIOLANUS. Ay, as an ostler, that for th' poorest piece
    Will bear the knave by th' volume. Th' honour'd gods
    Keep Rome in safety, and the chairs of justice
    Supplied with worthy men! plant love among's!
    Throng our large temples with the shows of peace,
    And not our streets with war!
  FIRST SENATOR. Amen, amen!
  MENENIUS. A noble wish.


In [6]:
train = [data[i] for i in range(len(data)) if i % 10 != 0]
test = [data[i] for i in range(len(data)) if i % 10 == 0]

print(train[-5])
print("")
print(test[-5])

  MONTJOY. You know me by my habit.
  KING HENRY. Well then, I know thee; what shall I know of thee?
  MONTJOY. My master's mind.
  KING HENRY. Unfold it.
  MONTJOY. Thus says my king. Say thou to Harry of England: Though we
    seem'd dead we did but sleep; advantage is a better soldier than
    rashness. Tell him we could have rebuk'd him at Harfleur, but  
    that we thought not good to bruise an injury till it were full
    ripe. Now we speak upon our cue, and our voice is imperial:
    England shall repent his folly, see his weakness, and admire our
    sufferance. Bid him therefore consider of his ransom, which must
    proportion the losses we have borne, the subjects we have lost,
    the disgrace we have digested; which, in weight to re-answer, his
    pettiness would bow under. For our losses his exchequer is too
    poor; for th' effusion of our blood, the muster of his kingdom
    too faint a number; and for our disgrace, his own person kneeling
    at our feet but a weak 

In [7]:
import numpy as np

def generate(model, len_limit):
    model.eval()
    with torch.no_grad():
        result = ""
        state = model.zero_state(1).to(device)
        x = "<start>"
        while len(result) < len_limit:
            x = torch.tensor(sym2idx[x]).to(device)
            y, state = model(x, state)
            y = y[0].cpu().numpy()
            y = np.exp(y)
            y /= np.sum(y)
            x = dictionary[np.random.choice(y.shape[0], p = y)]
            if x in ["<start>", "<end>", "<empty>"]:
                break
            result += x
        return result

print(generate(model, 1000))

1xQ9xGVHq(C.-H4"4R83&peAo,G_#Z
@>L(n"d46I)k!FVb#A


In [8]:
import tqdm

def iterate_batches(data, batch_size, device):
    x, y, max_len = [], [], 0
    for k in tqdm.tqdm(range(len(data))):
        item = data[k]
        x.append([sym2idx[sym] for sym in ["<start>"] + list(item)])
        y.append([sym2idx[sym] for sym in list(item) + ["<end>"]])
        max_len = max(max_len, len(x[-1]))
        if len(x) == batch_size or k + 1 == len(data):
            for i in range(len(x)):
                x[i] = x[i] + [sym2idx["<empty>"] for _ in range(max_len - len(x[i]))]
                y[i] = y[i] + [sym2idx["<empty>"] for _ in range(max_len - len(y[i]))]
            x = torch.tensor(x).to(device)
            y = torch.tensor(y).to(device)
            yield x, y
            x, y, max_len = [], [], 0
        

def train_epoch(data, model):
    model.train()
    loss_function = torch.nn.CrossEntropyLoss()
    #optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    optimizer = torch.optim.AdamW(model.parameters())
    total_loss, total_count = 0.0, 1e-38
    random.shuffle(data)
    for inputs, answers in iterate_batches(data, 64, device):
        optimizer.zero_grad()
        outputs = []
        state = model.zero_state(inputs.shape[0]).to(device)
        inputs = inputs.transpose(1, 0)
        #print(inputs.shape)
        #print(answers.shape)
        for i in range(inputs.shape[0]):
            y, state = model(inputs[i], state)
            #print(y.shape)
            outputs.append(y)
        outputs = torch.stack(outputs).transpose(1, 0).transpose(1, 2)
        #print(outputs.shape)
        loss = loss_function(outputs, answers)
        total_loss += (loss.item() * inputs.shape[0])
        total_count += inputs.shape[0]
        loss.backward()
        optimizer.step()
    return total_loss / total_count

def test_epoch(data, model):
    with torch.no_grad():
        model.eval()
        loss_function = torch.nn.CrossEntropyLoss()
        total_loss, total_count = 0.0, 1e-38
        for inputs, answers in iterate_batches(data, 64, device):
            outputs = []
            state = model.zero_state(inputs.shape[0]).to(device)
            inputs = inputs.transpose(1, 0)
            for i in range(inputs.shape[0]):
                y, state = model(inputs[i], state)
                outputs.append(y)
            outputs = torch.stack(outputs).transpose(1, 0).transpose(1, 2)
            loss = loss_function(outputs, answers)
            total_loss += (loss.item() * inputs.shape[0])
            total_count += inputs.shape[0]
        return total_loss / total_count

for i in range(10):
    train_loss = train_epoch(train, model)
    test_loss = test_epoch(test, model)
    print("Epoch {} loss: {:.5f} {:.5f}".format(i, train_loss, test_loss))
    print(generate(model, 1000))
    print("")

100%|██████████| 5834/5834 [05:04<00:00, 19.14it/s]
100%|██████████| 649/649 [00:09<00:00, 67.23it/s] 


Epoch 0 loss: 1.20524 0.55670




100%|██████████| 5834/5834 [05:08<00:00, 18.93it/s]
100%|██████████| 649/649 [00:09<00:00, 67.26it/s] 


Epoch 1 loss: 0.36456 0.37489
n kOa.
h

Det' a DeI



100%|██████████| 5834/5834 [05:06<00:00, 19.02it/s]
100%|██████████| 649/649 [00:09<00:00, 67.20it/s] 


Epoch 2 loss: 0.30614 0.30822
 me Pd
 ear gem fism Tour. mor  tI   satesryhunet,. UWwOT
Oyesnedrr   ikdlitive wikK! heWonee  , onlIw ;hiteuwCAthD;  I S  sto Pern sRE,  



100%|██████████| 5834/5834 [05:12<00:00, 18.67it/s]
100%|██████████| 649/649 [00:09<00:00, 67.15it/s] 


Epoch 3 loss: 0.25233 0.27135
 dinl.
    Tirimh me thand's avithes or danteas lre's tharmend fou thimeor jat nen thee Gharlrroref s;ungher neeerpares maye fardreu Rayollld omh reit tivewe lo siwid enir, lon to he wiy the meme mal'g heer]. LO.-ACIOEVunt wo theeserraaetindendclouce diy mideethormar  Cichheile benh cor brendolit ily buho fadb. Aot to mordend hou  ser laco ath bicatra norrenstt igoranth; savesy coos shetre thes pised deyeko ponmeuimd 'mw, Ayoiws wan winse ang cadT 



100%|██████████| 5834/5834 [05:04<00:00, 19.19it/s]
100%|██████████| 649/649 [00:09<00:00, 67.06it/s] 


Epoch 4 loss: 0.23435 0.25567
 pang kke shek trolccony ir preathsy feclien thor;?
  AR woy wiser, aHibuh,
   AVPAd Gilte at i yulis, meard
    IW Pwiken ilccort me yous; maod or tid, i,
      We gand be hole hall at fere, thim; ow sherise's, trome Leyentayd manget an wind ato, to perestouch nom an halk se rind wonot ppuald, it ar noss meast Jowt ik
         may thad diinir bleat cont:, ,
         .
  mosas oanmit.
  Cant memlich? Anlcoven'em hito the hou; kirus serree irs enlelce I wiln soas ath henone hrath fore Of gou to hat-.
  An to ad . I 's mated thu? KUUFES I Shaget Moof gote, ru! Dlay mpot;
       ster toud thow thanteine luprit's;
   Null urlfig wor yis yoont      Yous bond if dedthenr yo fon wir
                            Thebecun gofribith meparnid it lot
s ide ser roEs on seer ir itins bayco, come! millwes to sour for's.        ho woans afauvy thor
   Gnemdirg. EX Dy Lovied gothe sey llame cold chichcireo'gs weat mores sfevey tharchgrarnemnhereen sho de?
  Wame derom the b

100%|██████████| 5834/5834 [05:04<00:00, 19.14it/s]
100%|██████████| 649/649 [00:09<00:00, 67.16it/s] 


Epoch 5 loss: 0.22047 0.24698
     Whou sen!
  SYA. Brptoret hit wate I for sone haide fy! sir maect pely
    wit whausce hoonond a fian sato hik When to no woic, ang Maliny erghe, ing
    Whang thave blead forve sheper. Adhboce thiid pon corly lo; stace as bepon wish she in atry aln, coach mus to nes lispe-ben a for the hale! Bumellabe ang homcetrey me sparenenglor geEnis, meoter Amind
      ansgatl,
   [AMRHORARD. CIE(R  EDE Jacrovothmine dini the turercrerittous ib's, my be
            Th, the,
     Bother thatt mout'dom he is tpims owang Cham gatrir thI. Whid thrleek one. ONUENRSD I. Thild arl muth ulener afe mo are ave poun! Andy dim hy so ige mat hy
    Thid chat covand hirtip'es, cuit buveshy paer coud lady or prele be it titich, of the Mest
     
   Nive bonofruventess.   Swarm twere
    I noption facdele of As, yering mism
  Tleoveras win bestinorly,  [ANCCIANVAUR. SI' Siten tall me fouge- you will as- the, my
     To mapodnpeat ba. I ar swour, satles,
     ske.
  Bu by goshco

100%|██████████| 5834/5834 [05:04<00:00, 19.19it/s]
100%|██████████| 649/649 [00:09<00:00, 67.75it/s] 


Epoch 6 loss: 0.21618 0.24179
 Othentlest thy
    Senchckebou,
  
  If sit dim, jecc! O He thal. ATh.
  ARURHDENUMEDIS DOCTEUH ORD



100%|██████████| 5834/5834 [04:58<00:00, 19.57it/s]
100%|██████████| 649/649 [00:09<00:00, 67.53it/s] 


Epoch 7 loss: 0.21398 0.23767
 Ove wirE. Of my Gigramt me Lean,
    This bry I the for
    Mes, lennot thit sele.



100%|██████████| 5834/5834 [04:57<00:00, 19.64it/s]
100%|██████████| 649/649 [00:09<00:00, 67.51it/s] 


Epoch 8 loss: 0.21068 0.23402
ANTTEMOPCOTUOd elfwono dear sut'd befire ble shat! Fray mrey
                whis. Werst
    To lil she hern hion 'Dentere thous stof GROS. Who
      Ind FAt and hen of hur nees hee sert is. Conl and rorck
   And 'nfich wive the and ip they,
    Whell,
       [Ist, Hesseaush hend lefcee I the, gotcowle cowelo wour buts batcaon,
    To and;
    Hatan afir'd nod                  [And ofuse
    my weass conge
    Loce wors sure ciood.
  
TOLD EURSHIRCOGIBETIOG. Whell of and mamerese my my
    Vowe a stiin?
  ANCAGRIAS
  SIGKORSI.      you beartwerscere and do whige to strece,
    Sire his retly thos her pravinter mave,
    foung wore thull I af bepithad I onfoy me.
  Wil-hit; I comslos chere? A spar you shighim dover;
    hind Sary hove if in mars bay senght fnot my
     Ame
    ERIB. Ot oscey no, and heno, as my thissey; p'onse beht; mworse you.
  WCALA'. Whees the
    You aple,  sword ceeal
    Bulk whoth.  [SEILRIO TIADHDENCANLECOS No san that the antuly i

100%|██████████| 5834/5834 [04:59<00:00, 19.51it/s]
100%|██████████| 649/649 [00:09<00:00, 67.38it/s] 

Epoch 9 loss: 0.20791 0.23083
      To fly,
    Muctertisher you frotitherry dats
     mo?
  BRLLy, bening thath. TRirking ofsey thour.
    creen and whis in
    Cown Leath sorret.
  BOCIVOCNUD BRIANA. At cathe thes;
              he to list wilf seave hasten, harpow nit my yous
    wortore. I limlas,
    toth all wis tE ont stee tou best;
    Tho hit it with'n we? Exook thiy ou swak't deaed ass a dremwast liithee with firse of have pod;
    The be dwight, I dey in theof wald you. Hith othise; fall nofre Haf'd ous so nagres youl leceven, mus fare?




