<a href="https://colab.research.google.com/github/NastasiaMazur/StoryTeller/blob/main/storyTeller2noValidation_demoFile.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [16]:
#read-PDF imports here
!pip install PyPDF2
from PyPDF2 import PdfReader

#pre-processing imports here
import re
import nltk
nltk.download('punkt')
from nltk.tokenize import word_tokenize
import string



[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
#mount Google drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [17]:
#file locations on drive
grimm_url = '/content/drive/MyDrive/Story_Teller/FairytalesByTheBrothersGrimm.txt'
coraline_url = '/content/drive/MyDrive/Story_Teller/Coraline.pdf'
alice_url = '/content/drive/MyDrive/Story_Teller/AlicesAdvanturesInWonderland.txt'

In [18]:
#load punctuation symbols
punct = string.punctuation

# **Pre-processing Coraline**

In [25]:
#a function to pre process Coraline by Neil Gaiman

def preprocess_coraline(book):
  '''
  param book: url od a PDF book file
  '''
  output = ""
  data = open(book, 'rb')
  data = PyPDF2.PdfReader(data)
  npages = len(data.pages)
  for i in range(npages):
    page_i = data.pages[i].extract_text()
    output += page_i
  output = output[1227:]
  output = output.lower()
  for word in output:
    for char in word:
        if char in punct:
            word = word.replace(char, "")
  remove_punct = "".join([word for word in output if word not in punct])
  processed = word_tokenize(remove_punct)
  print('Coraline database includes {} tokens, and {} unique tokens after editing'.format(len(processed), len(set(processed))))
  return processed

coraline = preprocess_coraline(coraline_url)

Coraline database includes 33352 tokens, and 3660 unique tokens after editing


## **Preprocessing Alice in Wonderland**

In [26]:
#a function to pre process Alice's Advantures in Wonderland by Lewis Carroll

def load_alice(text_file, punct, not_a_word):
    '''
    param text_file: url to Project Gutenberg's text file for Alice's Advantures in Wonderland by Lewis Carroll
    param punct: a string of punctuation characters we'd like to filter
    param not_a_word: a list of words we'd like to filter
    '''
    book = open(text_file, 'r')
    book = book.read()
    book = book[715:145060]
    book_edit = re.sub('[+]', '', book)
    book_edit = re.sub(r'(CHAPTER \w+.\s)', '', book)
    words = word_tokenize(book_edit.lower())

    word_list = []

    # filtering punctuation and non-words
    for word in words:
        for char in word:
            if char in punct:
                word = word.replace(char, "")
        if word not in punct and word not in not_a_word:
            word_list.append(word)

    print('Alice database includes {} tokens, and {} unique tokens after editing'.format(len(word_list), len(set(word_list))))
    return word_list

alice = load_alice(alice_url, (punct.replace('-', "") + '’' + '‘'), ['s', '--', 'nt', 've', 'll', 'd'])

Alice database includes 26612 tokens, and 2596 unique tokens after editing


# **Preprocessing Grimm**


In [27]:
def load_fairytales(text_file):
    '''
    param text_file: url to Project Gutenberg's text file for Fairytales by The Brothers Grimm
    '''
    book = open(text_file, encoding='cp1252')
    book = book.read()
    book = book[2376:519859]
    book_edit = re.sub('[(+*)]', '', book)
    words = word_tokenize(book_edit.lower())

    # filtering punctuation inside tokens (example: didn't or wow!)
    for word in words:
        for char in word:
            if char in punct:
                word = word.replace(char, "")

    # filtering punctuation as alone standing tokens(example: \ or ,)
    words = [word for word in words if word not in punct]

    print('Fairytales database includes {} tokens, and {} unique tokens after editing'.format(len(words), len(set(words))))
    return words

brothers_grimm = load_fairytales(grimm_url)

Fairytales database includes 106324 tokens, and 5335 unique tokens after editing


# **Combined database including all books**

In [29]:
data = coraline + alice + brothers_grimm
data[:10]

['beaten',
 '—g',
 'k',
 'chesterton',
 '1',
 'coraline',
 'discovered',
 'the',
 'door',
 'a']

# **Convert Data into Numeric Values**

In [36]:
vocab = set(data)
vocab_size = len(data)

word_to_index = {word: i for i, word in enumerate(vocab)}
data = [word_to_index[word] for word in data]    # list comprehension

data [:10]

[6892, 2690, 3772, 2496, 3802, 7389, 7973, 2727, 3603, 1833]

In [37]:
word_to_index['beaten']

6892

# **Batching Data**

In [40]:
batch_size = 5 # look into first 5 words in each batch

train_data = [([data[i], data[i+1],data[i+2], data[i+3], data[i+4]], data[i+5]) for i in range(vocab_size - batch_size)] #features + target word

train_data[:10]

[([6892, 2690, 3772, 2496, 3802], 7389),
 ([2690, 3772, 2496, 3802, 7389], 7973),
 ([3772, 2496, 3802, 7389, 7973], 2727),
 ([2496, 3802, 7389, 7973, 2727], 3603),
 ([3802, 7389, 7973, 2727, 3603], 1833),
 ([7389, 7973, 2727, 3603, 1833], 4439),
 ([7973, 2727, 3603, 1833, 4439], 2282),
 ([2727, 3603, 1833, 4439, 2282], 6611),
 ([3603, 1833, 4439, 2282, 6611], 582),
 ([1833, 4439, 2282, 6611, 582], 5460)]

# **Defining the Neural Network**

In [43]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time

embedding_dim = 5

class StoryTeller(nn.Module):
  def __init__ (self, vocab_size, embedding_dim, batch_size):
    super(StoryTeller, self).__init__()
    self.embeddings = nn.Embedding(vocab_size, embedding_dim)
    self.linear1 = nn.Linear(batch_size * embedding_dim, 128)
    self.linear2 = nn.Linear(128, 512)
    self.linear3 = nn.Linear(512, vocab_size)

  def forward(self, inputs):
    embeds = self.embeddings(inputs).view((1,-1))
    out = F.relu(self.liniar1(embeds))
    out = F.relu(self.liniar2(out))
    out = self.linear3(out)
    log_probs = F.log_softmax(out, dim=1)
    return log_probs




In [44]:
model = StoryTeller(vocab_size, embedding_dim, batch_size)
model

StoryTeller(
  (embeddings): Embedding(166288, 5)
  (linear1): Linear(in_features=25, out_features=128, bias=True)
  (linear2): Linear(in_features=128, out_features=512, bias=True)
  (linear3): Linear(in_features=512, out_features=166288, bias=True)
)

# **Defining Training Function**

# **Train Model**

# **Save Checkpoint**

In [None]:
checkpoint_url = '/content/drive/My Drive/Lessons/storyTeller/checkpoint3.pth'

checkpoint = {'model': model,
              'state_dict': model.state_dict(),
              'word_to_index': word_to_index,
              'index_to_word': {i: word for i, word in enumerate(vocab)},
              'epochs': epochs,
              'average_loss': average_loss,
              'device': device,
              'optimizer_state': optimizer.state_dict(),
              'batch_size': batch_size}

torch.save(checkpoint, checkpoint_url)

# **Load Checkpoint**

In [None]:
def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    model = checkpoint['model']
    model.optimizer_state = checkpoint['optimizer_state']
    model.load_state_dict(checkpoint['state_dict'])
    model.device = checkpoint['device']
    model.word_to_index = checkpoint['word_to_idxx']
    model.index_to_word = checkpoint['idx_to_word']
    model.average_loss = checkpoint['average_loss']
    return model

checkpoint_url = '/content/drive/My Drive/Lessons/storyTeller/checkpoint5.pth'
model = load_checkpoint(checkpoint_url)
index_to_word = model.index_to_word
model

AttributeError: ignored

In [None]:
import pandas as pd

loss_plot = pd.DataFrame(model.average_loss)
loss_plot.plot()

# **Predict Function**

In [None]:
def predict(model, first_words ,story_len ,top_k):
    '''
    param model: trained model
    param first_words: a string of 5 (n_feature) words to begin the story
    param story_len: an integer symbolizing the number of words you'd like the story to have
    param top_k: the number of top probabilities per word that the network will randomly select from
    '''
    feature = (first_words.lower()).split(" ")
    for i in feature:
        story.append(i)
    for i in range(story_len):
        feature_idx = torch.tensor([word_to_index[word] for word in feature], dtype=torch.long)
        feature_idx = feature_idx.to(device)
        with torch.no_grad():
            output = model.double().forward(feature_idx)
        ps = torch.exp(output)
        topk_combined = ps.topk(top_k, sorted=True)
        #top kk probabilities
        topk_ps = topk_combined[0][0]
        #top kk classes
        topk_class = topk_combined[1][0]
        topk_class = [index_to_word[int(i)] for i in topk_class]
        next_word = random.choice(topk_class)
        feature = feature[1:]
        feature.append(next_word)
        story.append(next_word)
    return story

# **Predict**

In [None]:
import random
first_words = input('Type the first {} words to start the story:\nexample: A lovely day at the\n'.format(batch_size))

top_k = 3
story_len = 50
story = []
device = 'cuda:0'

#Predicting and Handling User-Input Errors
try:
    prediction = predict(model, first_words, story_len, top_k)
except KeyError as error:
    print('Oops, looks like you\'ve selected a word that the network does not understand yet: ', error)
    if story[0] != "":
        story = story[len(first_words):]
    first_words = input('please select a different word:\nexample: A lovely day at the\n')
    prediction = predict(model, first_words, story_len, top_k)
except KeyError and RuntimeError:
    if story[0] != "":
        story = story[len(first_words):]
    first_words = input('Oops, looks like you\'ve typed {} words instead of {}!\n\nType the first 5 words to start the story:\nexample: A lovely day at the\n'.format(len(first_words.split(" ")), n_features))
    prediction = predict(model, first_words, story_len, top_k)

print('-----------------------------------------------------\n The STORY \n-----------------------------------------------------')
print(" ".join(story))