# Generative networks

Recurrent Neural Networks (RNNs) and their gated cell variants, such as Long Short Term Memory Cells (LSTMs) and Gated Recurrent Units (GRUs), introduced a way to model language. In other words, they can learn the order of words and predict the next word in a sequence. This capability allows us to use RNNs for **generative tasks**, such as regular text generation, machine translation, and even image captioning.

In the RNN architecture we discussed in the previous unit, each RNN unit produced the next hidden state as its output. However, we can also add another output to each recurrent unit, enabling us to generate a **sequence** (of the same length as the original sequence). Additionally, we can use RNN units that do not take an input at every step but instead start with an initial state vector and then generate a sequence of outputs.

In this notebook, we will focus on simple generative models that help us create text. To keep things straightforward, we will build a **character-level network**, which generates text one letter at a time. During training, we will take a text corpus and split it into sequences of letters.


In [1]:
import torch
import torchtext
import numpy as np
from torchnlp import *
train_dataset,test_dataset,classes,vocab = load_dataset()

Loading dataset...
Building vocab...


## Building character vocabulary

To create a character-level generative network, the text needs to be divided into individual characters rather than words. This can be achieved by specifying a different tokenizer:


In [2]:
def char_tokenizer(words):
    return list(words) #[word for word in words]

counter = collections.Counter()
for (label, line) in train_dataset:
    counter.update(char_tokenizer(line))
vocab = torchtext.vocab.vocab(counter)

vocab_size = len(vocab)
print(f"Vocabulary size = {vocab_size}")
print(f"Encoding of 'a' is {vocab.get_stoi()['a']}")
print(f"Character with code 13 is {vocab.get_itos()[13]}")

Vocabulary size = 82
Encoding of 'a' is 1
Character with code 13 is c


Let's see the example of how we can encode the text from our dataset:


In [3]:
def enc(x):
    return torch.LongTensor(encode(x,voc=vocab,tokenizer=char_tokenizer))

enc(train_dataset[0][1])

tensor([ 0,  1,  2,  2,  3,  4,  5,  6,  3,  7,  8,  1,  9, 10,  3, 11,  2,  1,
        12,  3,  7,  1, 13, 14,  3, 15, 16,  5, 17,  3,  5, 18,  8,  3,  7,  2,
         1, 13, 14,  3, 19, 20,  8, 21,  5,  8,  9, 10, 22,  3, 20,  8, 21,  5,
         8,  9, 10,  3, 23,  3,  4, 18, 17,  9,  5, 23, 10,  8,  2,  2,  8,  9,
        10, 24,  3,  0,  1,  2,  2,  3,  4,  5,  9,  8,  8,  5, 25, 10,  3, 26,
        12, 27, 16, 26,  2, 27, 16, 28, 29, 30,  1, 16, 26,  3, 17, 31,  3, 21,
         2,  5,  9,  1, 23, 13, 32, 16, 27, 13, 10, 24,  3,  1,  9,  8,  3, 10,
         8,  8, 27, 16, 28,  3, 28,  9,  8,  8, 16,  3,  1, 28,  1, 27, 16,  6])

## Training a generative RNN

The method we will use to train the RNN to generate text works as follows. At each step, we will take a sequence of characters of length `nchars` and ask the network to predict the next output character for each input character:

![Image showing an example RNN generation of the word 'HELLO'.](../../../../../translated_images/rnn-generate.56c54afb52f9781d63a7c16ea9c1b86cb70e6e1eae6a742b56b7b37468576b17.en.png)

Depending on the specific use case, we might also want to include special characters, such as *end-of-sequence* `<eos>`. In our case, we aim to train the network for continuous text generation, so we will set the size of each sequence to a fixed length of `nchars` tokens. As a result, each training example will consist of `nchars` inputs and `nchars` outputs (where the output sequence is the input sequence shifted one character to the left). A minibatch will contain several such sequences.

To generate minibatches, we will take each news text of length `l` and create all possible input-output combinations from it (there will be `l-nchars` such combinations). These combinations will form one minibatch, and the size of the minibatches will vary at each training step.


In [4]:
nchars = 100

def get_batch(s,nchars=nchars):
    ins = torch.zeros(len(s)-nchars,nchars,dtype=torch.long,device=device)
    outs = torch.zeros(len(s)-nchars,nchars,dtype=torch.long,device=device)
    for i in range(len(s)-nchars):
        ins[i] = enc(s[i:i+nchars])
        outs[i] = enc(s[i+1:i+nchars+1])
    return ins,outs

get_batch(train_dataset[0][1])

(tensor([[ 0,  1,  2,  ..., 28, 29, 30],
         [ 1,  2,  2,  ..., 29, 30,  1],
         [ 2,  2,  3,  ..., 30,  1, 16],
         ...,
         [20,  8, 21,  ...,  1, 28,  1],
         [ 8, 21,  5,  ..., 28,  1, 27],
         [21,  5,  8,  ...,  1, 27, 16]]),
 tensor([[ 1,  2,  2,  ..., 29, 30,  1],
         [ 2,  2,  3,  ..., 30,  1, 16],
         [ 2,  3,  4,  ...,  1, 16, 26],
         ...,
         [ 8, 21,  5,  ..., 28,  1, 27],
         [21,  5,  8,  ...,  1, 27, 16],
         [ 5,  8,  9,  ..., 27, 16,  6]]))

Now let's define the generator network. It can be based on any recurrent cell we discussed in the previous unit (simple, LSTM, or GRU). In our example, we will use LSTM.

Since the network takes characters as input and the vocabulary size is relatively small, we don't need an embedding layer; a one-hot-encoded input can be fed directly into the LSTM cell. However, because we pass character indices as input, we need to one-hot-encode them before passing them to the LSTM. This is achieved by calling the `one_hot` function during the `forward` pass. The output encoder will be a linear layer that transforms the hidden state into a one-hot-encoded output.


In [5]:
class LSTMGenerator(torch.nn.Module):
    def __init__(self, vocab_size, hidden_dim):
        super().__init__()
        self.rnn = torch.nn.LSTM(vocab_size,hidden_dim,batch_first=True)
        self.fc = torch.nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, s=None):
        x = torch.nn.functional.one_hot(x,vocab_size).to(torch.float32)
        x,s = self.rnn(x,s)
        return self.fc(x),s

During training, we want to be able to sample generated text. To achieve this, we will define the `generate` function, which will produce an output string of length `size`, starting from the initial string `start`.

Here’s how it works: First, we pass the entire start string through the network, obtaining the output state `s` and the next predicted character `out`. Since `out` is one-hot encoded, we use `argmax` to determine the index of the character `nc` in the vocabulary. Then, we use `itos` to find the actual character, which we append to the resulting list of characters `chars`. This process of generating one character is repeated `size` times to produce the desired number of characters.


In [8]:
def generate(net,size=100,start='today '):
        chars = list(start)
        out, s = net(enc(chars).view(1,-1).to(device))
        for i in range(size):
            nc = torch.argmax(out[0][-1])
            chars.append(vocab.get_itos()[nc])
            out, s = net(nc.view(1,-1),s)
        return ''.join(chars)

Now let's start the training! The training loop is almost identical to what we've used in previous examples, but instead of accuracy, we display sampled generated text every 1000 epochs.

Pay special attention to how we calculate the loss. The loss needs to be computed using the one-hot-encoded output `out` and the expected text `text_out`, which is a list of character indices. Fortunately, the `cross_entropy` function takes the unnormalized network output as its first argument and the class number as its second argument, which is exactly what we need. It also automatically averages over the minibatch size.

To avoid waiting too long, we limit the training to `samples_to_train` samples. We encourage you to experiment and try longer training sessions, possibly spanning several epochs (in which case you'd need to create an additional loop around this code).


In [9]:
net = LSTMGenerator(vocab_size,64).to(device)

samples_to_train = 10000
optimizer = torch.optim.Adam(net.parameters(),0.01)
loss_fn = torch.nn.CrossEntropyLoss()
net.train()
for i,x in enumerate(train_dataset):
    # x[0] is class label, x[1] is text
    if len(x[1])-nchars<10:
        continue
    samples_to_train-=1
    if not samples_to_train: break
    text_in, text_out = get_batch(x[1])
    optimizer.zero_grad()
    out,s = net(text_in)
    loss = torch.nn.functional.cross_entropy(out.view(-1,vocab_size),text_out.flatten()) #cross_entropy(out,labels)
    loss.backward()
    optimizer.step()
    if i%1000==0:
        print(f"Current loss = {loss.item()}")
        print(generate(net))

Current loss = 4.398899078369141
today sr sr sr sr sr sr sr sr sr sr sr sr sr sr sr sr sr sr sr sr sr sr sr sr sr sr sr sr sr sr sr sr sr s
Current loss = 2.161320447921753
today and to the tor to to the tor to to the tor to to the tor to to the tor to to the tor to to the tor t
Current loss = 1.6722588539123535
today and the court to the could to the could to the could to the could to the could to the could to the c
Current loss = 2.423795223236084
today and a second to the conternation of the conternation of the conternation of the conternation of the 
Current loss = 1.702607274055481
today and the company to the company to the company to the company to the company to the company to the co
Current loss = 1.692358136177063
today and the company to the company to the company to the company to the company to the company to the co
Current loss = 1.9722288846969604
today and the control the control the control the control the control the control the control the control 
Current loss = 1.8

This example already generates some pretty good text, but there are several ways to improve it further:

* **Better minibatch generation**. The way we prepared data for training was to generate one minibatch from one sample. This approach is not ideal because minibatches end up being of different sizes, and some cannot even be generated if the text is smaller than `nchars`. Additionally, small minibatches do not utilize the GPU efficiently. A better approach would be to take one large chunk of text from all samples, generate all input-output pairs, shuffle them, and then create minibatches of equal size.

* **Multilayer LSTM**. It might be worth trying 2 or 3 layers of LSTM cells. As mentioned in the previous section, each LSTM layer extracts specific patterns from the text. In the case of a character-level generator, the lower LSTM layers might focus on extracting syllables, while the higher layers could handle words and word combinations. This can be easily implemented by passing a number-of-layers parameter to the LSTM constructor.

* You could also experiment with **GRU units** to see which performs better, as well as with **different hidden layer sizes**. A hidden layer that is too large might lead to overfitting (e.g., the network memorizing the exact text), while a smaller size might not yield good results.


## Soft text generation and temperature

In the previous definition of `generate`, we always selected the character with the highest probability as the next character in the generated text. This often caused the text to "loop" through the same character sequences repeatedly, as shown in this example:
```
today of the second the company and a second the company ...
```

However, if we examine the probability distribution for the next character, we might notice that the difference between the top probabilities isn't very large. For instance, one character might have a probability of 0.2, while another has 0.19, and so on. For example, when determining the next character in the sequence '*play*', the next character could just as easily be a space or **e** (as in the word *player*).

This brings us to the realization that it isn't always "fair" to pick the character with the highest probability, as choosing the second-highest might still result in meaningful text. A better approach is to **sample** characters from the probability distribution provided by the network's output.

This sampling can be achieved using the `multinomial` function, which implements the so-called **multinomial distribution**. Below is a function that performs this **soft** text generation:


In [10]:
def generate_soft(net,size=100,start='today ',temperature=1.0):
        chars = list(start)
        out, s = net(enc(chars).view(1,-1).to(device))
        for i in range(size):
            #nc = torch.argmax(out[0][-1])
            out_dist = out[0][-1].div(temperature).exp()
            nc = torch.multinomial(out_dist,1)[0]
            chars.append(vocab.get_itos()[nc])
            out, s = net(nc.view(1,-1),s)
        return ''.join(chars)
    
for i in [0.3,0.8,1.0,1.3,1.8]:
    print(f"--- Temperature = {i}\n{generate_soft(net,size=300,start='Today ',temperature=i)}\n")

--- Temperature = 0.3
Today and a company and complete an all the land the restrational the as a security and has provers the pay to and a report and the computer in the stand has filities and working the law the stations for a company and with the company and the final the first company and refight of the state and and workin

--- Temperature = 0.8
Today he oniis its first to Aus bomblaties the marmation a to manan  boogot that pirate assaid a relaid their that goverfin the the Cappets Ecrotional Assonia Cition targets it annight the w scyments Blamity #39;s TVeer Diercheg Reserals fran envyuil that of ster said access what succers of Dour-provelith

--- Temperature = 1.0
Today holy they a 11 will meda a toket subsuaties, engins for Chanos, they's has stainger past to opening orital his thempting new Nattona was al innerforder advan-than #36;s night year his religuled talitatian what the but with Wednesday to Justment will wemen of Mark CCC Camp as Timed Nae wome a leaders

--- Temper

We have introduced one more parameter called **temperature**, which is used to indicate how hard we should stick to the highest probability. If temperature is 1.0, we do fair multinomial sampling, and when temperature goes to infinity - all probabilities become equal, and we randomly select next character. In the example below we can observe that the text becomes meaningless when we increase the temperature too much, and it resembles "cycled" hard-generated text when it becomes closer to 0.



---

**Disclaimer**:  
This document has been translated using the AI translation service [Co-op Translator](https://github.com/Azure/co-op-translator). While we aim for accuracy, please note that automated translations may include errors or inaccuracies. The original document in its native language should be regarded as the authoritative source. For critical information, professional human translation is advised. We are not responsible for any misunderstandings or misinterpretations resulting from the use of this translation.
