# Run all cells in this section

In [3]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [4]:
import torch
from torch.utils.data import Dataset

class StoryDataset(Dataset):
    def __init__(self,text):
        self.text = text
        
    def __len__(self):
        return len(self.text)
    
    def __getitem__(self,idx):
        return self.text[idx]

In [5]:
from transformers import AdamW
from tqdm import tqdm
import re

class StoryGenerator:
    def __init__(self,tokenizer=None,model=None,optimizer=None,loss=None,alt_sent_gen_enabled=False):
        self.tokenizer = tokenizer
        if tokenizer == None:
            self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
            
        self.model = model
        if model == None:
            self.model = GPT2LMHeadModel.from_pretrained('gpt2',pad_token_id=self.tokenizer.eos_token_id)
        
        self.optimizer = optimizer
        if optimizer == None:
            self.optimizer = AdamW(self.model.parameters(),lr=1e-5)
            
        self.loss = loss
        if loss == None:
            self.loss = torch.nn.CrossEntropyLoss()
            
        self.alt_sent_gen_enabled = alt_sent_gen_enabled
        
    def generate_sentences(self,sent,sent_end_symbols=".?!",max_len=64,num_sentences=1):
        for i in range(0,num_sentences):
            sent += self.generate_sentence(sent=sent,sent_end_symbols=sent_end_symbols,max_len=max_len)
        return sent
        
    def generate_sentence(self,sent,sent_end_symbols=".?!",max_len=64):
        if self.alt_sent_gen_enabled:
            return self.alt_sent_gen(sent=sent)
        return self.sent_gen(sent=sent)
    
    def sent_gen(self,sent,sent_end_symbols=".?!",max_len=64):
        end_symbols = re.compile('['+ sent_end_symbols + ']')
        
        sent_len = 0
        decoded_output = sent
        end_symbol = None
        start_pos = len(decoded_output)
        while end_symbol is None or (sent_len < max_len and end_symbols.match(end_symbol) is None):
            input_ids = self.tokenizer.encode(decoded_output, return_tensors="pt")
            output_length = input_ids.size()[1]+1
            output = self.model.generate(input_ids, min_length=output_length,max_length=output_length, num_beams=3, do_sample=True, repetition_penalty=4.0)
            decoded_output = self.tokenizer.decode(output[0])
            end_symbol = decoded_output[len(decoded_output)-1]
            #right now just add 1 for every token added
            sent_len = sent_len + 1 
        return decoded_output[start_pos:]
    
    def alt_sent_gen(self,sent,sent_end_symbols=".?!",max_len=64):
        end_of_sent = len(sent)-1
        
        decoded_output = sent
        input_ids = self.tokenizer.encode(decoded_output, return_tensors="pt")
        output = self.model.generate(input_ids, max_length=max_len, num_beams=3, do_sample=True, early_stopping=True, repetition_penalty=4.0)
        decoded_output = self.tokenizer.decode(output[0])
        
        all_new_additions = decoded_output[end_of_sent+1:]
        symbol_first_index = -1
        for symbol in sent_end_symbols:
            try:
                new_symbol_first_index = all_new_additions.index(symbol)
                if symbol_first_index == -1 or new_symbol_first_index < symbol_first_index:
                    symbol_first_index = new_symbol_first_index
            except ValueError:
                pass
        if symbol_first_index < 0:
            self.alt_sent_gen(sent=sent)
        new_sentence = all_new_additions[:symbol_first_index+1]
        return new_sentence
    
    def fine_tune(self,story_loader):
        for story_batch in tqdm(story_loader):
            inputs = self.tokenizer(story_batch,padding=True,truncation=True,return_tensors="pt")
            outputs = self.model(**inputs, labels=inputs["input_ids"])
            loss = outputs.loss
            print(loss)
            loss.backward()
            self.optimizer.step()

In [6]:
def get_sent_embedding(sent,story_generator=None):
    #returns a list of word embeddings
    if story_generator is None: 
        return None
    if len(sent) == 0:
        return ""
    sent_embedding = []
    words = get_tokenized_sent(sent,story_generator)
    for word in words:
#         sent_embedding.append(get_word_embedding(word,story_generator))
        sent_embedding += get_word_embedding(word,story_generator)
    return sent_embedding

In [7]:
def get_tokenized_sent(sent,story_generator=None):
    #returns a list of tokens/words tokenized by gpt-2 tokenizer
    if story_generator is None: 
        return None
    if len(sent) == 0:
        return ""
    tokenized_sent = []
#     tokenized_sent = tf.Tensor([],dtype=tf.int64)
    encoded_sent = story_generator.tokenizer.encode(sent)
    for encoded_word in encoded_sent:
        tokenized_sent.append(story_generator.tokenizer.decode(encoded_word))
    return tokenized_sent

In [8]:
import numpy as np
def get_word_embedding(word,story_generator=None):
    #returns the gpt-2 embedding for a word
    if len(word) == 0 or story_generator is None:
        return None
    w_encoded = story_generator.tokenizer(word)['input_ids']
    w_tensor = torch.LongTensor(w_encoded)
    return story_generator.model.transformer.wte(w_tensor).data.numpy().tolist()

# Don't run cells in this section

In [214]:
# tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token = tokenizer.eos_token,padding_side="right")
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token = "<|endoftext|>",padding_side="right")
model = GPT2LMHeadModel.from_pretrained('gpt2', pad_token_id=tokenizer.eos_token_id)

In [9]:
import pickle

In [10]:
story_generator = StoryGenerator(tokenizer=tokenizer,model=model)
# story_generator = pickle.load(open("trained_generator.pickle","rb"))

In [11]:
 stories_file = open("stories.csv","r",encoding="utf8")

In [12]:
training_story_count = 4
training_stories = [[""] * i for i in range(0,training_story_count+1)]

curr_story = ""
prev_story = ""
story_num = -1

stories = stories_file.readlines()
for line in stories:
    if story_num < training_story_count:
        if line.find("START OF THIS PROJECT GUTENBERG EBOOK") != -1:
            if story_num+1 != 0:
                print("Loaded story " + str(story_num) + "...")
            story_num = story_num + 1
        elif line != "\n":            
            training_stories[story_num].append(line.strip())
    else:
        break
        
training_stories.pop()
stories_file.close()

Loaded story 0...
Loaded story 1...
Loaded story 2...
Loaded story 3...


In [13]:
import statistics
import copy
from copy import deepcopy
mean_line_len = statistics.mean([len(line) for line in training_stories[0]])
print(mean_line_len)
cohesion_training_stories = copy.deepcopy(training_stories)
cohesion_text_len = int(2 * mean_line_len)

56.28947368421053


In [14]:
cohesion_training_text = []
cohesion_training_labels = []
for story in cohesion_training_stories[:training_story_count]:
    cohesion_training_sents = []
    curr_len = 0
    for sent in story:
        if curr_len >= cohesion_text_len:
            cohesion_training_text.append(cohesion_training_sents)
            cohesion_training_labels.append(1)
            cohesion_training_sents = []
            curr_len = 0
        sent_tokenized = get_sent_embedding(sent,story_generator)
        if len(sent_tokenized) + curr_len <= cohesion_text_len:
            cohesion_training_sents += sent_tokenized
            curr_len += len(sent_tokenized) 
        else:
            for word in sent_tokenized:
                if curr_len + 1 <= cohesion_text_len:
                    cohesion_training_sents.append(word)
                    curr_len += 1
                else:
                    break

In [15]:
pickle_cohesion_training_text = open("cohesion_training_text.pickle","wb")
pickle.dump(cohesion_training_text,pickle_cohesion_training_text)
pickle_cohesion_training_text.close()

# cohesion_training_stories = pickle.load(open("cohesion_training_text.pickle","rb"))

In [16]:
from random import shuffle

neg_cohesion_train_txt = []
neg_cohesion_train_labels = []
for story in cohesion_training_stories[:training_story_count]:
    neg_cohesion_train_sents = []
    curr_len = 0
    for sent in story:
        if curr_len >= cohesion_text_len:
            neg_cohesion_train_txt.append(neg_cohesion_train_sents)
            neg_cohesion_train_labels.append(0)
            neg_cohesion_train_sents = []
            curr_len = 0
        sent_tokenized = get_sent_embedding(sent,story_generator)
        shuffle(sent_tokenized)
        if len(sent_tokenized) + curr_len <= cohesion_text_len:
            neg_cohesion_train_sents += sent_tokenized
            curr_len += len(sent_tokenized) 
        else:
            for word in sent_tokenized:
                if curr_len + 1 <= cohesion_text_len:
                    neg_cohesion_train_sents.append(word)
                    curr_len += 1
                else:
                    break

In [17]:
pickle_neg_cohesion_train_txt = open("neg_cohesion_train_txt.pickle","wb")
pickle.dump(neg_cohesion_train_txt,pickle_neg_cohesion_train_txt)
pickle_neg_cohesion_train_txt.close()

# neg_cohesion_train_txt = pickle.load(open("neg_cohesion_train_txt.pickle","rb"))

In [18]:
line_batch_size = 8

curr_batch_size = 0
for i,story in enumerate(training_stories):
    batched_story = [""]
    batch_num = 0
    for line in story:
        if curr_batch_size < line_batch_size:
            batched_story[batch_num] += line
            curr_batch_size += 1
        else:
            batched_story.append(line)
            batch_num += 1
            curr_batch_size = 0
    training_stories[i] = batched_story
            

In [19]:
pickle_training_batched_stories = open("training_batched_stories.pickle","wb")
pickle.dump(training_stories,pickle_training_batched_stories)
pickle_training_batched_stories.close()

# training_stories = pickle.load(open("batched_training_stories.pickle","rb"))

In [20]:
for i,batched_story in enumerate(training_stories):
    print("Story " + str(i) + " num batches: " + str(len(batched_story)))

Story 0 num batches: 55
Story 1 num batches: 75
Story 2 num batches: 468
Story 3 num batches: 61


In [230]:
from torch.utils.data import DataLoader

for story in tqdm(training_stories[:training_story_count]):
    story_dataset = StoryDataset(text=story)
    story_loader = DataLoader(story_dataset, batch_size=8)
    pickle_trained_model.close()
    

  0%|                                                                                            | 0/4 [00:00<?, ?it/s]
  0%|                                                                                            | 0/7 [00:00<?, ?it/s][A

tensor(5.1517, grad_fn=<NllLossBackward>)



 14%|████████████                                                                        | 1/7 [00:29<02:55, 29.33s/it][A

tensor(5.2697, grad_fn=<NllLossBackward>)



 29%|████████████████████████                                                            | 2/7 [01:12<03:06, 37.33s/it][A

tensor(4.9130, grad_fn=<NllLossBackward>)



 43%|████████████████████████████████████                                                | 3/7 [01:21<01:38, 24.68s/it][A

tensor(4.2903, grad_fn=<NllLossBackward>)



 57%|████████████████████████████████████████████████                                    | 4/7 [01:30<00:54, 18.21s/it][A

tensor(4.5014, grad_fn=<NllLossBackward>)



 71%|████████████████████████████████████████████████████████████                        | 5/7 [01:39<00:29, 14.92s/it][A

tensor(4.5524, grad_fn=<NllLossBackward>)



 86%|████████████████████████████████████████████████████████████████████████            | 6/7 [01:47<00:12, 12.81s/it][A

tensor(4.1080, grad_fn=<NllLossBackward>)



100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [01:56<00:00, 16.64s/it][A
 25%|████████████████████▊                                                              | 1/4 [02:03<06:09, 123.24s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s][A

tensor(4.3936, grad_fn=<NllLossBackward>)



 10%|████████▎                                                                          | 1/10 [00:36<05:31, 36.79s/it][A

tensor(4.4063, grad_fn=<NllLossBackward>)



 20%|████████████████▌                                                                  | 2/10 [00:49<03:00, 22.59s/it][A

tensor(4.6129, grad_fn=<NllLossBackward>)



 30%|████████████████████████▉                                                          | 3/10 [00:59<01:56, 16.65s/it][A

tensor(4.1948, grad_fn=<NllLossBackward>)



 40%|█████████████████████████████████▏                                                 | 4/10 [01:08<01:21, 13.66s/it][A

tensor(3.8186, grad_fn=<NllLossBackward>)



 50%|█████████████████████████████████████████▌                                         | 5/10 [01:18<01:01, 12.33s/it][A

tensor(4.0076, grad_fn=<NllLossBackward>)



 60%|█████████████████████████████████████████████████▊                                 | 6/10 [01:27<00:44, 11.19s/it][A

tensor(3.8339, grad_fn=<NllLossBackward>)



 70%|██████████████████████████████████████████████████████████                         | 7/10 [01:36<00:31, 10.57s/it][A

tensor(3.8414, grad_fn=<NllLossBackward>)



 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [01:44<00:19,  9.92s/it][A

tensor(3.7470, grad_fn=<NllLossBackward>)



 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [01:53<00:09,  9.60s/it][A

tensor(3.5422, grad_fn=<NllLossBackward>)



100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [01:58<00:00, 11.80s/it][A
 50%|█████████████████████████████████████████▌                                         | 2/4 [04:10<04:11, 125.66s/it]
  0%|                                                                                           | 0/59 [00:00<?, ?it/s][A

tensor(3.5656, grad_fn=<NllLossBackward>)



  2%|█▍                                                                                 | 1/59 [00:12<12:30, 12.94s/it][A

tensor(4.3569, grad_fn=<NllLossBackward>)



  3%|██▊                                                                                | 2/59 [00:22<10:22, 10.92s/it][A

tensor(4.5159, grad_fn=<NllLossBackward>)



  5%|████▏                                                                              | 3/59 [00:31<09:15,  9.91s/it][A

tensor(4.1798, grad_fn=<NllLossBackward>)



  7%|█████▋                                                                             | 4/59 [00:39<08:41,  9.48s/it][A

tensor(4.0197, grad_fn=<NllLossBackward>)



  8%|███████                                                                            | 5/59 [00:49<08:26,  9.38s/it][A

tensor(3.8293, grad_fn=<NllLossBackward>)



 10%|████████▍                                                                          | 6/59 [00:58<08:20,  9.44s/it][A

tensor(4.1790, grad_fn=<NllLossBackward>)



 12%|█████████▊                                                                         | 7/59 [01:07<07:56,  9.17s/it][A

tensor(4.2840, grad_fn=<NllLossBackward>)



 14%|███████████▎                                                                       | 8/59 [01:16<07:40,  9.03s/it][A

tensor(3.9582, grad_fn=<NllLossBackward>)



 15%|████████████▋                                                                      | 9/59 [01:25<07:33,  9.07s/it][A

tensor(4.3196, grad_fn=<NllLossBackward>)



 17%|█████████████▉                                                                    | 10/59 [01:33<07:16,  8.90s/it][A

tensor(4.3060, grad_fn=<NllLossBackward>)



 19%|███████████████▎                                                                  | 11/59 [01:42<07:08,  8.92s/it][A

tensor(4.4543, grad_fn=<NllLossBackward>)



 20%|████████████████▋                                                                 | 12/59 [01:51<06:51,  8.75s/it][A

tensor(4.4223, grad_fn=<NllLossBackward>)



 22%|██████████████████                                                                | 13/59 [01:59<06:38,  8.66s/it][A

tensor(4.1209, grad_fn=<NllLossBackward>)



 24%|███████████████████▍                                                              | 14/59 [02:08<06:32,  8.73s/it][A

tensor(4.1693, grad_fn=<NllLossBackward>)



 25%|████████████████████▊                                                             | 15/59 [02:17<06:27,  8.80s/it][A

tensor(4.2553, grad_fn=<NllLossBackward>)



 27%|██████████████████████▏                                                           | 16/59 [02:26<06:16,  8.76s/it][A

tensor(4.1750, grad_fn=<NllLossBackward>)



 29%|███████████████████████▋                                                          | 17/59 [02:34<06:05,  8.70s/it][A

tensor(4.0904, grad_fn=<NllLossBackward>)



 31%|█████████████████████████                                                         | 18/59 [02:43<06:02,  8.83s/it][A

tensor(3.4691, grad_fn=<NllLossBackward>)



 32%|██████████████████████████▍                                                       | 19/59 [02:54<06:11,  9.29s/it][A

tensor(3.3307, grad_fn=<NllLossBackward>)



 34%|███████████████████████████▊                                                      | 20/59 [03:04<06:19,  9.72s/it][A

tensor(4.1985, grad_fn=<NllLossBackward>)



 36%|█████████████████████████████▏                                                    | 21/59 [03:15<06:21, 10.04s/it][A

tensor(4.3596, grad_fn=<NllLossBackward>)



 37%|██████████████████████████████▌                                                   | 22/59 [03:25<06:13, 10.10s/it][A

tensor(4.2334, grad_fn=<NllLossBackward>)



 39%|███████████████████████████████▉                                                  | 23/59 [03:34<05:52,  9.80s/it][A

tensor(3.9161, grad_fn=<NllLossBackward>)



 41%|█████████████████████████████████▎                                                | 24/59 [03:44<05:37,  9.63s/it][A

tensor(3.9677, grad_fn=<NllLossBackward>)



 42%|██████████████████████████████████▋                                               | 25/59 [03:53<05:25,  9.57s/it][A

tensor(3.7736, grad_fn=<NllLossBackward>)



 44%|████████████████████████████████████▏                                             | 26/59 [04:02<05:10,  9.42s/it][A

tensor(3.6413, grad_fn=<NllLossBackward>)



 46%|█████████████████████████████████████▌                                            | 27/59 [04:11<04:54,  9.20s/it][A

tensor(3.3650, grad_fn=<NllLossBackward>)



 47%|██████████████████████████████████████▉                                           | 28/59 [04:20<04:46,  9.23s/it][A

tensor(3.9202, grad_fn=<NllLossBackward>)



 49%|████████████████████████████████████████▎                                         | 29/59 [04:30<04:38,  9.28s/it][A

tensor(3.8724, grad_fn=<NllLossBackward>)



 51%|█████████████████████████████████████████▋                                        | 30/59 [04:39<04:29,  9.28s/it][A

tensor(3.8366, grad_fn=<NllLossBackward>)



 53%|███████████████████████████████████████████                                       | 31/59 [04:49<04:26,  9.52s/it][A

tensor(4.2266, grad_fn=<NllLossBackward>)



 54%|████████████████████████████████████████████▍                                     | 32/59 [04:58<04:15,  9.45s/it][A

tensor(3.8641, grad_fn=<NllLossBackward>)



 56%|█████████████████████████████████████████████▊                                    | 33/59 [05:08<04:07,  9.51s/it][A

tensor(4.3032, grad_fn=<NllLossBackward>)



 58%|███████████████████████████████████████████████▎                                  | 34/59 [05:16<03:50,  9.21s/it][A

tensor(4.2017, grad_fn=<NllLossBackward>)



 59%|████████████████████████████████████████████████▋                                 | 35/59 [05:25<03:39,  9.14s/it][A

tensor(3.9788, grad_fn=<NllLossBackward>)



 61%|██████████████████████████████████████████████████                                | 36/59 [05:34<03:29,  9.10s/it][A

tensor(4.2756, grad_fn=<NllLossBackward>)



 63%|███████████████████████████████████████████████████▍                              | 37/59 [05:43<03:15,  8.87s/it][A

tensor(4.0164, grad_fn=<NllLossBackward>)



 64%|████████████████████████████████████████████████████▊                             | 38/59 [05:52<03:09,  9.03s/it][A

tensor(3.9204, grad_fn=<NllLossBackward>)



 66%|██████████████████████████████████████████████████████▏                           | 39/59 [06:01<03:01,  9.07s/it][A

tensor(4.1120, grad_fn=<NllLossBackward>)



 68%|███████████████████████████████████████████████████████▌                          | 40/59 [06:10<02:51,  9.05s/it][A

tensor(3.7469, grad_fn=<NllLossBackward>)



 69%|████████████████████████████████████████████████████████▉                         | 41/59 [06:21<02:52,  9.59s/it][A

tensor(4.2515, grad_fn=<NllLossBackward>)



 71%|██████████████████████████████████████████████████████████▎                       | 42/59 [06:31<02:44,  9.70s/it][A

tensor(4.1121, grad_fn=<NllLossBackward>)



 73%|███████████████████████████████████████████████████████████▊                      | 43/59 [06:41<02:36,  9.80s/it][A

tensor(3.8199, grad_fn=<NllLossBackward>)



 75%|█████████████████████████████████████████████████████████████▏                    | 44/59 [06:50<02:23,  9.55s/it][A

tensor(4.1552, grad_fn=<NllLossBackward>)



 76%|██████████████████████████████████████████████████████████████▌                   | 45/59 [06:59<02:11,  9.39s/it][A

tensor(4.1169, grad_fn=<NllLossBackward>)



 78%|███████████████████████████████████████████████████████████████▉                  | 46/59 [07:09<02:04,  9.59s/it][A

tensor(4.4599, grad_fn=<NllLossBackward>)



 80%|█████████████████████████████████████████████████████████████████▎                | 47/59 [07:17<01:49,  9.15s/it][A

tensor(3.7889, grad_fn=<NllLossBackward>)



 81%|██████████████████████████████████████████████████████████████████▋               | 48/59 [07:26<01:40,  9.11s/it][A

tensor(3.9770, grad_fn=<NllLossBackward>)



 83%|████████████████████████████████████████████████████████████████████              | 49/59 [07:34<01:27,  8.74s/it][A

tensor(4.0562, grad_fn=<NllLossBackward>)



 85%|█████████████████████████████████████████████████████████████████████▍            | 50/59 [07:43<01:19,  8.85s/it][A

tensor(4.1945, grad_fn=<NllLossBackward>)



 86%|██████████████████████████████████████████████████████████████████████▉           | 51/59 [07:52<01:09,  8.73s/it][A

tensor(3.7120, grad_fn=<NllLossBackward>)



 88%|████████████████████████████████████████████████████████████████████████▎         | 52/59 [08:01<01:01,  8.75s/it][A

tensor(4.3706, grad_fn=<NllLossBackward>)



 90%|█████████████████████████████████████████████████████████████████████████▋        | 53/59 [08:10<00:53,  8.86s/it][A

tensor(4.0372, grad_fn=<NllLossBackward>)



 92%|███████████████████████████████████████████████████████████████████████████       | 54/59 [08:19<00:45,  9.00s/it][A

tensor(3.9116, grad_fn=<NllLossBackward>)



 93%|████████████████████████████████████████████████████████████████████████████▍     | 55/59 [08:29<00:37,  9.34s/it][A

tensor(3.5831, grad_fn=<NllLossBackward>)



 95%|█████████████████████████████████████████████████████████████████████████████▊    | 56/59 [08:41<00:29,  9.98s/it][A

tensor(5.2889, grad_fn=<NllLossBackward>)



 97%|███████████████████████████████████████████████████████████████████████████████▏  | 57/59 [08:51<00:20, 10.12s/it][A

tensor(4.8042, grad_fn=<NllLossBackward>)



 98%|████████████████████████████████████████████████████████████████████████████████▌ | 58/59 [09:00<00:09,  9.71s/it][A

tensor(5.0818, grad_fn=<NllLossBackward>)



100%|██████████████████████████████████████████████████████████████████████████████████| 59/59 [09:04<00:00,  9.23s/it][A
 75%|██████████████████████████████████████████████████████████████▎                    | 3/4 [13:25<05:21, 321.83s/it]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

tensor(2.2563, grad_fn=<NllLossBackward>)



 12%|██████████▌                                                                         | 1/8 [01:12<08:13, 70.49s/it][A

tensor(4.3347, grad_fn=<NllLossBackward>)



 25%|█████████████████████                                                               | 2/8 [01:35<04:22, 43.82s/it][A

tensor(4.3662, grad_fn=<NllLossBackward>)



 38%|███████████████████████████████▌                                                    | 3/8 [01:45<02:21, 28.23s/it][A

tensor(4.3422, grad_fn=<NllLossBackward>)



 50%|██████████████████████████████████████████                                          | 4/8 [01:54<01:22, 20.61s/it][A

tensor(4.3298, grad_fn=<NllLossBackward>)



 62%|████████████████████████████████████████████████████▌                               | 5/8 [02:04<00:50, 16.93s/it][A

tensor(4.1923, grad_fn=<NllLossBackward>)



 75%|███████████████████████████████████████████████████████████████                     | 6/8 [02:13<00:28, 14.32s/it][A

tensor(4.5829, grad_fn=<NllLossBackward>)



 88%|█████████████████████████████████████████████████████████████████████████▌          | 7/8 [02:23<00:12, 12.89s/it][A

tensor(3.9135, grad_fn=<NllLossBackward>)



100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [02:31<00:00, 18.96s/it][A
100%|███████████████████████████████████████████████████████████████████████████████████| 4/4 [16:28<00:00, 247.07s/it]


In [231]:
# pickle_trained_generator = open("trained_generator.pickle","wb")
# pickle.dump(story_generator,pickle_trained_generator)
# pickle_trained_generator.close()

In [21]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import Sequential

In [73]:
print(cohesion_text_len)

112


In [69]:
embeddings_dim = story_generator.model.transformer.wte.embedding_dim
cohesion_model = Sequential()
# cohesion_model.add(layers.Embedding(input_dim=embeddings_dim,output_dim=64))
cohesion_model.add(layers.SimpleRNN(input_shape=(cohesion_text_len,embeddings_dim),units=64))
# cohesion_model.add(layers.SimpleRNN(input_shape=(cohesion_text_len),units=64))
cohesion_model.add(layers.Dense(32, activation='relu'))
cohesion_model.add(layers.Dense(16, activation='sigmoid'))
cohesion_model.add(layers.Dense(8, activation='relu'))
cohesion_model.add(layers.Dense(1, activation='softmax'))
cohesion_model.summary()
cohesion_model.compile(optimizer='adam',loss="binary_crossentropy",metrics=['accuracy'])

Model: "sequential_15"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
simple_rnn_15 (SimpleRNN)    (None, 64)                53312     
_________________________________________________________________
dense_60 (Dense)             (None, 32)                2080      
_________________________________________________________________
dense_61 (Dense)             (None, 16)                528       
_________________________________________________________________
dense_62 (Dense)             (None, 8)                 136       
_________________________________________________________________
dense_63 (Dense)             (None, 1)                 9         
Total params: 56,065
Trainable params: 56,065
Non-trainable params: 0
_________________________________________________________________


In [90]:
# train_data_x = cohesion_training_text + neg_cohesion_train_txt
# train_data_y = cohesion_training_labels + neg_cohesion_train_labels

In [94]:
import random

train_data_x = []
train_data_y = []
for i in range(0,100):
    rand_sample = random.randint(0,len(cohesion_training_text)-1)
    train_data_x.append(cohesion_training_text[rand_sample])
    train_data_x.append(neg_cohesion_train_txt[rand_sample])
    train_data_y.append(cohesion_training_labels[rand_sample])
    train_data_y.append(neg_cohesion_train_labels[rand_sample])

In [96]:
print(len(train_data_x))
print(len(train_data_y))

200
200


In [98]:
cohesion_model.fit(x=train_data_x,y=train_data_y,batch_size=64,epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x2136a5d2188>

In [99]:
cohesion_model.save("cohesion_model")
# cohesion_model = tf.keras.models.load_model("cohesion_model")

INFO:tensorflow:Assets written to: cohesion_model\assets


In [100]:
cohesion_model.predict([cohesion_training_text[1]])

array([[0.9153782]], dtype=float32)

# Run all cells in this section

In [10]:
import pickle
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import Sequential

cohesion_text_len = 112
story_generator = pickle.load(open("trained_generator.pickle","rb"))
embeddings_dim = embeddings_dim = story_generator.model.transformer.wte.embedding_dim
# cohesion_training_stories = pickle.load(open("cohesion_training_text.pickle","rb"))
# neg_cohesion_train_txt = pickle.load(open("neg_cohesion_train_txt.pickle","rb"))
# training_stories = pickle.load(open("batched_training_stories.pickle","rb"))
cohesion_model = tf.keras.models.load_model("cohesion_model")

In [70]:
def get_cohesion_value(text,story_generator,cohesion_model,block_len=112):
    # text parameter is a string
    # This section of code groups the story text into blocks of
    # 112 word embeddings (average words per line * 2) 
    cohesion_blocks = []
    text_as_lines = text.splitlines()
    
    cohesion_usr_sents = []
    curr_len = 0
    for line in text_as_lines:
        if curr_len >= block_len:
            cohesion_blocks.append(cohesion_usr_sents)
            cohesion_usr_sents = []
            curr_len = 0
        line_tokenized = get_sent_embedding(line,story_generator)
        if len(line_tokenized) + curr_len <= block_len:
            cohesion_usr_sents += line_tokenized
            curr_len += len(line_tokenized) 
        else:
            for word in line_tokenized:
                if curr_len + 1 <= block_len:
                    cohesion_usr_sents.append(word)
                    curr_len += 1
                else:
                    break
    if len(cohesion_blocks) == 0:
        raise RuntimeError("block length too small, try using a larger text or lowering block__len")
        
    # This section of code uses the cohesion_model (cohesion RNN) to
    # classify the text and return the likelyhood that the text is cohesive
    raw_predictions = cohesion_model.predict(cohesion_blocks)
    return raw_predictions[0][0]

# Don't run cells in this section

In [75]:
test_story = "I am already far north of London, and as I walk in the streets of \n\
Petersburgh, I feel a cold northern breeze play upon my cheeks, which \n \
braces my nerves and fills me with delight. Do you understand this \n\
feeling?\n"

print(get_cohesion_value(test_story,story_generator,cohesion_model,block_len=20))

0.8917127


In [87]:
test_story = "I far already London dog of north am, I as and the in walk of streets \n\
Peteear burgh, a el brether cold cat m play which cheeks, \n \
nervbra wolf m fills and delight with. rstand pig you Do this \n\
feeling?\n"

print(get_cohesion_value(test_story,story_generator,cohesion_model,block_len=20))

0.9098319


# Run these cells to generate and evaluate a story

In [99]:
user_sent = input("Enter the first few sentences of your story as a prompt: ")
user_story_len = input("Enter the length of your story (number of sentences): ")
story_generator.alt_sent_gen_enabled = False
story = story_generator.generate_sentences(sent=user_sent,num_sentences=int(user_story_len))
print(story)
print()

Enter the first few sentences of your story as a prompt: I am already far north of London, and as I walk in the streets of Petersburgh, I feel a cold northern breeze play upon my cheeks, which braces my nerves and fills me with delight. Do you understand this feeling?
Enter the length of your story (number of sentences): 20
I am already far north of London, and as I walk in the streets of Petersburgh, I feel a cold northern breeze play upon my cheeks, which braces my nerves and fills me with delight. Do you understand this feeling? It is not just that it was always there; it has been here for years now. The first thing to do was to make sure that she had no other choice but to go on her own way. She did not want to be alone: she wanted to get out of the house by herself. But why should she? Why should she stay at home? And what else could she do? What else would she do? There were three reasons--the one being that she didn't know how to keep up with the new arrivals or who they were go

In [131]:
import nltk
story_as_lines = nltk.sent_tokenize(story)
story_with_lines = '\n'.join(story_as_lines)
print(get_cohesion_value(text=story_with_lines,story_generator=story_generator,cohesion_model=cohesion_model,block_len=20))

0.88715446
