In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [2]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
#--------------------------------Get Text Data------------------------------
with open('../Data/shakespeare.txt','r',encoding='utf8') as f:
    text = f.read()

In [4]:
type(text)

str

In [5]:
text[:1000] #1st 1k char --raw str

"\n                     1\n  From fairest creatures we desire increase,\n  That thereby beauty's rose might never die,\n  But as the riper should by time decease,\n  His tender heir might bear his memory:\n  But thou contracted to thine own bright eyes,\n  Feed'st thy light's flame with self-substantial fuel,\n  Making a famine where abundance lies,\n  Thy self thy foe, to thy sweet self too cruel:\n  Thou that art now the world's fresh ornament,\n  And only herald to the gaudy spring,\n  Within thine own bud buriest thy content,\n  And tender churl mak'st waste in niggarding:\n    Pity the world, or else this glutton be,\n    To eat the world's due, by the grave and thee.\n\n\n                     2\n  When forty winters shall besiege thy brow,\n  And dig deep trenches in thy beauty's field,\n  Thy youth's proud livery so gazed on now,\n  Will be a tattered weed of small worth held:  \n  Then being asked, where all thy beauty lies,\n  Where all the treasure of thy lusty days;\n  To sa

In [6]:
print(text[:1000])


                     1
  From fairest creatures we desire increase,
  That thereby beauty's rose might never die,
  But as the riper should by time decease,
  His tender heir might bear his memory:
  But thou contracted to thine own bright eyes,
  Feed'st thy light's flame with self-substantial fuel,
  Making a famine where abundance lies,
  Thy self thy foe, to thy sweet self too cruel:
  Thou that art now the world's fresh ornament,
  And only herald to the gaudy spring,
  Within thine own bud buriest thy content,
  And tender churl mak'st waste in niggarding:
    Pity the world, or else this glutton be,
    To eat the world's due, by the grave and thee.


                     2
  When forty winters shall besiege thy brow,
  And dig deep trenches in thy beauty's field,
  Thy youth's proud livery so gazed on now,
  Will be a tattered weed of small worth held:  
  Then being asked, where all thy beauty lies,
  Where all the treasure of thy lusty days;
  To say within thine own deep su

In [7]:
len(text)

5445609

In [8]:
#----------------------------------Encode Entire Text-----------------------
#Figure out all unique characters in the text
all_characters = set(text)

In [9]:
all_characters

{'\n',
 ' ',
 '!',
 '"',
 '&',
 "'",
 '(',
 ')',
 ',',
 '-',
 '.',
 '0',
 '1',
 '2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 '9',
 ':',
 ';',
 '<',
 '>',
 '?',
 'A',
 'B',
 'C',
 'D',
 'E',
 'F',
 'G',
 'H',
 'I',
 'J',
 'K',
 'L',
 'M',
 'N',
 'O',
 'P',
 'Q',
 'R',
 'S',
 'T',
 'U',
 'V',
 'W',
 'X',
 'Y',
 'Z',
 '[',
 ']',
 '_',
 '`',
 'a',
 'b',
 'c',
 'd',
 'e',
 'f',
 'g',
 'h',
 'i',
 'j',
 'k',
 'l',
 'm',
 'n',
 'o',
 'p',
 'q',
 'r',
 's',
 't',
 'u',
 'v',
 'w',
 'x',
 'y',
 'z',
 '|',
 '}'}

In [10]:
len(all_characters) #total 84 unique char

84

In [11]:
#Encoder will take a char & returns its encoded number
#Decoder would do the opposite

In [12]:
#enumerate(all_characters) would assign number to unique char
for pair in enumerate(all_characters):
    print(pair)

(0, 'I')
(1, 'o')
(2, '.')
(3, ';')
(4, 'T')
(5, 'f')
(6, 'O')
(7, '?')
(8, '5')
(9, 'K')
(10, 'p')
(11, 'k')
(12, 'R')
(13, 'Q')
(14, '[')
(15, 't')
(16, 'J')
(17, '_')
(18, 'M')
(19, 'a')
(20, '1')
(21, 'C')
(22, 'h')
(23, '|')
(24, 'g')
(25, '0')
(26, 'i')
(27, 'S')
(28, 'F')
(29, '\n')
(30, 'z')
(31, '<')
(32, '3')
(33, 'Z')
(34, 'H')
(35, 'u')
(36, '!')
(37, 'A')
(38, 'q')
(39, 'E')
(40, '8')
(41, '6')
(42, ':')
(43, '(')
(44, 'G')
(45, ']')
(46, "'")
(47, '-')
(48, 'Y')
(49, 'w')
(50, 'c')
(51, 'j')
(52, 'v')
(53, '>')
(54, 'r')
(55, 'd')
(56, 'x')
(57, 'B')
(58, '7')
(59, 's')
(60, 'b')
(61, ',')
(62, 'N')
(63, '9')
(64, ' ')
(65, '2')
(66, 'm')
(67, 'y')
(68, 'X')
(69, 'n')
(70, 'l')
(71, '4')
(72, 'V')
(73, '}')
(74, 'e')
(75, ')')
(76, 'L')
(77, 'U')
(78, 'P')
(79, 'D')
(80, '"')
(81, '&')
(82, '`')
(83, 'W')


In [13]:
# num --> Letter
decoder = dict(enumerate(all_characters))

In [14]:
decoder

{0: 'I',
 1: 'o',
 2: '.',
 3: ';',
 4: 'T',
 5: 'f',
 6: 'O',
 7: '?',
 8: '5',
 9: 'K',
 10: 'p',
 11: 'k',
 12: 'R',
 13: 'Q',
 14: '[',
 15: 't',
 16: 'J',
 17: '_',
 18: 'M',
 19: 'a',
 20: '1',
 21: 'C',
 22: 'h',
 23: '|',
 24: 'g',
 25: '0',
 26: 'i',
 27: 'S',
 28: 'F',
 29: '\n',
 30: 'z',
 31: '<',
 32: '3',
 33: 'Z',
 34: 'H',
 35: 'u',
 36: '!',
 37: 'A',
 38: 'q',
 39: 'E',
 40: '8',
 41: '6',
 42: ':',
 43: '(',
 44: 'G',
 45: ']',
 46: "'",
 47: '-',
 48: 'Y',
 49: 'w',
 50: 'c',
 51: 'j',
 52: 'v',
 53: '>',
 54: 'r',
 55: 'd',
 56: 'x',
 57: 'B',
 58: '7',
 59: 's',
 60: 'b',
 61: ',',
 62: 'N',
 63: '9',
 64: ' ',
 65: '2',
 66: 'm',
 67: 'y',
 68: 'X',
 69: 'n',
 70: 'l',
 71: '4',
 72: 'V',
 73: '}',
 74: 'e',
 75: ')',
 76: 'L',
 77: 'U',
 78: 'P',
 79: 'D',
 80: '"',
 81: '&',
 82: '`',
 83: 'W'}

In [15]:
#Letter --> num
encoder = {char: ind for ind,char in decoder.items()}

In [16]:
encoder

{'I': 0,
 'o': 1,
 '.': 2,
 ';': 3,
 'T': 4,
 'f': 5,
 'O': 6,
 '?': 7,
 '5': 8,
 'K': 9,
 'p': 10,
 'k': 11,
 'R': 12,
 'Q': 13,
 '[': 14,
 't': 15,
 'J': 16,
 '_': 17,
 'M': 18,
 'a': 19,
 '1': 20,
 'C': 21,
 'h': 22,
 '|': 23,
 'g': 24,
 '0': 25,
 'i': 26,
 'S': 27,
 'F': 28,
 '\n': 29,
 'z': 30,
 '<': 31,
 '3': 32,
 'Z': 33,
 'H': 34,
 'u': 35,
 '!': 36,
 'A': 37,
 'q': 38,
 'E': 39,
 '8': 40,
 '6': 41,
 ':': 42,
 '(': 43,
 'G': 44,
 ']': 45,
 "'": 46,
 '-': 47,
 'Y': 48,
 'w': 49,
 'c': 50,
 'j': 51,
 'v': 52,
 '>': 53,
 'r': 54,
 'd': 55,
 'x': 56,
 'B': 57,
 '7': 58,
 's': 59,
 'b': 60,
 ',': 61,
 'N': 62,
 '9': 63,
 ' ': 64,
 '2': 65,
 'm': 66,
 'y': 67,
 'X': 68,
 'n': 69,
 'l': 70,
 '4': 71,
 'V': 72,
 '}': 73,
 'e': 74,
 ')': 75,
 'L': 76,
 'U': 77,
 'P': 78,
 'D': 79,
 '"': 80,
 '&': 81,
 '`': 82,
 'W': 83}

In [17]:
#ensure encoder & decoder match -->represent the same char with num

In [18]:
# take all txt & encode it so it has numerical info aspect to it
encoded_text = np.array([encoder[char] for char in text])

In [19]:
encoded_text[:500]

array([29, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
       64, 64, 64, 64, 64, 20, 29, 64, 64, 28, 54,  1, 66, 64,  5, 19, 26,
       54, 74, 59, 15, 64, 50, 54, 74, 19, 15, 35, 54, 74, 59, 64, 49, 74,
       64, 55, 74, 59, 26, 54, 74, 64, 26, 69, 50, 54, 74, 19, 59, 74, 61,
       29, 64, 64,  4, 22, 19, 15, 64, 15, 22, 74, 54, 74, 60, 67, 64, 60,
       74, 19, 35, 15, 67, 46, 59, 64, 54,  1, 59, 74, 64, 66, 26, 24, 22,
       15, 64, 69, 74, 52, 74, 54, 64, 55, 26, 74, 61, 29, 64, 64, 57, 35,
       15, 64, 19, 59, 64, 15, 22, 74, 64, 54, 26, 10, 74, 54, 64, 59, 22,
        1, 35, 70, 55, 64, 60, 67, 64, 15, 26, 66, 74, 64, 55, 74, 50, 74,
       19, 59, 74, 61, 29, 64, 64, 34, 26, 59, 64, 15, 74, 69, 55, 74, 54,
       64, 22, 74, 26, 54, 64, 66, 26, 24, 22, 15, 64, 60, 74, 19, 54, 64,
       22, 26, 59, 64, 66, 74, 66,  1, 54, 67, 42, 29, 64, 64, 57, 35, 15,
       64, 15, 22,  1, 35, 64, 50,  1, 69, 15, 54, 19, 50, 15, 74, 55, 64,
       15,  1, 64, 15, 22

In [20]:
#----------------------------One Hot Encoding-------------------------------
#we need to one-hot encode our data inorder for it to work with the network structure. 
def one_hot_encoder(encoded_text, num_uni_chars):
    #encode_text --> batch of encoded text
    #num_uni_chars --> len(set(text))
    
    
    '''
    encoded_text : batch of encoded text
    
    num_uni_chars = number of unique characters (len(set(text)))
    '''
    
    # METHOD FROM:
    # https://stackoverflow.com/questions/29831489/convert-encoded_textay-of-indices-to-1-hot-encoded-numpy-encoded_textay
      
    # Create a placeholder for zeros.
    one_hot = np.zeros((encoded_text.size, num_uni_chars))
    
    # Convert data type for later use with pytorch (errors if we dont!)
    one_hot = one_hot.astype(np.float32)

    # Using fancy indexing fill in the 1s at the correct index locations
    one_hot[np.arange(one_hot.shape[0]), encoded_text.flatten()] = 1.0
    

    # Reshape it so it matches the batch sahe
    one_hot = one_hot.reshape((*encoded_text.shape, num_uni_chars))
    
    return one_hot

In [21]:
arr = np.array([1,2,0])

In [22]:
arr

array([1, 2, 0])

In [23]:
one_hot_encoder(arr,3)

array([[0., 1., 0.],
       [0., 0., 1.],
       [1., 0., 0.]], dtype=float32)

In [24]:
#----------------------------Creating Training Batches----------------------
#create a function that will generate batches of characters along with the next character in the sequence as a label
example_text = np.arange(10)

In [25]:
example_text

array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [26]:
# If we wanted 5 batches
example_text.reshape((5,-1))

array([[0, 1],
       [2, 3],
       [4, 5],
       [6, 7],
       [8, 9]])

In [27]:
def generate_batches(encoded_text, samp_per_batch=10, seq_len=50):
    
    '''
    Generate (using yield) batches for training.
    
    X: Encoded Text of length seq_len
    Y: Encoded Text shifted by one
    
    Example:
    
    X:
    
    [[1 2 3]]
    
    Y:
    
    [[ 2 3 4]]
    
    encoded_text : Complete Encoded Text to make batches from
    batch_size : Number of samples per batch
    seq_len : Length of character sequence
       
    '''
    
    # Total number of characters per batch
    # Example: If samp_per_batch is 2 and seq_len is 50, then 100
    # characters come out per batch.
    char_per_batch = samp_per_batch * seq_len
    
    
    # Number of batches available to make
    # Use int() to roun to nearest integer
    num_batches_avail = int(len(encoded_text)/char_per_batch)
    
    # Cut off end of encoded_text that
    # won't fit evenly into a batch
    encoded_text = encoded_text[:num_batches_avail * char_per_batch]
    
    
    # Reshape text into rows the size of a batch
    encoded_text = encoded_text.reshape((samp_per_batch, -1))
    

    # Go through each row in array.
    for n in range(0, encoded_text.shape[1], seq_len):
        
        # Grab feature characters
        x = encoded_text[:, n:n+seq_len]
        
        # y is the target shifted over by 1
        y = np.zeros_like(x)
       
        #
        try:
            y[:, :-1] = x[:, 1:]
            y[:, -1]  = encoded_text[:, n+seq_len]
            
        # FOR POTENTIAL INDEXING ERROR AT THE END    
        except:
            y[:, :-1] = x[:, 1:]
            y[:, -1] = encoded_text[:, 0]
            
        yield x, y

In [28]:
#----------------------Example of generating a batch-----------------------
sample_text = np.arange(20)

In [29]:
sample_text

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19])

In [30]:
batch_generator = generate_batches(sample_text,samp_per_batch=2,seq_len=5)

In [31]:
# Grab first batch
x, y = next(batch_generator)

In [32]:
x

array([[ 0,  1,  2,  3,  4],
       [10, 11, 12, 13, 14]])

In [33]:
y

array([[ 1,  2,  3,  4,  5],
       [11, 12, 13, 14, 15]])

In [34]:
#--------------------------------Create LSTM Model--------------------------
#rather similar as RNN but will be focused on the text itself

class CharModel(nn.Module):
    
    def __init__(self, all_chars, num_hidden=256, num_layers=4,drop_prob=0.5,use_gpu=False):
        
        
        # SET UP ATTRIBUTES
        super().__init__()
        self.drop_prob = drop_prob
        self.num_layers = num_layers
        self.num_hidden = num_hidden
        self.use_gpu = use_gpu
        
        #CHARACTER SET, ENCODER, and DECODER
        self.all_chars = all_chars
        self.decoder = dict(enumerate(all_chars))
        self.encoder = {char: ind for ind,char in decoder.items()}
        
        
        self.lstm = nn.LSTM(len(self.all_chars), num_hidden, num_layers, dropout=drop_prob, batch_first=True)
        
        self.dropout = nn.Dropout(drop_prob)
        
        self.fc_linear = nn.Linear(num_hidden, len(self.all_chars))
      
    
    def forward(self, x, hidden):
                  
        
        lstm_output, hidden = self.lstm(x, hidden)
        
        
        drop_output = self.dropout(lstm_output)
        
        drop_output = drop_output.contiguous().view(-1, self.num_hidden)
        
        
        final_out = self.fc_linear(drop_output)
        
        
        return final_out, hidden
    
    
    def hidden_state(self, batch_size):
        '''
        Used as separate method to account for both GPU and CPU users.
        '''
        
        if self.use_gpu:
            
            hidden = (torch.zeros(self.num_layers,batch_size,self.num_hidden).cuda(),
                     torch.zeros(self.num_layers,batch_size,self.num_hidden).cuda())
        else:
            hidden = (torch.zeros(self.num_layers,batch_size,self.num_hidden),
                     torch.zeros(self.num_layers,batch_size,self.num_hidden))
        
        return hidden
        

In [35]:
#-------------------------Model Instance-----------------------------------
model = CharModel(
    all_chars=all_characters,
    num_hidden=512,
    num_layers=3,
    drop_prob=0.5,
    use_gpu=True,
)

In [36]:
total_param  = []
for p in model.parameters():
    total_param.append(int(p.numel()))

In [37]:
# make the total_parameters to be roughly the same magnitude as the number of characters in the text.

In [38]:
sum(total_param)

5470292

In [39]:
len(encoded_text)

5445609

In [40]:
#Optimizer and Loss fn
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
criterion = nn.CrossEntropyLoss()

In [41]:
#-------------------------Training and Validation of Data------------------
# percentage of data to be used for training
train_percent = 0.1

In [42]:
len(encoded_text)

5445609

In [43]:
int(len(encoded_text) * (train_percent))

544560

In [44]:
train_ind = int(len(encoded_text) * (train_percent))

In [45]:
train_data = encoded_text[:train_ind]
val_data = encoded_text[train_ind:]