# Homework 3 - Text generation with LSTM and Transformer networks



## Installs the unidecode library and downloads the Shakespeare dataset.

In [35]:
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

## LSTM implementation

For this task you will implement the LSTM neural network architecture and train it on the task of character-level text generation. Implement a single layer LSTM and optionally extend your implementation to multiple layers to generate better results.

Links:

- https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html -- Lists the equations for each component of the LSTM cell.
- http://colah.github.io/posts/2015-08-Understanding-LSTMs/ -- Intuitive explanation of LSTM
- http://karpathy.github.io/2015/05/21/rnn-effectiveness/ -- Explanation and uses of RNNs.


Implement the initialization and the forward pass of a LSTMCell and use it as part of the LSTMSimple network class.

The input of the LSTM network will be a sequence of characters, whereas the input of the LSTMCell will be a single input character (x), the output of the previous iteration (C) and the hidden state of the previous iteration (h). Iteratively process the entire input character sequence and calculate the loss based on the prediction at each time step.

### Do NOT use the torch.nn.LSTM class.


In [36]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset

class LSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):

        super(LSTMCell, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        # compute the forgetting gate
        self.forgetting_gate = nn.Sequential(
            nn.Linear(self.input_dim+self.hidden_dim, self.output_dim),
            nn.Sigmoid()
        )

        # compute the input gate and new memories
        self.input_gate = nn.Sequential(
            nn.Linear(self.input_dim+self.hidden_dim, self.output_dim),
            nn.Sigmoid()
        )
        self.new_memories = nn.Sequential(
            nn.Linear(self.input_dim+self.hidden_dim, self.output_dim),
            nn.Tanh()
        )

        # compute the new hidden state
        self.new_hidden_mask = nn.Sequential(
            nn.Linear(self.input_dim+self.hidden_dim, self.hidden_dim),
            nn.Sigmoid()
        )
        self.new_hidden_vals = nn.Sequential(
            nn.Linear(self.input_dim+self.hidden_dim, self.hidden_dim),
            nn.Tanh()
        )

    def forward(self, x:torch.tensor, C:torch.tensor, h:torch.tensor):
        # x - batch of encoded characters
        # C - Cell state of the previous iteration
        # h - Hidden state of the previous iteration

        # concat h_t-1 and x 
        hidden_stack = torch.concat((x,h), dim=1)
        
        # calculate forgetting mask
        forgetting_mask = self.forgetting_gate(hidden_stack)
        # forget some C by the mask
        forgotten_C = C * forgetting_mask

        # calculat new memories
        input_mask = self.input_gate(hidden_stack)
        input_vals = self.new_memories(hidden_stack)
        masked_new_vals = input_mask * input_vals


        # modify cell state with new values
        new_C = forgotten_C + masked_new_vals
        
        # calculat new hidden dim
        hiden_mask = self.new_hidden_mask(hidden_stack)
        hiden_vals = self.new_hidden_vals(hidden_stack)
        new_hidden_state = hiden_mask * hiden_vals

        return new_C, new_hidden_state


class LSTMSimple(nn.Module):
    def __init__(self, seq_length, input_dim, hidden_dim, output_dim, batch_size):
        super(LSTMSimple, self).__init__()

        self.seq_length = seq_length
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        self.lstm_cell = LSTMCell(input_dim, hidden_dim, output_dim)
        self.proj = nn.Sequential(
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x):
        # x - One hot encoded batch - Shape: (batch, seq_len,)
        batch, tokens, features = x.shape
        c = torch.zeros((batch, self.output_dim)).cuda()
        h = torch.zeros((batch, self.hidden_dim)).cuda()
        out = torch.zeros((batch, self.seq_length, self.output_dim)).cuda()

        for t in range(tokens):
            x_t = x[:, t, :]
            c, h = self.lstm_cell(x_t,c,h)

            o = self.proj(h)
            out[:, t,:] = o
        
        return  out, (c,h)


        # Returns the predicted next character for each character in the
        # sequence (outputs), also returns the cell state and hidden state of the
        # LSTMCell call on the last character. -- outputs, (c,t)



### LSTM Sampling Code

To generate text the network must predict the next character in a sequence, however networks do not produce a single character but rather estimate the likelihood for each possible character. Sampling characters from the network output can be done in different ways with common ones being the Greedy sampling process and Top-K sampling.

In the simple greedy sampling method the network takes a text prompt as input and generates an additional N tokens by always taking the token with the highest prediction score as the next token.

In the Top-K sampling, randomness is added to the sampling process as the network samples from K most likely predicitons at each step. This alleviates the problem of generative models repeating text but may generate incorrect text by sampling inappropriate tokens.


In [37]:
def greedy_sampling_lstm(lstm, x, num_chars):
    # x -- b x onehot_char
    outputs = torch.zeros((1,num_chars,x.shape[2]))
    t_outputs, (cell_state, hidden) = lstm(x.float())
    for c in range(num_chars):
        output_tmp = torch.softmax(lstm.proj(hidden),dim=1)
        top_ind = torch.argmax(output_tmp,dim=1)[0]
        tmp = torch.zeros_like(x[:,0,:]).cuda()
        tmp[:,top_ind] = 1
        outputs[:,c] = tmp

        cell_state, hidden = lstm.lstm_cell(tmp,cell_state,hidden)
    return outputs

def topk_sampling_lstm(lstm, x, num_chars):
    # x -- b x onehot_char
    outputs = torch.zeros((1,num_chars,x.shape[2]))
    t_outputs, (cell_state, hidden) = lstm(x.float())
    for c in range(num_chars):
        output_vals, output_ind = torch.topk(lstm.proj(hidden), 5, dim=1)
        output_tmp = torch.softmax(output_vals,dim=1)
        top_ind = torch.multinomial(output_tmp[0], 1)[0]
        tmp = torch.zeros_like(x[:,0,:]).cuda()
        tmp[:,output_ind[0,top_ind]] = 1
        outputs[:,c] = tmp

        cell_state, hidden = lstm.lstm_cell(tmp,cell_state,hidden)

    return outputs

### LSTM Dataset Code

In [38]:
import unidecode
import string
import random
from torch.autograd import Variable
from torch.utils.data import Dataset


class LSTMDataset(Dataset):
    def __init__(self, chunk_len=200, padded_chunks=False):
        # Character based dataset
        dataset_path = "./input.txt"
        # The tokens in the vocabulary (all_characters)
        # are just the printable characters of the string class
        self.all_characters = string.printable
        self.n_characters = len(self.all_characters)
        # Maps characters to indices
        self.char_dict = {x:i for i,x in enumerate(self.all_characters)}
        self.file, self.file_len = self.read_file(dataset_path)
        # Sequence length of the input
        self.chunk_len = chunk_len

    def read_file(self,filename):
        file = unidecode.unidecode(open(filename).read())
        return file, len(file)

    def char_tensor(self,in_str):
        # in_str - input sequence - String
        # Return one-hot encoded characters of in_str
        tensor = torch.zeros(len(in_str),self.n_characters).long()
        char_ind = [self.char_dict[c] for c in in_str]
        tensor[torch.arange(tensor.shape[0]),char_ind] = 1
        return tensor

    def __getitem__(self, idx):
        inp, target = self.get_random_text()
        return {"input":inp, "target":target}

    def __len__(self):
        return 10000

    def get_random_text(self):
        # Pick a random string of length self.chunk_len from the dataset
        start_index = np.random.randint(0, self.file_len - self.chunk_len)
        end_index = start_index + self.chunk_len + 1
        chunk = self.file[start_index:end_index]
        # One-hot encode the chosen string
        inp = self.char_tensor(chunk[:-1])
        # The target string is the same as the
        # input string but shifted by 1 character
        target = self.char_tensor(chunk[1:])
        inp = Variable(inp).cuda()
        target = Variable(target).cuda()
        return inp, target


### LSTM Training loop

With a correct implementation you should get sensible text generation results with the set parameters, however you should experiment with various parameters,
especially with the sequence length (chunk_len) used during training.

In [39]:
from tqdm import tqdm
import torch.optim as optim

batch_size = 256
chunk_len = 128
model_name = "LSTM"
train_dataset = LSTMDataset(chunk_len=chunk_len)
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=0, drop_last=True)

#Sample parameters, use whatever you see fit.
input_dim = train_dataset.n_characters
hidden_dim = 256
output_dim = train_dataset.n_characters
learning_rate = 0.005
model = LSTMSimple(chunk_len,input_dim, hidden_dim, output_dim,batch_size)
model.train()
model.cuda()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

epochs=30

for epoch in range(epochs):
    with tqdm(total=len(trainloader.dataset), desc ='Training - Epoch: '+str(epoch)+"/"+str(epochs), unit='chunks') as prog_bar:
        for i, data in enumerate(trainloader, 0):
            inputs = data['input'].float()
            labels = data['target'].float()
            # b x chunk_len x len(dataset.all_characters)
            optimizer.zero_grad()
            outputs, _ = model(inputs)
            target = torch.argmax(labels,dim=2)

            loss = criterion(outputs.view(inputs.shape[0]*inputs.shape[1],-1),target.view(labels.shape[0]*labels.shape[1]))
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                      max_norm=10.0)
            optimizer.step()
            prog_bar.set_postfix(**{'run:': model_name,'lr': learning_rate,
                                    'loss': loss.item()
                                    })
            prog_bar.update(batch_size)
        # Intermediate output
        sample_text = "O Romeo, wherefore art thou"
        inp = train_dataset.char_tensor(sample_text)
        sample_input = Variable(inp).cuda().unsqueeze(0).float()
        out_test = topk_sampling_lstm(model,sample_input, 300)[0]
        out_char_index = torch.argmax(out_test, dim=1).detach().cpu().numpy()
        out_chars = sample_text+"".join([train_dataset.all_characters[i] for i in out_char_index])
        print("Top-K sampling -----------------")
        print(out_chars)

        out_test = greedy_sampling_lstm(model,sample_input, 300)[0]
        out_char_index = torch.argmax(out_test, dim=1).detach().cpu().numpy()
        out_chars = sample_text+"".join([train_dataset.all_characters[i] for i in out_char_index])
        print("Greedy sampling ----------------")
        print(out_chars)


Training - Epoch: 0/30:   3%|▎         | 256/10000 [00:00<00:08, 1180.97chunks/s, loss=4.62, lr=0.005, run:=LSTM]

Training - Epoch: 0/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1160.24chunks/s, loss=2.93, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou  he 
erit 
 aan 
h t soun 
ee ser taee 
 mirerert
er sino eo he therese to ne
 hore
 a  e
tae mor 
erd mond
sha s matt to thase thi  e  oet ta ne


o e shas t   a tee
 t end mo d tee  aon   he shan 
ha s tho eon 
ene  ou torot  he done 
hou 
 mar  oen hhat    eer sat eet eoutt e ners
 hert

hone 
a


Training - Epoch: 0/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1095.26chunks/s, loss=2.93, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou  he the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the


Training - Epoch: 1/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1161.33chunks/s, loss=2.41, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou wor arine waris int therin sheree this, thet an te the sind are hower the meresteres wall thethar and sord ang andend arethan sathe so moterars irith as aretine shand tarestin singert of he myerint ind wous ar ing wer thit thet mares alit we me the the shot, and woun and thee whis,
I tises ant oule


Training - Epoch: 1/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1098.32chunks/s, loss=2.41, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the


Training - Epoch: 2/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1119.57chunks/s, loss=2.24, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thoul singe wate ther thale thond sith me peat se thit thee tere, ther at the state, sor that the sore seer the weat hot the thin ther shat tho ghar houn mate the ware ale that ste this at an the ceat tall tite, tot he me herithast ta the coure,, at in ther and he mongen stiss and toust at this wer whes


Training - Epoch: 2/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1082.07chunks/s, loss=2.24, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thour the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the th


Training - Epoch: 3/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1145.89chunks/s, loss=2.1, lr=0.005, run:=LSTM] 

Top-K sampling -----------------
O Romeo, wherefore art thoug seane his thes is not to thou hom heane and shing ho her, ast well the thin shing ate ho hant to the camy tour shat shald sord se shall th then tha ling than th the hom of
With stell wallend, best anghes at oul ser sill we cite to shoredest and the serente, buces inge this seare at ald but him,
Th


Training - Epoch: 3/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1092.12chunks/s, loss=2.1, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou hat the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the


Training - Epoch: 4/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1138.05chunks/s, loss=2.02, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art though,
Tis tar hom art think the meandst with mis seate hive the ment to mers and thinest all the price, as in the thee fall,
And hither soull wingst ang she souss thou monged sersed to mond to to boorse songone thee he prates and have and hemp at made that well the cond is heme thit dintend,
In shared


Training - Epoch: 4/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1071.53chunks/s, loss=2.02, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou dous the sond the sond the sond the sond the sond the sond the sond the sond the sond the sond the sond the sond the sond the sond the sond the sond the sond the sond the sond the sond the sond the sond the sond the sond the sond the sond the sond the sond the sond the sond the sond the sond the so


Training - Epoch: 5/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1152.61chunks/s, loss=1.91, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art though sene frot make
That and mare of hit fay toun beer, teer a well ta bear my herse were how fraid of
Tid, are bute the parsunce are to hem to sence in thing thee hered, thy hors, wele hath thin me mighter and ald be antere, singhtre siess.

PRONGIO:
A dell of my has a fartull at more thom ar ithor o


Training - Epoch: 5/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1068.59chunks/s, loss=1.91, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou do the such and the such and the such and the such and the such and the such and the such and the such and the such and the such and the such and the such and the such and the such and the such and the such and the such and the such and the such and the such and the such and the such and the such a


Training - Epoch: 6/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1069.23chunks/s, loss=1.86, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou dear the know the men the more and wild brith him thy, buldes trust, this sour him,
Will how, be to him seave you so thot she the kinged no the king at you shant thing thou werl with mess that heat,
And thougst of your gave my lay, told thougs are an my lord, wher have shard be my look, and the sea


Training - Epoch: 6/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1056.04chunks/s, loss=1.86, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou do the stould the sear the sear the sear the sear the sear the sear the sear the sear the sear the sear the sear the sear the sear the sear the sear the sear the sear the sear the sear the sear the sear the sear the sear the sear the sear the sear the sear the sear the sear the sear the sear the se


Training - Epoch: 7/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1151.00chunks/s, loss=1.79, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou a trount in thy she lurest, my larght to this wanger, thought many to man the right to thy lorst this art the roshed.

PORINIO:
No shall as there is not.

BRANUEN:
What santing is the preest and mith off tares and me brenged there thene hence of the surder, brind,
I have you hishel thise to bear a 


Training - Epoch: 7/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1085.87chunks/s, loss=1.79, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou have the counters and the stand the stand the stand the stand the stand the stand the stand the stand the stand the stand the stand the stand the stand the stand the stand the stand the stand the stand the stand the stand the stand the stand the stand the stand the stand the stand the stand the sta


Training - Epoch: 8/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1177.90chunks/s, loss=1.76, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thoughth with mad mest man me to bleader thou dester him.

SAMPLANUS:
They goverest best oneredist onde this to morr thee tome, thou wollors and she tiget that tremest as a manters to be true, and stant the kisterthen to stould, mine and, and that hearter have the caute he wather,
And ther as holestat h


Training - Epoch: 8/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1105.54chunks/s, loss=1.76, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou dost the sears the sears the sears the sears the sears the sears the sears the sears the sears the sears the sears the sears the sears the sears the sears the sears the sears the sears the sears the sears the sears the sears the sears the sears the sears the sears the sears the sears the sears the 


Training - Epoch: 9/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1135.52chunks/s, loss=1.7, lr=0.005, run:=LSTM] 

Top-K sampling -----------------
O Romeo, wherefore art thou hash a tord are it to here angrouss, but he would the wints, they hast themess.

GREMIO:
Nay, are to thy were your friath.

MENRES:
Will, that wouss, beiste so,
Be conelless
Think not my this done the rager that wo should.

GOLIO:
I am tithie,
Thou a blee a sirs, there this tounts an mold wing of h


Training - Epoch: 9/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1099.55chunks/s, loss=1.7, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou have a served and the prove the words and the prove the words and the prove the words and the prove the words and the prove the words and the prove the words and the prove the words and the prove the words and the prove the words and the prove the words and the prove the words and the prove the wor


Training - Epoch: 10/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1129.42chunks/s, loss=1.67, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou surdion the stourd, and though dead were a trush, to she this arr my sousies some
The stilt and the worncass to deld to do my forench,
If years of she he seevy to then, to shate, sir, will sare,
I will he deed and but him,
That hade the hang to she with his subjeat,
And hath show them here.

SICINI


Training - Epoch: 10/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1098.51chunks/s, loss=1.67, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou art the promes of the sears of the sears of the sears of the sears of the sears of the sears of the sears of the sears of the sears of the sears of the sears of the sears of the sears of the sears of the sears of the sears of the sears of the sears of the sears of the sears of the sears of the sear


Training - Epoch: 11/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1122.47chunks/s, loss=1.63, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou speak.

DOKE:
Then you shall briok my so fint sone,
And I have this would stay this wornt a dost bear by this, his could holl he shring the hower to the suffer all by the stations fall it well think this in some stay.

BOUKE:
And shall stay is now stint have hear to-marriteds to to her.

COMINIUS:



Training - Epoch: 11/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1083.42chunks/s, loss=1.63, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou shall be the state the state the state the state the state the state the state the state the state the state the state the state the state the state the state the state the state the state the state the state the state the state the state the state the state the state the state the state the state 


Training - Epoch: 12/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1041.50chunks/s, loss=1.57, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thougs all the words as infall once in so, that we hear's my live a toult bus you this as I be to the sif, there is not to bath his for thain, with all my leass if this is thy pray on all to mean to deep,
And so therefild made to-merrion of a sunder time the chereal of all to the promishis
Wasting here 


Training - Epoch: 12/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1075.19chunks/s, loss=1.57, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou shall be the prove to the prove to the prove to the prove to the prove to the prove to the prove to the prove to the prove to the prove to the prove to the prove to the prove to the prove to the prove to the prove to the prove to the prove to the prove to the prove to the prove to the prove to the 


Training - Epoch: 13/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1149.20chunks/s, loss=1.57, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou bear.

CORIOLANUS:
Will husble with hen to this way;
I wollors as this my leest,
And make me some, beith of his bodes
Or shills it is athemble this ass a server'd be of this.

BUCKINGHAM:
Than I am a cheath, and be though you well,
As the world and him tell here, though you arms.

KING RICHARD III:


Training - Epoch: 13/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1072.18chunks/s, loss=1.57, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou shall be the sear the state to the country shall be the sear the state to the country shall be the sear the state to the country shall be the sear the state to the country shall be the sear the state to the country shall be the sear the state to the country shall be the sear the state to the countr


Training - Epoch: 14/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1182.72chunks/s, loss=1.53, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou hast thee?

GRUMIO:
I am a cortease their honour. But, sirve the shappary an is alouted with my friends of mine hather sauntion we as he wear to me.

GLOUCESTER:
That stander tham I lame the colding and him that they was sour than a word you, men you have shouse that thy hands, tire, my forgors:
Wh


Training - Epoch: 14/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1087.59chunks/s, loss=1.53, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou hast the rest the rest the king of the state of the sentless the sentent the rest and the state of the sentless the sentent the rest and the state of the sentless the sentent the rest and the state of the sentless the sentent the rest and the state of the sentless the sentent the rest and the state


Training - Epoch: 15/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1171.53chunks/s, loss=1.52, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou them are he doot me to my barks all the wars, the tone of the past o' the world where's the bardens and talk thee words thou denited that he would treath of his foirings,
She would never be the fitted that had suct the pears there's a bettard, make his some twath to he drubted to me,
I couldon, is 


Training - Epoch: 15/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1091.23chunks/s, loss=1.52, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou shalt stay the prove to the world the marries the world the marries the world the marries the world the marries the world the marries the world the marries the world the marries the world the marries the world the marries the world the marries the world the marries the world the marries the world t


Training - Epoch: 16/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1175.99chunks/s, loss=1.49, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou hope thou hish poot to my speeting soul will the corrow thee there.

LUCIO:
I would have stown to be here.

KING HENRY VI:
This is the cried for some sit, and they stand of the wild with hand that she's hath,
To make the childs, and to make heaven, there's the wooness.

PETRUCHIO:
Andess take thin 


Training - Epoch: 16/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1092.88chunks/s, loss=1.49, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou have a sword and the sund the country's son the command and the country's to the world,
And then the world the command and the world,
And then the world the command and the world,
And then the world the command and the world,
And then the world the command and the world,
And then the world the comm


Training - Epoch: 17/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1118.92chunks/s, loss=1.5, lr=0.005, run:=LSTM] 

Top-K sampling -----------------
O Romeo, wherefore art thou, is she hath news warther hurband,
And be thou shall fair with his last is some to his honour than thin hath diserved is the baster and the warling stread of the brothers to him home heart of a powarned, whither or endors,
Ans to be a servician and man as deeds of my day!
There were thy first,
But 


Training - Epoch: 17/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1086.76chunks/s, loss=1.5, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou hast should be so record and soul the command and the country's son the command and the country's son the command and the country's son the command and the country's son the command and the country's son the command and the country's son the command and the country's son the command and the country


Training - Epoch: 18/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1155.35chunks/s, loss=1.47, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou say,
This is as these stander transof liggly death.

KING RICHARD III:
Make a sweet some, siles,
Why, thou's the surst
To perter the day a hard,
Too hath the dost, thou strittle.

COMINIUS:
Why what all a service here,
I'll there a burged to hath for a must belight this true.

BENVOLIO:
To making h


Training - Epoch: 18/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1090.38chunks/s, loss=1.47, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou hast the world,
That we may be so the married the strength,
That we may be so the married the strength,
That we may be so the married the strength,
That we may be so the married the strength,
That we may be so the married the strength,
That we may be so the married the strength,
That we may be so t


Training - Epoch: 19/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1163.73chunks/s, loss=1.45, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou hast me but to-morrow to and the preasing
Agone,
To see thy since anot thou art,--

Richar:
Is thy sound whole that will
To be a most buin to them to mare. The man it so, there, when,
Is thy for thee.

POMPEY:
I was nother.

MISTRENNE:
Who sour heart, my fortune than any as they to
the sends.

COMI


Training - Epoch: 19/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1092.12chunks/s, loss=1.45, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou art the state,
And there is a soldier to the sunters and the sun the sun the sun the sun the sun the sun the sun the sun the sun the sun the sun the sun the sun the sun the sun the sun the sun the sun the sun the sun the sun the sun the sun the sun the sun the sun the sun the sun the sun the sun th


Training - Epoch: 20/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1142.16chunks/s, loss=1.46, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou whom thy honour,
To more than we see my sovereign men:
Then, and so soul abood. Thy father and serves how that, I'll stays, and they true monoth to make my possare and the wife,
The brother with the tangard,
In pains. What, if it be talk the tame,
I was the cloods,
And there word made a persead, al


Training - Epoch: 20/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1073.25chunks/s, loss=1.46, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou hast thou shall be so to the prince the sease the sease the sease the sease the sease the sease the sease the sease the sease the sease the sease the sease the sease the sease the sease the sease the sease the sease the sease the sease the sease the sease the sease the sease the sease the sease the


Training - Epoch: 21/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1122.01chunks/s, loss=1.42, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou have a should he his down thy fore against a friar,
And nays say that I had
He shall therefore that he his engething that an all.

KING RENCEPEY:
Your changed to be so.

CORIOLANUS:
Have, worsh our hand to speak and toneral
he heard the while I him a words that speching treams;
Frunt it, be not see


Training - Epoch: 21/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1073.45chunks/s, loss=1.42, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou hast thou the world that the world the sen to the world that the world the sen to the world that the world the sen to the world that the world the sen to the world that the world the sen to the world that the world the sen to the world that the world the sen to the world that the world the sen to t


Training - Epoch: 22/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1156.81chunks/s, loss=1.43, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou, the wish any sign of a bear how
Or sea strung himself here burn to deep than hears to my fault
Then strokents for his son tell and time of hath find,
The gracessish'd and have much makes a posce last shall have the than trumpets.

LADY ANNE:
Why, hast thou wilt for my counternice of my love;
And t


Training - Epoch: 22/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1087.58chunks/s, loss=1.43, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou art the words and the good to the world and the good to the world and the good to the world and the good to the world and the good to the world and the good to the world and the good to the world and the good to the world and the good to the world and the good to the world and the good to the world


Training - Epoch: 23/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1117.74chunks/s, loss=1.4, lr=0.005, run:=LSTM] 

Top-K sampling -----------------
O Romeo, wherefore art thou shalt should break'st brother will to to mine, and the words. Clarities.

KING RICHARD III:
A petit him of the sur,
Ay.
My free my free my signing to thy son were all to my stands
The crown is that would tale her son, I am this better bleath the stame that them hark,
I will their streathest, and wi


Training - Epoch: 23/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1052.25chunks/s, loss=1.4, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou shalt be so far as the death and the death and the death and the death and the death and the death and the death and the death and the death and the death and the death and the death and the death and the death and the death and the death and the death and the death and the death and the death and 


Training - Epoch: 24/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1151.48chunks/s, loss=1.4, lr=0.005, run:=LSTM] 

Top-K sampling -----------------
O Romeo, wherefore art thou, and. What conscience a tile some fair and her
shep a shale, that so such me and the conquirence in them;
In that thou wilt a tortal diech of his execution withis
Till house when he stall
Fear you the prince his strember.

Clown:
I will to thy hand,
True time himself.-
Both true my goant.

LUCENTIO


Training - Epoch: 24/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1098.80chunks/s, loss=1.4, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou art the sun.

CORIOLANUS:
What shall be so fare and the seas and the sun of the world,
And the prince and the sun of the world,
And the prince and the sun of the world,
And the prince and the sun of the world,
And the prince and the sun of the world,
And the prince and the sun of the world,
And the


Training - Epoch: 25/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1155.44chunks/s, loss=1.38, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou art
True--

KI GoE:
The plood and businems,
Inlest thou, to have you, so long again. I will be so it is son,
So her him to be most good for thy bready:
Thy holours baugh at to my husband
That I will see that him,
I am not there,
We cannot do in thee, are thou have
I would hover a down.

DUKE OF YOR


Training - Epoch: 25/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1086.69chunks/s, loss=1.38, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou be the world and the sen to the senous and the sen to the senous and the sen to the senous and the sen to the senous and the sen to the senous and the sen to the senous and the sen to the senous and the sen to the senous and the sen to the senous and the sen to the senous and the sen to the senous 


Training - Epoch: 26/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1164.18chunks/s, loss=1.42, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou, and have you bear his begghence,
Though it is mother and at the sin will not, sister of a scall.

LEONTES:
Tell yet he is, mine eyes of this woold, that I al acts and they so much and what thou hast make me thy more with the precealy it.

CLow:
My brows of continulance, and with many.

LEONTESO:
A


Training - Epoch: 26/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1103.74chunks/s, loss=1.42, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou with the sun to the world,
And then the great double and a prince the law from the common for the sun
Which is the sun to the world and the secret and the sun to the world,
And then the great double and a prince the law from the common for the sun
Which is the sun to the world and the secret and th


Training - Epoch: 27/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1166.78chunks/s, loss=1.4, lr=0.005, run:=LSTM] 

Top-K sampling -----------------
O Romeo, wherefore art thou some mine,
And me that said wishom or that he doth.

Second Murderer:
I have done without
The woundly father, a man be say nor his still,
And there that honest that take hearing all myself
Would hath but there?

Provost:
He seem' to me;
And then whom I am part and seen of the service.

CORIOL:
We a


Training - Epoch: 27/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1089.78chunks/s, loss=1.4, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou have seen the world in the world,
And there is a word and heart that show the prince and so strange of the world,
And there is a word and heart that show the prince and so strange of the world,
And there is a word and heart that show the prince and so strange of the world,
And there is a word and h


Training - Epoch: 28/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1161.41chunks/s, loss=1.35, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou whom I have tening mine
Whom thou art thou dost thou whom will'st this,
In they are to her:
In the sore to him he they are this itself.

POLIXENES:
That I must break too
Trunt and to thy pains:
If this bear his seen of thee, and that say that I did monest our curse your child; and see thee
I devil 


Training - Epoch: 28/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1091.81chunks/s, loss=1.35, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou worthy son, and the state and the sun the rest to be married to his court, and the sun the rest to be married to his court, and the sun the rest to be married to his court, and the sun the rest to be married to his court, and the sun the rest to be married to his court, and the sun the rest to be m


Training - Epoch: 29/30: 100%|█████████▉| 9984/10000 [00:08<00:00, 1158.24chunks/s, loss=1.36, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou, Is there soul to plament,
And they do nature
In sometitions.

COMINIUS:
Therefore for him hence taskel and
Thou art.

PETRUCHIO:
You shall play the way.
I call he dowly son are, and by you.

First Catizen:
His wrong'd thy hope,
If you have been a part,
But tell the words, sir!

PAULINA:
No, nake a


Training - Epoch: 29/30: 100%|█████████▉| 9984/10000 [00:09<00:00, 1090.13chunks/s, loss=1.36, lr=0.005, run:=LSTM]

Greedy sampling ----------------
O Romeo, wherefore art thou be the world and the seatent of the world and the seatent of the world and the seatent of the world and the seatent of the world and the seatent of the world and the seatent of the world and the seatent of the world and the seatent of the world and the seatent of the world and the seatent of the wo





# Task 2: Character generation transformer network implementation
Our simple transformer-like network will take as input a sequence of characters and predict the next character in the sequence. To ensure an efficient training procedure, masked attention modules will be used as in the [GPT model](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf).

For this task you must implement the Scaled dot product attention module and the Masked multi-head attention module. Both of these modules are described in the [Attention is all you need](https://arxiv.org/pdf/1706.03762.pdf) paper (See Figure 2 in the paper as well as Sections 3.2.1, 3.2.2 and 3.2.3). They are the core operations of transformers. As we will use our model for text generation also add the masking operation shown as (mask opt.) in Figure 2, implemented as AttentionMasking in the code.

**Implement the modules in the ScaledDotProductAttention class and the MultiHeadAttention class.**

Read the GPT paper and the Attention is all you need paper for a better understanding of the components. For a more high level overview, this [post](https://jalammar.github.io/illustrated-gpt2/) may also be helpful.


In [40]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset
import math
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1000):
        super().__init__()
        # Positional encoding adds the positional information to the
        # embedding. Without it the model would not be able to differentiate
        # between different characters orders such as between "dog" and "god".
        position = torch.arange(max_len).unsqueeze(1).float()
        div_term = 10000.0**(torch.arange(0,d_model,2).float()/d_model)
        print(div_term.shape)
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position / div_term)
        pe[:, 1::2] = torch.cos(position / div_term)
        pe = pe.unsqueeze(0)
        self.pe = pe.cuda()
        self.pe.requires_grad = False

    def forward(self, x):
        p = self.pe[:, :x.size(1)]
        return p

class AttentionMasking(nn.Module):
    def __init__(self, max_len):
        super(AttentionMasking, self).__init__()
        self.register_buffer("mask", torch.tril(torch.ones(max_len, max_len))
                                     .view(1, 1, max_len, max_len))
    def forward(self,x):
        length = x.shape[-1]
        out = x.masked_fill(self.mask[:,:,:length,:length] == 0, float('-inf'))
        return out


class ScaledDotProductAttention(nn.Module):
    def __init__(self, max_len):
        super(ScaledDotProductAttention, self).__init__()
        self.softmax = nn.Softmax(dim=-1)
        # Multiply with an upper triangular
        # matrix of dimensions (length x length) after the scale operation
        # in Figure 2 of the paper.
        self.mask_opt = AttentionMasking(max_len)


    def forward(self,q,k,v):
      # length = number of input tokens
      batch_size,num_heads,length,num_neuron = k.size()
      # TODO: Implement the scaled dot product attention as described in
      # the Attention is all you need paper in Equation 1
      pass


class MultiHeadAttention(nn.Module):
    def __init__(self, dim_model, num_neuron, n_head, max_len):
        super(MultiHeadAttention, self).__init__()
        self.n_head = n_head
        self.num_neuron = num_neuron

        # TODO: Initialize the ScaledDotProductAttention and other
        # necessary components.

    def split(self,tensor):
        batch_size, length, total_dim = tensor.size()
        # Reshape the tensor to enable the use in
        # the ScaledDotProductAttention module
        split_tensor = tensor.view(batch_size, length, self.n_head, self.num_neuron).transpose(1,2)
        return split_tensor

    def concat(self,tensor):
        batch_size, num_heads, length, num_neuron = tensor.size()
        # Reshape the tensor to its original size before the split operation.
        concat_tensor = tensor.transpose(1,2).contiguous().view(batch_size, length, self.n_head*self.num_neuron)
        return concat_tensor


    def forward(self, q, k, v):
        # TODO: Implement the Masked Multi-head attention module as described in the
        # Attention is all you need paper in Figure 1 and Section 3.2.2.
        pass



class PositionFeedForwardNet(nn.Module):
    def __init__(self, dim_model):
        super(PositionFeedForwardNet, self).__init__()
        self.ff_net1 = nn.Linear(dim_model, dim_model*4)
        self.ff_net2 = nn.Linear(dim_model*4, dim_model)
    def forward(self,x):
        ff_out = self.ff_net1(x)
        ff_out = torch.nn.functional.relu(ff_out)
        ff_out = self.ff_net2(ff_out)
        return ff_out


class TransformerBlock(nn.Module):
    def __init__(self, dim_model, num_neuron, n_head, max_len):
        super(TransformerBlock, self).__init__()
        self.mha = MultiHeadAttention(dim_model, num_neuron, n_head, max_len)
        self.l_norm = torch.nn.LayerNorm(dim_model)
        self.l_norm2 = torch.nn.LayerNorm(dim_model)
        self.ff_net = PositionFeedForwardNet(dim_model)
        # b, len_seq, n_head, num_neuron

    def forward(self, x):
      # A Transformer block as described in the
      # Attention is all you need paper. In Figure 1 the transformer
      # block is marked with a gray rectangle right of the text "Nx"
      _x = x
      mha1 = self.mha(x,x,x)
      lnorm = self.l_norm(_x+mha1)
      _x = lnorm
      ff_out = self.ff_net(lnorm)
      out = self.l_norm2(ff_out+_x)

      return out

class TransformerSimple(nn.Module):
    def __init__(self, seq_length, input_dim, output_dim,
                 batch_size):
        super(TransformerSimple, self).__init__()
        num_neuron = 64
        n_head = 8
        dim_model=256
        max_len = 512
        self.start_embedding = nn.Embedding(input_dim, dim_model)

        self.pos_embedding = PositionalEncoding(dim_model)

        # b x l x c*n_head
        self.t_block1 = TransformerBlock(dim_model, num_neuron, n_head, max_len)
        self.t_block2 = TransformerBlock(dim_model, num_neuron, n_head, max_len)
        self.t_block3 = TransformerBlock(dim_model, num_neuron, n_head, max_len)
        self.t_block4 = TransformerBlock(dim_model, num_neuron, n_head, max_len)
        self.t_block5 = TransformerBlock(dim_model, num_neuron, n_head, max_len)

        #self.out_layer_1 = nn.Linear(dim_model, dim_model)
        self.output_layer = nn.Linear(dim_model,output_dim)

    def forward(self,x):
      # x - Tensor - (b, seq_len)
      # Embeds the input tensor from tokens to features
      s_emb = self.start_embedding(x)
      # Adds positional embeddings
      p_emb = self.pos_embedding(s_emb)
      b_out = p_emb + s_emb
      # Transformer blocks - You can experiment with varying depth
      # For example GPT uses 12 blocks but this might be a bit memory intensive
      b_out = self.t_block1(b_out)
      b_out = self.t_block2(b_out)
      b_out = self.t_block3(b_out)
      b_out = self.t_block4(b_out)
      b_out = self.t_block5(b_out)

      # Output mapping to a classification of output tokens
      # For each token the network tries to predict the next token
      # based only on the previous tokens.
      # Output shape: (b x seq_len x vocabulary_size)
      out = self.output_layer(b_out)

      return out






## Dataset class


In [41]:
import unidecode
import string
import random
from torch.autograd import Variable
from torch.utils.data import Dataset

class TextDataset(Dataset):
    def __init__(self, chunk_len=200, padded_chunks=False):
        # Character based dataset
        dataset_path = "./input.txt"
        # The tokens in the vocabulary (all_characters)
        # are just the printable characters of the string class
        self.all_characters = string.printable
        self.n_characters = len(self.all_characters)
        # Maps characters to indices
        self.char_dict = {x:i for i,x in enumerate(self.all_characters)}
        self.file, self.file_len = self.read_file(dataset_path)
        # Sequence length of the input
        self.chunk_len = chunk_len
        self.encoded_file = [self.char_dict[x] for x in self.file]

    def read_file(self,filename):
        file = unidecode.unidecode(open(filename).read())
        return file, len(file)

    def encode_text(self,in_str):
        # in_str - input sequence - String
        # Returns - in_str mapped to tokens in char_dict
        tensor = torch.LongTensor([self.char_dict[x] for x in in_str])
        return tensor

    def __getitem__(self, idx):
        inp, target = self.get_random_text()
        return {"input":inp, "target":target}

    def __len__(self):
        return 10000

    def get_random_text(self):
        # Pick a random string of length self.chunk_len from the dataset
        start_index = np.random.randint(0, self.file_len - self.chunk_len)
        end_index = start_index + self.chunk_len + 1
        chunk = self.encoded_file[start_index:end_index]
        # input_tokens - random sequence of tokens from the dataset
        input_tokens = torch.LongTensor(chunk[:-1])
        # target - input token sequence shifted by 1
        # the idea is to predict next token for each token in the input sequence
        # therefore if the input is [1,2,3,4] the target is [2,3,4,5]
        target = torch.LongTensor(chunk[1:])
        input_tokens = input_tokens.cuda()
        target = target.cuda()
        return input_tokens, target


## Character sampling

To generate text the network must predict the next character in a sequence, however networks do not produce a single character but rather estimate the likelihood for each possible character. Sampling characters from the network output can be done in different ways with common ones being the Greedy sampling process and Top-K sampling.

In the simple greedy sampling method the network takes a text prompt as input and generates an additional N tokens by always taking the token with the highest prediction score as the next token.

In the Top-K sampling, randomness is added to the sampling process as the network samples from K most likely predicitons at each step. This alleviates the problem of generative models repeating text but may generate incorrect text by sampling inappropriate tokens.


In [42]:
def topk_sampling_iter_transformer(model, x, num_chars, chunk_len, output_token):
    # x -- b x onehot_char
    # x = b x l
    outputs = torch.zeros((1,num_chars))
    inp = x

    for t in range(num_chars):
        # b x onehot_char
        output = model(inp.long())[0,-1:]
        #output = torch.softmax(output, dim=1)
        # b x 3
        output_vals, output_ind = torch.topk(output, 5, dim=1)
        # 3 -> int
        output_vals = torch.softmax(output_vals, dim=1)
        top_ind = torch.multinomial(output_vals[0], 1)[0]
        # int
        out_char_index = output_ind[0,top_ind]
        # int -> 1
        out_char_index = torch.ones(1).cuda() * out_char_index

        outputs[:,t] = out_char_index.item()
        if inp.shape[1] > chunk_len:
          inp = torch.cat((inp[:,1:], out_char_index.unsqueeze(0)), dim=1)
        else:
          inp = torch.cat((inp, out_char_index.unsqueeze(0)), dim=1)

    return outputs


def greedy_sampling_iter_transformer(model, x, num_chars, chunk_len, output_token):
    # x -- shape (batch, tokens in x)
    outputs = torch.zeros((1,num_chars))
    inp = x

    for t in range(num_chars):
        # b x l x onehot_char
        output = model(inp.long())[0,-1:]
        output = torch.softmax(output, dim=1)
        out_char_index = torch.argmax(output, dim=1)
        outputs[:,t] = out_char_index.item()
        if inp.shape[1] > chunk_len:
          inp = torch.cat((inp[:,1:], out_char_index.unsqueeze(0)), dim=1)
        else:
          inp = torch.cat((inp, out_char_index.unsqueeze(0)), dim=1)

    return outputs






## Transformer model training

With a correct implementation you should get sensible text generation results with the set parameters, however you should experiment with various parameters,
especially with the sequence length (chunk_len) used during training.

In [43]:
from tqdm import tqdm
import torch.optim as optim


#Sample parameters, use whatever you see fit.
batch_size = 256
chunk_len = 128
train_dataset = TextDataset(chunk_len=chunk_len)
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=0)

input_dim = train_dataset.n_characters
output_dim = train_dataset.n_characters
learning_rate = 0.0006

model = TransformerSimple(chunk_len, input_dim, output_dim,batch_size)
model.train()
model.cuda()

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

epochs=50

for epoch in range(epochs):
    with tqdm(total=len(trainloader.dataset), desc ='Training - Epoch: '+str(epoch)+"/"+str(epochs), unit='chunks') as prog_bar:
        for i, data in enumerate(trainloader, 0):
            # inputs - shape (batch_size, chunk_len) - Tensor of vocabulary tokens
            inputs = data['input'].long()
            # labels - shape (batch_size, chunk_len) - Tensor of vocabulary tokens
            labels = data['target'].long()

            optimizer.zero_grad()
            outputs = model(inputs)
            target_t = labels
            loss = criterion(outputs.view(inputs.shape[0]*inputs.shape[1],-1),target_t.view(labels.shape[0]*labels.shape[1]))
            loss.backward()
            optimizer.step()
            prog_bar.set_postfix(**{'run:': "Transformer", 'lr': learning_rate,
                                    'loss': loss.item()
                                    })
            prog_bar.update(batch_size)

        # Intermediate text output
        sample_texts = ["What authority surfeits on",
                        "I say unto you, what he hath done famously, he did it to that end:",
                        "That in submission will return to us: And then, as we have ta'en the sacrament,"]
        output_token = torch.zeros(1,1).cuda()
        output_token[0,0] = train_dataset.n_characters-1
        print("Top-K sampling")
        for sample_text in sample_texts:
            sample_encoding = train_dataset.encode_text(sample_text)
            sample_input = Variable(sample_encoding).cuda().unsqueeze(0).long()

            #out_test= greedy_sampling_iter_transformer(model, sample_input, 400, chunk_len, output_token)[0]
            out_test= topk_sampling_iter_transformer(model, sample_input, 400, chunk_len, output_token)[0]
            out_char_index = out_test.long().detach().cpu().numpy()
            out_chars = sample_text+" "+"".join([train_dataset.all_characters[i] for i in out_char_index])
            print("----------------------------------------")
            print(out_chars)




torch.Size([128])


Training - Epoch: 0/50:   0%|          | 0/10000 [00:00<?, ?chunks/s]


TypeError: unsupported operand type(s) for +: 'Tensor' and 'NoneType'

## Text sampling - Transformers


In [None]:
sample_text = "Here's to my love! O true apothecary! Thy drugs are quick."
sample_encoding = train_dataset.encode_text(sample_text)
sample_input = Variable(sample_encoding).cuda().unsqueeze(0).long()

#out_test= greedy_sampling_iter_transformer(model, sample_input, 400, chunk_len, output_token)[0]
out_test= topk_sampling_iter_transformer(model, sample_input, 400, chunk_len, output_token)[0]
out_char_index = out_test.long().detach().cpu().numpy()
out_chars = sample_text+" "+"".join([train_dataset.all_characters[i] for i in out_char_index])
print("----------------------------------------")
print(out_chars)


: 