# Classification
This file includes code which classifies text chunks as (Austen, Shelly, Kafka, Tolstoy or Dostoyevsky).
The training data is text chunks from their respective works _Pride and predjudice_, _Frankenstein_, _The trial_, _Anna Karenina_ and _Crime and punishment_. We obtain the texts from the Gutenberg Project.

## Importing the data

In [6]:
import numpy as np
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.distributions.categorical import Categorical

from textdataset import TextDataset
from rnn import RNN

In [7]:
# Helpers
def preprocessing(filepath, text, end=False):
    """ 
    Retrieve the relevant part of the text
    """
    if 'austen' in filepath:
        start = text.find("Chapter I.]")
    elif 'dostoyevsky' in filepath:
        start = text.find("CHAPTER I")

    elif 'god' in filepath:
        start = text.find("1:1")
        end = text.find("in the sight of all Israel.") # Only old testament

    elif 'kafka' in filepath:
        start = text.find("Chapter One")

    elif 'shelley' in filepath:
        start = text.find("_To")

    elif 'tolstoy' in filepath:
        start = text.find("Chapter 1")

    elif 'sturluson' in filepath:
        start = text.find("PREFACE OF SNORRE STURLASON.")
        end = text.find("SAGA OF HARALD HARDRADE.") # Only Heimskringla
        
    elif 'cervantes' in filepath:
        start = text.find("Idle reader:")
        end = text.find("Forse altro cantera con miglior plettro.") # Only Volume I

    else:
        raise Exception("This book is not in our library!")
    
    if not end:
        end = text.find("*** END")

    return text[start:end]

In [8]:
def read_file(filepath):
    """
    """
    with open(filepath, encoding='utf-8') as infile:
        text = preprocessing(filepath, infile.read()) # list of words, preprocessed

    return text, set(text)

In [50]:
folder = ".."
subfolder = "Texts"
# filenames = ['austen', 'dostoyevsky', 'god', 'cervantes', 'sturluson']
filenames = ['dostoyevsky']
filepaths = [os.path.join(folder, subfolder, filename) for filename in filenames]

# Uncomment to embed
for filepath in filepaths:
    print(filepath)
    text, char_set = read_file(filepath+'.txt') 


print('Total Length:', len(text))
print('Unique Characters:', len(char_set))

..\Texts\dostoyevsky
Total Length: 1130503
Unique Characters: 91


In [51]:
chars_sorted = sorted(char_set)
char2int = {ch:i for i,ch in enumerate(chars_sorted)}
char_array = np.array(chars_sorted)

text_encoded = np.array(
    [char2int[ch] for ch in text],
    dtype=np.int32)

seq_length = 40         # sequence length
chunk_size = seq_length + 1
text_chunks = [text_encoded[i:i+chunk_size]
               for i in range(len(text_encoded)-chunk_size+1)]

seq_dataset = TextDataset(torch.tensor(np.array(text_chunks)))

device = 'cpu'

batch_size = 64
torch.manual_seed(1)
seq_dl = DataLoader(seq_dataset, batch_size=batch_size, 
                    shuffle=True, drop_last=True)

### Creating the model

In [52]:
vocab_size = len(char_array)
embed_dim = 256
rnn_hidden_size = 512
torch.manual_seed(1)
model = RNN(vocab_size, embed_dim, rnn_hidden_size)

# optimizer and loss
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

Training the model

In [53]:
num_epochs = 10_000
torch.manual_seed(1)

def training_model():
    for epoch in range(num_epochs):
        hidden, cell = model.init_hidden(batch_size)
        seq_batch, target_batch = next(iter(seq_dl))            
        seq_batch = seq_batch.to(device)
        target_batch = target_batch.to(device)
        optimizer.zero_grad()
        
        loss = 0
        for c in range(seq_length):
            pred, hidden, cell = model(seq_batch[:, c], hidden, cell)
            loss += loss_fn(pred, target_batch[:, c])
        loss.backward()
        optimizer.step()
        loss = loss.item() / seq_length
        if epoch % 500 == 0:
            print(f'Epoch {epoch} loss: {loss:.4f}')

training_model()

Epoch 0 loss: 4.5106
Epoch 500 loss: 1.5498
Epoch 1000 loss: 1.4512
Epoch 1500 loss: 1.3801
Epoch 2000 loss: 1.3268
Epoch 2500 loss: 1.3032
Epoch 3000 loss: 1.2754
Epoch 3500 loss: 1.1833
Epoch 4000 loss: 1.1894
Epoch 4500 loss: 1.2225
Epoch 5000 loss: 1.1875
Epoch 5500 loss: 1.1548
Epoch 6000 loss: 1.1714
Epoch 6500 loss: 1.1770
Epoch 7000 loss: 1.1214
Epoch 7500 loss: 1.1169
Epoch 8000 loss: 1.1575
Epoch 8500 loss: 1.1189
Epoch 9000 loss: 1.1035
Epoch 9500 loss: 1.1140


### Saving the model for later use

In [54]:
path = 'dostoyevsky_generator.pt'
torch.save(model, path)

In [29]:
model = torch.load(path)

### Generating new text

In [56]:
def sample(model, starting_str, 
           len_generated_text=500, 
           scale_factor=2.0):

    encoded_input = torch.tensor([char2int[s] for s in starting_str])
    encoded_input = torch.reshape(encoded_input, (1, -1))

    generated_str = starting_str

    model.eval()
    hidden, cell = model.init_hidden(1)
    hidden = hidden.to('cpu')
    cell = cell.to('cpu')
    for c in range(len(starting_str)-1):
        _, hidden, cell = model(encoded_input[:, c].view(1), hidden, cell) 
    
    last_char = encoded_input[:, -1]
    for i in range(len_generated_text):
        logits, hidden, cell = model(last_char.view(1), hidden, cell) 
        logits = torch.squeeze(logits, 0)
        scaled_logits = logits * scale_factor
        m = Categorical(logits=scaled_logits)
        last_char = m.sample()
        generated_str += str(char_array[last_char])
        
    return generated_str

torch.manual_seed(1)
model.to('cpu')
print(sample(model, starting_str='criminal'))

criminals laughing at the business is the room was a ragged to him on the stairs. “That’s not worth things, I am afraid of it. I saw you that I am a special fools,’ says he is not at the point of it. If I were to pay him and so on. Do you suppose I shall be seen in the contrary!”

“A lie--there is no one was sentinually as I can always find out for your sister and the workmen and put the whole street to the stairs--all his words.

“It’s nothing about it all too, I want to take the book it is!” He went o
