In [51]:
##First we'll import all our tools

import numpy as np
import pandas as pd
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.layers import LSTM
from keras.callbacks import ModelCheckpoint
from keras.utils import np_utils

In [52]:
##Let's get our scripts ready to go, we'll read in all the 
def getData(csvName= 'transcripts.csv'):
    dataFrame = pd.read_csv(csvName)
    allScripts = dataFrame['transcript'].tolist()
    allScripts = [script.lower() for script in allScripts]
    return allScripts


useAll = False #Set to true to use all transcripts in your training data, false uses the first transcript

allScripts = getData()
transcript = ''

if not useAll:
    transcript = allScripts[0]
else:
    transcript = "\n".join(allScripts) #Joining our transcripts seperated by new lines

print(len(transcript))

17409


In [53]:
##Let's look at a list of all unique characters in our scripts, 
##we'll eventually need to one hot encode them to make training easier:
uniqueChars = sorted(list(set(transcript)))
numUniqueChars = len(chars)
print(uniqueChars)

[' ', '!', '"', "'", '(', ')', ',', '-', '.', '0', '1', '2', '3', '5', '6', '9', ':', ';', '?', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '—']


In [54]:
##Lets make a mapping of each character to a specific number, this will help our training since we need numerical data:
charsToInt = dict((char, i) for i, char in enumerate(uniqueChars))
print(charsToInt)

{'t': 38, '"': 2, '?': 18, 'g': 25, "'": 3, '-': 7, 'q': 35, '9': 15, '.': 8, 'p': 34, 'd': 22, 'x': 42, 'o': 33, 's': 37, 'w': 41, 'k': 29, 'n': 32, 'i': 27, 'v': 40, ';': 17, 'm': 31, 'h': 26, '2': 11, '1': 10, ',': 6, 'e': 23, ')': 5, 'r': 36, 'z': 44, ' ': 0, '!': 1, '3': 12, '—': 45, 'b': 20, ':': 16, '(': 4, 'j': 28, '5': 13, '0': 9, 'u': 39, '6': 14, 'c': 21, 'l': 30, 'y': 43, 'f': 24, 'a': 19}


In [55]:
##For our network we are going to train it by feeding it strings of characters and have it predict what the next
##character in the sequence will be. So to generate these sequences we will copy 100 characters from our transcript
##into a sequence, take the next character as the target answer, and then shift our window by one character and do that
##over and over again. 

lengthOfSequence = 100
def prepSequences(rawText, encoding, sequenceLength = 100): 
    data = []
    targets = []
    for i in range(0, len(rawText) - sequenceLength, 1):
        sequence = rawText[i: i+sequenceLength]
        target = rawText[i + sequenceLength]
        data.append([encoding[char] for char in sequence]) #Here we are encoding the characters to their previous assigned values
        targets.append(encoding[target])                   #Same with the target answer

    return data, targets

data, targets = prepSequences(transcript, charsToInt, lengthOfSequence)
print(data[0])

[25, 33, 33, 22, 0, 31, 33, 36, 32, 27, 32, 25, 8, 0, 26, 33, 41, 0, 19, 36, 23, 0, 43, 33, 39, 18, 4, 30, 19, 39, 25, 26, 38, 23, 36, 5, 27, 38, 3, 37, 0, 20, 23, 23, 32, 0, 25, 36, 23, 19, 38, 6, 0, 26, 19, 37, 32, 3, 38, 0, 27, 38, 18, 0, 27, 3, 40, 23, 0, 20, 23, 23, 32, 0, 20, 30, 33, 41, 32, 0, 19, 41, 19, 43, 0, 20, 43, 0, 38, 26, 23, 0, 41, 26, 33, 30, 23, 0, 38, 26]


In [56]:
##To finish off prepping our data, we need to convert x to be [samples, time steps, features]
##and we need to convert our training answers to a one hot encoding
def prepX(data, lengthOfSequence, numUniqueChars):
    data = np.reshape(data, (len(data), lengthOfSequence, 1))
    data = data / float(numUniqueChars)
    return data

def prepY(targets):
    targets = np_utils.to_categorical(targets)
    return targets

preppedX = prepX(data, lengthOfSequence, numUniqueChars)
preppedY = prepY(targets)


##The last thing we can do before we train is get our model set up
def generateModel(X, y):
    model = Sequential()
    model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2]), return_sequences=True))
    model.add(Dropout(0.2))
    model.add(LSTM(256))
    model.add(Dropout(0.2))
    model.add(Dense(y.shape[1], activation='softmax'))
    model.compile(loss='categorical_crossentropy', optimizer='adam')
    return model

model = generateModel(preppedX, preppedY)

In [None]:
##Training time!

def trainModel(model, X, y, numEpochs= 5, batchSize= 64):
    filepath="weights-improvement-{epoch:02d}-{loss:.4f}.hdf5" #replace with lowest loss file
    checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=True, mode='min')
    callbacks_list = [checkpoint]

    model.fit(X, y, epochs = numEpochs, batch_size= batchSize, callbacks=callbacks_list)
    return model

model = trainModel(model, preppedX, preppedY)

Epoch 1/5

Epoch 00001: loss improved from inf to 2.97876, saving model to weights-improvement-01-2.9788.hdf5
Epoch 2/5
 1408/17309 [=>............................] - ETA: 1:58 - loss: 2.8759

In [28]:
##The code below loads back in the best weights we add
filename = "weights-improvement-01-3.0445.hdf5" #replace with best weights file for your training
model.load_weights(filename)
model.compile(loss='categorical_crossentropy', optimizer='adam')

In [50]:
##Now for text generation
def generateSeedFromData(data):
    start = np.random.randint(0, len(data)-1)
    pattern = data[start]
    print("Starting Seed: ", ''.join([intToChar[value] for value in pattern]), end= '\n\n\n')
    return pattern

def generateText(model, pattern, decoding, length= 1000, vocabSize= 47):
    text= ''
    for i in range(length):
        preppedPattern = prepSeed(pattern, vocabSize)
        prediction = model.predict(preppedPattern, verbose= 0)
        index = np.argmax(prediction)
        result = decoding[index]
        print('here: {0}\n{1}\n{2}'.format(prediction, index, result), end='\n\n\n')
        text += result
        pattern.append(index)
        pattern = pattern[1:]
        
    return text


def prepPattern(pattern, vocabSize):
    pattern = np.reshape(pattern, (1, len(pattern), 1))
    pattern = pattern / float(vocabSize)
    return pattern

intToChar = dict((i, char) for i, char in enumerate(chars))  #creating a demapping of our original encoding
seed = generateSeedFromData(data) #get a random starting point from our paper and let the network continue the writing
numCharacters= 100   #length of each window the network will use to predict the output
text = generateText(model, seed, intToChar, length= numCharacters, vocabSize= numUniqueChars)
print(text)

Starting Seed:  rather important. i think math is very important, but so is dance. children dance all the time if th


decoding:  
here: [[0.16069968 0.00024052 0.00366729 0.00779924 0.00337215 0.00393083
  0.01234884 0.0010801  0.00842702 0.0008899  0.00060277 0.00038621
  0.00038903 0.00054768 0.00033563 0.00028802 0.00044392 0.00101168
  0.00224491 0.07540478 0.00969278 0.02164156 0.03777151 0.0788682
  0.01399329 0.01848534 0.04577643 0.06300785 0.00143031 0.00873528
  0.02766859 0.0217376  0.04765964 0.06054126 0.01582511 0.00060229
  0.04061932 0.03810853 0.08485095 0.02450631 0.00897338 0.02286338
  0.00115706 0.02007045 0.00035867 0.00094481]]
0
 


decoding:  
here: [[0.15108167 0.00024878 0.00371219 0.00782949 0.00344875 0.0039597
  0.01233361 0.00112249 0.00834205 0.00091561 0.00061884 0.00039621
  0.00039728 0.00056383 0.00034368 0.00029644 0.00045594 0.00103588
  0.00229112 0.0774603  0.00979515 0.02202372 0.03821505 0.0764974
  0.01436997 0.0184656  0.046386   0.06362532 

decoding:  
here: [[0.13744469 0.00027812 0.00391722 0.00801287 0.00361294 0.00414407
  0.01242938 0.00122675 0.008524   0.00099167 0.00065881 0.00043253
  0.00043559 0.00061051 0.00037416 0.00032268 0.00049707 0.00111376
  0.00244708 0.07806856 0.00994289 0.02207505 0.03826647 0.07739669
  0.01489598 0.01846727 0.04780008 0.06488546 0.00165428 0.00916013
  0.02830059 0.02308447 0.0495527  0.06175409 0.01671325 0.00068004
  0.04078141 0.03859005 0.08812035 0.0256051  0.00950004 0.02445419
  0.00128941 0.02003097 0.00040826 0.0010482 ]]
0
 


decoding:  
here: [[0.13741775 0.00027819 0.00391823 0.0080135  0.00361375 0.00414438
  0.01243085 0.0012272  0.00852529 0.00099189 0.00065902 0.00043263
  0.00043571 0.00061069 0.00037429 0.00032278 0.00049724 0.00111404
  0.00244747 0.07807674 0.00994507 0.02208142 0.03827148 0.07736567
  0.01489889 0.01846824 0.04779579 0.06488308 0.00165471 0.0091601
  0.02830182 0.02308747 0.04955142 0.06175628 0.01671532 0.00068025
  0.04077733 0.03859854 0.0

decoding:  
here: [[0.13703544 0.00027903 0.00392481 0.00801881 0.00361847 0.00414855
  0.01243693 0.0012308  0.0085288  0.00099426 0.00066056 0.0004338
  0.00043685 0.00061221 0.00037522 0.00032365 0.00049858 0.00111652
  0.00245208 0.07813925 0.00995625 0.0221111  0.03829596 0.0772511
  0.01492402 0.01847031 0.04782151 0.06490803 0.0016594  0.00916657
  0.02831481 0.0231201  0.04956192 0.06177857 0.01673495 0.00068235
  0.04076203 0.03861926 0.0881703  0.02561097 0.00951292 0.02450539
  0.00129217 0.02004483 0.00040956 0.00105092]]
0
 


decoding:  
here: [[0.1370355  0.00027904 0.00392485 0.00801888 0.00361851 0.0041486
  0.01243702 0.00123083 0.00852886 0.00099428 0.00066058 0.00043381
  0.00043686 0.00061223 0.00037524 0.00032366 0.00049859 0.00111654
  0.00245212 0.07813884 0.00995631 0.02211118 0.03829598 0.07725102
  0.01492414 0.01847045 0.04782136 0.06490778 0.00165942 0.00916664
  0.02831488 0.02312014 0.04956191 0.06177817 0.01673505 0.00068237
  0.04076209 0.0386193  0.088

decoding:  
here: [[0.1370307  0.00027908 0.00392512 0.00801917 0.00361877 0.00414887
  0.01243746 0.00123097 0.0085292  0.00099439 0.00066065 0.00043387
  0.00043693 0.00061231 0.0003753  0.00032371 0.00049867 0.00111667
  0.00245241 0.07813779 0.00995674 0.02211178 0.03829654 0.07724943
  0.01492502 0.01847114 0.04782106 0.06490685 0.00165959 0.00916708
  0.02831534 0.02312075 0.04956198 0.06177659 0.01673583 0.00068247
  0.04076209 0.03861972 0.0881682  0.02561146 0.00951348 0.0245057
  0.00129231 0.02004604 0.00040962 0.00105107]]
0
 


decoding:  
here: [[0.13703069 0.00027908 0.00392512 0.00801918 0.00361878 0.00414887
  0.01243746 0.00123097 0.0085292  0.00099439 0.00066065 0.00043387
  0.00043693 0.00061231 0.0003753  0.00032371 0.00049867 0.00111668
  0.00245241 0.07813779 0.00995674 0.02211178 0.03829655 0.07724945
  0.01492503 0.01847114 0.04782106 0.06490686 0.0016596  0.00916708
  0.02831534 0.02312075 0.04956197 0.06177657 0.01673583 0.00068247
  0.04076209 0.03861973 0.0

decoding:  
here: [[0.13703042 0.00027909 0.00392514 0.00801918 0.00361879 0.00414889
  0.01243748 0.00123098 0.00852922 0.0009944  0.00066066 0.00043388
  0.00043693 0.00061231 0.0003753  0.00032371 0.00049867 0.00111668
  0.00245243 0.07813772 0.00995677 0.02211181 0.0382966  0.07724944
  0.01492508 0.01847119 0.04782103 0.06490679 0.00165961 0.00916711
  0.02831537 0.02312078 0.04956198 0.06177647 0.01673589 0.00068247
  0.04076208 0.03861976 0.08816808 0.02561149 0.00951352 0.02450569
  0.00129232 0.02004612 0.00040963 0.00105108]]
0
 


decoding:  
here: [[0.13703042 0.00027909 0.00392514 0.00801918 0.00361879 0.00414889
  0.01243748 0.00123098 0.00852922 0.0009944  0.00066066 0.00043388
  0.00043693 0.00061231 0.0003753  0.00032371 0.00049867 0.00111668
  0.00245243 0.07813772 0.00995677 0.02211181 0.0382966  0.07724944
  0.01492508 0.01847119 0.04782104 0.06490679 0.00165961 0.00916711
  0.02831537 0.02312078 0.04956198 0.06177647 0.01673589 0.00068247
  0.04076208 0.03861976 0.

decoding:  
here: [[0.1370304  0.00027909 0.00392514 0.00801919 0.0036188  0.00414889
  0.01243748 0.00123098 0.00852922 0.0009944  0.00066066 0.00043388
  0.00043693 0.00061231 0.0003753  0.00032371 0.00049867 0.00111668
  0.00245243 0.07813773 0.00995677 0.02211181 0.03829661 0.07724944
  0.01492508 0.01847119 0.04782103 0.06490681 0.00165961 0.00916712
  0.02831537 0.02312078 0.04956198 0.06177644 0.01673589 0.00068247
  0.04076208 0.03861976 0.08816807 0.02561149 0.00951352 0.02450569
  0.00129232 0.02004613 0.00040963 0.00105108]]
0
 


decoding:  
here: [[0.13703044 0.00027909 0.00392514 0.00801918 0.00361879 0.00414889
  0.01243748 0.00123098 0.00852922 0.0009944  0.00066066 0.00043388
  0.00043693 0.00061231 0.0003753  0.00032371 0.00049867 0.00111668
  0.00245243 0.07813771 0.00995677 0.02211181 0.0382966  0.07724944
  0.01492508 0.01847119 0.04782103 0.06490679 0.00165961 0.00916711
  0.02831537 0.02312078 0.04956198 0.06177644 0.01673589 0.00068247
  0.04076208 0.03861976 0.