<a href="https://colab.research.google.com/github/MandilKarki/character-level-text-generation/blob/main/4_character_text_generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# char-RNN: Character-level text generation

Generate weight loss articles using a character-level RNN.

See the [classic Karpathy post](http://karpathy.github.io/2015/05/21/rnn-effectiveness/) on this topic.

In [1]:
!pip install boltons -q

[?25l[K     |██                              | 10kB 21.2MB/s eta 0:00:01[K     |███▉                            | 20kB 21.5MB/s eta 0:00:01[K     |█████▊                          | 30kB 10.5MB/s eta 0:00:01[K     |███████▊                        | 40kB 8.6MB/s eta 0:00:01[K     |█████████▋                      | 51kB 7.4MB/s eta 0:00:01[K     |███████████▌                    | 61kB 7.5MB/s eta 0:00:01[K     |█████████████▌                  | 71kB 8.1MB/s eta 0:00:01[K     |███████████████▍                | 81kB 8.6MB/s eta 0:00:01[K     |█████████████████▎              | 92kB 8.0MB/s eta 0:00:01[K     |███████████████████▎            | 102kB 8.2MB/s eta 0:00:01[K     |█████████████████████▏          | 112kB 8.2MB/s eta 0:00:01[K     |███████████████████████         | 122kB 8.2MB/s eta 0:00:01[K     |█████████████████████████       | 133kB 8.2MB/s eta 0:00:01[K     |███████████████████████████     | 143kB 8.2MB/s eta 0:00:01[K     |████████████████████████

In [2]:
import string
from pathlib import Path
from textwrap import wrap


import numpy as np
import pandas as pd
from boltons.iterutils import windowed
from tqdm import tqdm, tqdm_notebook

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import random_split
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from google_drive_downloader import GoogleDriveDownloader as gdd

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
DATA_PATH = 'data/weight_loss/articles.jsonl'
if not Path(DATA_PATH).is_file():
    gdd.download_file_from_google_drive(
        file_id='1mafPreWzE-FyLI0K-MUsXPcnUI0epIcI',
        dest_path='data/weight_loss/weight_loss_articles.zip',
        unzip=True,
    )

Downloading 1mafPreWzE-FyLI0K-MUsXPcnUI0epIcI into data/weight_loss/weight_loss_articles.zip... Done.
Unzipping...Done.


In [5]:
def load_data(path, sequence_length=125):
    texts = pd.read_json(path).text.sample(100).str.lower().tolist()
    chars_windowed = [list(windowed(text, sequence_length)) for text in texts]
    all_chars_windowed = [sublst for lst in chars_windowed for sublst in lst]
    filtered_good_chars = [
        sequence for sequence in tqdm_notebook(all_chars_windowed) 
        if all(char in string.printable for char in sequence)
    ]
    return filtered_good_chars


def get_unique_chars(sequences):
    return {sublst for lst in sequences for sublst in lst}


def create_char2idx(sequences):
    unique_chars = get_unique_chars(sequences)
    return {char: idx for idx, char in enumerate(sorted(unique_chars))}


def encode_sequence(sequence, char2idx):
    return [char2idx[char] for char in sequence]


def encode_sequences(sequences, char2idx):
    return np.array([
        encode_sequence(sequence, char2idx) 
        for sequence in tqdm_notebook(sequences)
    ])


class Sequences(Dataset):
    def __init__(self, path, sequence_length=125):
        self.sequences = load_data(DATA_PATH, sequence_length=sequence_length)
        self.vocab_size = len(get_unique_chars(self.sequences))
        self.char2idx = create_char2idx(self.sequences)
        self.idx2char = {idx: char for char, idx in self.char2idx.items()}
        self.encoded = encode_sequences(self.sequences, self.char2idx)
        
    def __getitem__(self, i):
        return self.encoded[i, :-1], self.encoded[i, 1:]
    
    def __len__(self):
        return len(self.encoded)

In [6]:
dataset = Sequences(DATA_PATH, sequence_length=128)
len(dataset)
train_loader = DataLoader(dataset, batch_size=4096)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


HBox(children=(FloatProgress(value=0.0, max=272790.0), HTML(value='')))




Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, max=268752.0), HTML(value='')))




## GRU

![](images/gru_equations.png)

![](images/gru_diagram.png)

In [7]:
class RNN(nn.Module):
    def __init__(
        self,
        vocab_size,
        embedding_dimension=100,
        hidden_size=128, 
        n_layers=1,
        device='cpu',
    ):
        super(RNN, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.device = device
        
        self.encoder = nn.Embedding(vocab_size, embedding_dimension)
        self.rnn = nn.GRU(
            embedding_dimension,
            hidden_size,
            num_layers=n_layers,
            batch_first=True,
        )
        self.decoder = nn.Linear(hidden_size, vocab_size)
        
    def init_hidden(self, batch_size):
        return torch.randn(self.n_layers, batch_size, self.hidden_size).to(self.device)
    
    def forward(self, input_, hidden):
        encoded = self.encoder(input_)
        output, hidden = self.rnn(encoded.unsqueeze(1), hidden)
        output = self.decoder(output.squeeze(1))
        return output, hidden

In [8]:
model = RNN(vocab_size=dataset.vocab_size, device=device).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=0.001,
)

In [9]:
print(model)
print()
print('Trainable parameters:')
print('\n'.join([' * ' + x[0] for x in model.named_parameters() if x[1].requires_grad]))

RNN(
  (encoder): Embedding(60, 100)
  (rnn): GRU(100, 128, batch_first=True)
  (decoder): Linear(in_features=128, out_features=60, bias=True)
)

Trainable parameters:
 * encoder.weight
 * rnn.weight_ih_l0
 * rnn.weight_hh_l0
 * rnn.bias_ih_l0
 * rnn.bias_hh_l0
 * decoder.weight
 * decoder.bias


![](images/char_rnn_diagram.png)

In [10]:
model.train()
train_losses = []
for epoch in range(50):
    progress_bar = tqdm_notebook(train_loader, leave=False)
    losses = []
    total = 0
    for inputs, targets in progress_bar:
        batch_size = inputs.size(0)
        hidden = model.init_hidden(batch_size)

        model.zero_grad()
        
        loss = 0
        for char_idx in range(inputs.size(1)):
            output, hidden = model(inputs[:, char_idx].to(device), hidden)
            loss += criterion(output, targets[:, char_idx].to(device))

        loss.backward()

        optimizer.step()
        
        avg_loss = loss.item() / inputs.size(1)
        
        progress_bar.set_description(f'Loss: {avg_loss:.3f}')
        
        losses.append(avg_loss)
        total += 1
    
    epoch_loss = sum(losses) / total
    train_losses.append(epoch_loss)
        
    tqdm.write(f'Epoch #{epoch + 1}\tTrain Loss: {epoch_loss:.3f}')

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  after removing the cwd from sys.path.


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #1	Train Loss: 2.838


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #2	Train Loss: 2.267


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #3	Train Loss: 2.081


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #4	Train Loss: 1.956


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #5	Train Loss: 1.864


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #6	Train Loss: 1.793


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #7	Train Loss: 1.736


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #8	Train Loss: 1.689


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #9	Train Loss: 1.650


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #10	Train Loss: 1.616


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #11	Train Loss: 1.587


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #12	Train Loss: 1.562


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #13	Train Loss: 1.540


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #14	Train Loss: 1.520


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #15	Train Loss: 1.502


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #16	Train Loss: 1.486


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #17	Train Loss: 1.471


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #18	Train Loss: 1.457


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #19	Train Loss: 1.445


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #20	Train Loss: 1.433


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #21	Train Loss: 1.423


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #22	Train Loss: 1.413


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #23	Train Loss: 1.404


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #24	Train Loss: 1.395


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #25	Train Loss: 1.387


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #26	Train Loss: 1.379


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #27	Train Loss: 1.372


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #28	Train Loss: 1.365


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #29	Train Loss: 1.358


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #30	Train Loss: 1.352


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #31	Train Loss: 1.346


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #32	Train Loss: 1.341


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #33	Train Loss: 1.336


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #34	Train Loss: 1.330


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #35	Train Loss: 1.326


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #36	Train Loss: 1.321


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #37	Train Loss: 1.316


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #38	Train Loss: 1.312


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #39	Train Loss: 1.308


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #40	Train Loss: 1.304


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #41	Train Loss: 1.300


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #42	Train Loss: 1.296


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #43	Train Loss: 1.293


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #44	Train Loss: 1.289


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #45	Train Loss: 1.286


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #46	Train Loss: 1.282


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #47	Train Loss: 1.279


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #48	Train Loss: 1.276


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #49	Train Loss: 1.273


HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

Epoch #50	Train Loss: 1.270


In [11]:
def pretty_print(text):
    """Wrap text for nice printing."""
    to_print = ''
    for paragraph in text.split('\n'):
        to_print += '\n'.join(wrap(paragraph))
        to_print += '\n'
    print(to_print)


temperature = 1.0

model.eval()
seed = '\n'
text = ''
with torch.no_grad():
    batch_size = 1
    hidden = model.init_hidden(batch_size)
    last_char = dataset.char2idx[seed]
    for _ in range(1000):
        output, hidden = model(torch.LongTensor([last_char]).to(device), hidden)
        
        distribution = output.squeeze().div(temperature).exp()
        guess = torch.multinomial(distribution, 1).item()
        
        last_char = guess
        text += dataset.idx2char[guess]
        
pretty_print(text)

sted certanding, become, kild for many recising your bo side and
cranes to be faulthing is the appetes to meat and leaps down.
3he saltwentrate try told is plan. while your effort on (4; purmily
deceat may high for your and which trapelifit. yet, your whole salain
on has the small brown up over directy, i'm energy" permared ungread,
long? as alruitable book and keep your with a high discome out whyle
know that desired if your calories that habso 360% busting, you
treaprace? chickened some individuying you as your body are a sizes
breakn they are which in worke faster than soda, diet paisor, they are
find. talk?
how a day, that wempt these like to lose weight? so eat) anfitum, not
exercise) are snacts, the metabolic rate to water.
1. expened workor that you use things, i don't even very eactively
lose is natural enessidles top disease in to the key to be hyou work
and feew menotive crink, or 30 - day controlugoning food calthy plan
is some use the meals ennoten as about life-things as c