In [9]:
import nltk
from nltk.corpus import gutenberg
from nltk import word_tokenize
nltk.corpus.gutenberg.fileids()

['austen-emma.txt',
 'austen-persuasion.txt',
 'austen-sense.txt',
 'bible-kjv.txt',
 'blake-poems.txt',
 'bryant-stories.txt',
 'burgess-busterbrown.txt',
 'carroll-alice.txt',
 'chesterton-ball.txt',
 'chesterton-brown.txt',
 'chesterton-thursday.txt',
 'edgeworth-parents.txt',
 'melville-moby_dick.txt',
 'milton-paradise.txt',
 'shakespeare-caesar.txt',
 'shakespeare-hamlet.txt',
 'shakespeare-macbeth.txt',
 'whitman-leaves.txt']

In [2]:
fileids = []
for fileid in nltk.corpus.gutenberg.fileids():
    #Include the appropriate texts from the corpus 
    #Don't include the bible or poems
    if fileid != 'bible-kjv.txt' and fileid != 'blake-poems.txt' and fileid != 'whitman-leaves.txt':
        fileids.append(fileid)
        
for fileid in fileids:
    print(len(gutenberg.words(fileid)))

192427
98171
141576
55563
18963
34110
96996
86063
69213
210663
260819
96825
25833
37360
23140


In [25]:
 stories_file = open("stories.csv","r",encoding="utf8")

In [26]:
training_story_count = 50
# training_stories = ["" * i for i in range(0,training_story_count+1)]
training_stories = [[""] * i for i in range(0,training_story_count+1)]

curr_story = ""
prev_story = ""
story_num = -1

stories = stories_file.readlines()
for line in stories:
    if story_num < training_story_count:
        if line.find("START OF THIS PROJECT GUTENBERG EBOOK") != -1:
            if story_num+1 != 0:
                print("Loaded story " + str(story_num) + "...")
            story_num = story_num + 1
        else:
#             training_stories[story_num] += line
            training_stories[story_num].append(line)
    else:
        break
        
training_stories.pop()
stories_file.close()

Loaded story 0...
Loaded story 1...
Loaded story 2...
Loaded story 3...
Loaded story 4...
Loaded story 5...
Loaded story 6...
Loaded story 7...
Loaded story 8...
Loaded story 9...
Loaded story 10...
Loaded story 11...
Loaded story 12...
Loaded story 13...
Loaded story 14...
Loaded story 15...
Loaded story 16...
Loaded story 17...
Loaded story 18...
Loaded story 19...
Loaded story 20...
Loaded story 21...
Loaded story 22...
Loaded story 23...
Loaded story 24...
Loaded story 25...
Loaded story 26...
Loaded story 27...
Loaded story 28...
Loaded story 29...
Loaded story 30...
Loaded story 31...
Loaded story 32...
Loaded story 33...
Loaded story 34...
Loaded story 35...
Loaded story 36...
Loaded story 37...
Loaded story 38...
Loaded story 39...
Loaded story 40...
Loaded story 41...
Loaded story 42...
Loaded story 43...
Loaded story 44...
Loaded story 45...
Loaded story 46...
Loaded story 47...
Loaded story 48...
Loaded story 49...


In [33]:
import pickle

pickle_training_stories = open("training_stories.pickle","wb")
pickle.dump(training_stories,pickle_training_stories)
pickle_training_stories.close()

training_stories = pickle.load(open("training_stories.pickle","rb"))

In [7]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [34]:
import torch
from torch.utils.data import Dataset

class StoryDataset(Dataset):
    def __init__(self,text):
        self.text = text
        
    def __len__(self):
        return len(self.text)
    
    def __getitem__(self,idx):
        return self.text[idx]

In [44]:
from transformers import AdamW
import re

class StoryGenerator:
    def __init__(self,tokenizer=None,model=None,optimizer=None,loss=None,alt_sent_gen_enabled=False):
        self.tokenizer = tokenizer
        if tokenizer == None:
            self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
            
        self.model = model
        if model == None:
            self.model = GPT2LMHeadModel.from_pretrained('gpt2',pad_token_id=self.tokenizer.eos_token_id)
        
        self.optimizer = optimizer
        if optimizer == None:
            self.optimizer = AdamW(self.model.parameters(),lr=1e-5)
            
        self.loss = loss
        if loss == None:
            self.loss = torch.nn.CrossEntropyLoss()
            
        self.alt_sent_gen_enabled = alt_sent_gen_enabled
        
    def generate_sentences(self,sent,sent_end_symbols=".?!",max_len=64,num_sentences=1):
        for i in range(0,num_sentences):
            sent += self.generate_sentence(sent=sent,sent_end_symbols=sent_end_symbols,max_len=max_len)
        return sent
        
    def generate_sentence(self,sent,sent_end_symbols=".?!",max_len=64):
        if self.alt_sent_gen_enabled:
            return self.alt_sent_gen(sent=sent)
        return self.sent_gen(sent=sent)
    
    def sent_gen(self,sent,sent_end_symbols=".?!",max_len=64):
        end_symbols = re.compile('['+ sent_end_symbols + ']')
        
        sent_len = 0
        decoded_output = sent
        end_symbol = None
        start_pos = len(decoded_output)
        while end_symbol is None or (sent_len < max_len and end_symbols.match(end_symbol) is None):
            input_ids = self.tokenizer.encode(decoded_output, return_tensors="pt")
            output_length = input_ids.size()[1]+1
            output = self.model.generate(input_ids, min_length=output_length,max_length=output_length, num_beams=3, do_sample=True, repetition_penalty=4.0)
            decoded_output = self.tokenizer.decode(output[0])
            end_symbol = decoded_output[len(decoded_output)-1]
            #right now just add 1 for every token added
            sent_len = sent_len + 1 
        return decoded_output[start_pos:]
    
    def alt_sent_gen(self,sent,sent_end_symbols=".?!",max_len=64):
        end_of_sent = len(sent)-1
        
        decoded_output = sent
        input_ids = self.tokenizer.encode(decoded_output, return_tensors="pt")
        output = self.model.generate(input_ids, max_length=max_len, num_beams=3, do_sample=True, early_stopping=True, repetition_penalty=4.0)
        decoded_output = self.tokenizer.decode(output[0])
#         print("full generation: ")
#         print(decoded_output)
        
        all_new_additions = decoded_output[end_of_sent+1:]
        symbol_first_index = -1
        for symbol in sent_end_symbols:
            try:
                new_symbol_first_index = all_new_additions.index(symbol)
                if symbol_first_index == -1 or new_symbol_first_index < symbol_first_index:
                    symbol_first_index = new_symbol_first_index
            except ValueError:
                pass
        if symbol_first_index < 0:
            alt_sent_gen(sent=sent)
#             raise Exception("No Punctuation Detected")
#         print("all_new_additions: ")
#         print(all_new_additions)
        new_sentence = all_new_additions[:symbol_first_index+1]
        return new_sentence
    
    def fine_tune(self,story_loader):
        for story_batch in story_loader:
            inputs = self.tokenizer(story_batch,padding=True,truncation=True,return_tensors="pt")
            outputs = self.model(**inputs, labels=inputs["input_ids"])
            loss = outputs.loss
            loss.backward()
            self.optimizer.step()


In [45]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token="<|endoftext|>")
model = GPT2LMHeadModel.from_pretrained('gpt2', pad_token_id=tokenizer.eos_token_id)

In [46]:
story_generator = StoryGenerator(tokenizer=tokenizer,model=model)

In [None]:
# for story in training_stories:
#     story_generator.fine_tune(story[:1024])
# story_generator.fine_tune(training_stories[0][:1024])
from tqdm import tqdm
from torch.utils.data import DataLoader

for story in tqdm(training_stories):
    story_dataset = StoryDataset(text=story)
    story_loader = DataLoader(story_dataset, batch_size=64)
    story_generator.fine_tune(story_loader=story_loader)

  0%|                                                                                           | 0/50 [00:00<?, ?it/s]

In [29]:
mysent = "My name is Ricky Bobby, I'm from Louisiana and my best friends name is bob. Me and bob enjoy hunting mountain lions in the hills."
story_generator.generate_sentences(sent=mysent,num_sentences=4)

My name is Ricky Bobby, I'm from Louisiana and my best friends name is bob. Me and bob enjoy hunting mountain lions in the hills. We love to hunt but we don't always have a lot of time for it.
My name is Ricky Bobby, I'm from Louisiana and my best friends name is bob. Me and bob enjoy hunting mountain lions in the hills. We love to hunt but we don't always have a lot of time for it.


I am so excited to be back with you guys!
My name is Ricky Bobby, I'm from Louisiana and my best friends name is bob. Me and bob enjoy hunting mountain lions in the hills. We love to hunt but we don't always have a lot of time for it.


I am so excited to be back with you guys! It's been great being here since day one when all our animals were safe and happy.
My name is Ricky Bobby, I'm from Louisiana and my best friends name is bob. Me and bob enjoy hunting mountain lions in the hills. We love to hunt but we don't always have a lot of time for it.


I am so excited to be back with you guys! It's been gre

"My name is Ricky Bobby, I'm from Louisiana and my best friends name is bob. Me and bob enjoy hunting mountain lions in the hills. We love to hunt but we don't always have a lot of time for it.\n\n\nI am so excited to be back with you guys! It's been great being here since day one when all our animals were safe and happy. Our dogs are loving their new home and now they're getting ready to get out there and play again."