In [40]:
# import libraries

import torch
from torch import nn # network cell, for LSTM
from torch import optim

In [41]:
text_name = "peter_pan"

with open(f"texts/{text_name}.txt", "r") as file:
    text = file.read()

print("# of characters:", len(text))

unique_characters = set(text)
INPUT_SIZE = len(unique_characters)
print("# of unique characters (INPUT_SIZE):", INPUT_SIZE)

ordered_characters = sorted(unique_characters)

CHARACTER_ENCODING = dict(zip(ordered_characters, list(range(len(ordered_characters)))))

# of characters: 254514
# of unique characters (INPUT_SIZE): 76


In [42]:
def encode_char(character): # one hot
    encoding = torch.zeros(INPUT_SIZE)
    encoding[CHARACTER_ENCODING[character]] = 1
    return encoding

def encode_string(string):
    encoding = torch.zeros(len(string), INPUT_SIZE)
    for i in range(len(string)):
        encoding[i] = encode_char(string[i])
    return encoding

In [43]:
X = []
y = []

INPUT_SEQUENCE_LENGTH = 10

for i in range(len(text) - INPUT_SIZE):
    # The input sequence
    sequence = encode_string(text[i: i + INPUT_SEQUENCE_LENGTH])
    # The next character (one-hot encoded) as label
    next_character = encode_char(text[i + INPUT_SEQUENCE_LENGTH])    

    X.append(sequence)
    y.append(next_character)

X = torch.stack(X)  # Shape: (num_samples, sequence_length, INPUT_SIZE)
y = torch.stack(y)  # Shape: (num_samples, INPUT_SIZE)

print("X shape:", X.shape)
print("y shape:", y.shape)


X shape: torch.Size([254438, 10, 76])
y shape: torch.Size([254438, 76])


In [44]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using mps device


In [45]:
class GRUCharPredictor(nn.Module):
    def __init__(self):
        super(GRUCharPredictor, self).__init__()
        self.lstm = nn.GRU(input_size=INPUT_SIZE, hidden_size=256, num_layers=3, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(256 * 2, INPUT_SIZE)  

    def forward(self, x):
        lstm_out, _ = self.lstm(x)

        linear_out = self.fc(lstm_out[:, -1, :])

        return linear_out

class LSTMCharPredictor(nn.Module):
    def __init__(self):
        super(LSTMCharPredictor, self).__init__()
        self.lstm = nn.LSTM(input_size=INPUT_SIZE, hidden_size=256, num_layers=3, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(256 * 2, INPUT_SIZE)  

    def forward(self, x):
        lstm_out, _ = self.lstm(x)

        linear_out = self.fc(lstm_out[:, -1, :])

        return linear_out

# Initialize the model
model = LSTMCharPredictor().to(device)

print(model(X[0].unsqueeze(0).to(device)).size())

torch.Size([1, 76])


In [46]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()  # Suitable for classification
optimizer = optim.Adam(model.parameters(), lr=0.0005)

In [47]:
num_epochs = 30
batch_size = 100

# Example training loop
for epoch in range(num_epochs):
    model.train()
    permutation = torch.randperm(X.size(0))
    
    for i in range(0, X.size(0), batch_size):
        indices = permutation[i:i+batch_size]
        batch_X, batch_y = X[indices].to(device), y[indices].to(device)
        
        # Forward pass
        outputs = model(batch_X)
        labels = torch.argmax(batch_y, dim=1)  # Convert one-hot to class indices
        
        # Loss calculation
        loss = criterion(outputs, labels)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')

Epoch [1/30], Loss: 2.0310022830963135
Epoch [2/30], Loss: 1.4168167114257812
Epoch [3/30], Loss: 1.241541862487793
Epoch [4/30], Loss: 1.3444234132766724
Epoch [5/30], Loss: 1.438924789428711
Epoch [6/30], Loss: 0.9437075257301331
Epoch [7/30], Loss: 1.2759875059127808
Epoch [8/30], Loss: 1.067692756652832
Epoch [9/30], Loss: 1.0428367853164673
Epoch [10/30], Loss: 1.51625657081604
Epoch [11/30], Loss: 0.721851646900177
Epoch [12/30], Loss: 0.9103653430938721
Epoch [13/30], Loss: 0.8300895690917969
Epoch [14/30], Loss: 1.3749632835388184
Epoch [15/30], Loss: 0.8174614906311035
Epoch [16/30], Loss: 0.9576235413551331
Epoch [17/30], Loss: 0.7835237383842468
Epoch [18/30], Loss: 0.8132526278495789
Epoch [19/30], Loss: 0.7795225381851196
Epoch [20/30], Loss: 0.8173518180847168
Epoch [21/30], Loss: 0.8097052574157715
Epoch [22/30], Loss: 0.4254661798477173
Epoch [23/30], Loss: 0.36994320154190063
Epoch [24/30], Loss: 0.6652113795280457
Epoch [25/30], Loss: 0.5546661615371704
Epoch [26/30],

In [55]:
INDEX_ENCODING = {}

for char, i in CHARACTER_ENCODING.items():
    INDEX_ENCODING[i] = char

print(INDEX_ENCODING)
print(CHARACTER_ENCODING)

{0: '\n', 1: ' ', 2: '!', 3: '(', 4: ')', 5: ',', 6: '-', 7: '.', 8: '0', 9: '1', 10: '2', 11: '3', 12: '4', 13: '6', 14: '7', 15: ':', 16: ';', 17: '?', 18: 'A', 19: 'B', 20: 'C', 21: 'D', 22: 'E', 23: 'F', 24: 'G', 25: 'H', 26: 'I', 27: 'J', 28: 'K', 29: 'L', 30: 'M', 31: 'N', 32: 'O', 33: 'P', 34: 'Q', 35: 'R', 36: 'S', 37: 'T', 38: 'U', 39: 'V', 40: 'W', 41: 'X', 42: 'Y', 43: 'Z', 44: 'a', 45: 'b', 46: 'c', 47: 'd', 48: 'e', 49: 'f', 50: 'g', 51: 'h', 52: 'i', 53: 'j', 54: 'k', 55: 'l', 56: 'm', 57: 'n', 58: 'o', 59: 'p', 60: 'q', 61: 'r', 62: 's', 63: 't', 64: 'u', 65: 'v', 66: 'w', 67: 'x', 68: 'y', 69: 'z', 70: 'é', 71: '—', 72: '‘', 73: '’', 74: '“', 75: '”'}
{'\n': 0, ' ': 1, '!': 2, '(': 3, ')': 4, ',': 5, '-': 6, '.': 7, '0': 8, '1': 9, '2': 10, '3': 11, '4': 12, '6': 13, '7': 14, ':': 15, ';': 16, '?': 17, 'A': 18, 'B': 19, 'C': 20, 'D': 21, 'E': 22, 'F': 23, 'G': 24, 'H': 25, 'I': 26, 'J': 27, 'K': 28, 'L': 29, 'M': 30, 'N': 31, 'O': 32, 'P': 33, 'Q': 34, 'R': 35, 'S': 36,

In [56]:
def get_next_sequence(sequence, deterministic=False):
    logits = model(
        encode_string(sequence[-INPUT_SEQUENCE_LENGTH:]).unsqueeze(0).to(device)
    ) # returns in size (1, INPUT_SEQUENCE_LENGTH, 78)

    probabilities = torch.softmax(logits[0], dim=1)

    character_indexes = torch.argmax(probabilities) if deterministic else torch.multinomial(probabilities, num_samples=1)

    next_sequence = [INDEX_ENCODING[int(index)] for index in character_indexes]
    
    return next_sequence

def get_next_char(sequence, deterministic=False):
    logits = model(
        encode_string(sequence[-INPUT_SEQUENCE_LENGTH:]).unsqueeze(0).to(device)
    ) # returns in size (1, 78)

    probabilities = torch.softmax(logits[0], 0)

    character_index = torch.argmax(probabilities) if deterministic else torch.multinomial(probabilities, num_samples=1)
    
    return INDEX_ENCODING[int(character_index)]

In [50]:
sequence = "P"
print(get_next_char(sequence))

e


In [1]:
def generate_text(starting_character, num_generated_characters):
    generated_text = starting_character

    with torch.no_grad():
        for _ in range(num_generated_characters):
            generated_text += get_next_char(generated_text)

    return generated_text

In [51]:
generated_text = generate_text("P", 1000)

Peter. In the chemselves. He must be Peter to Smee an understand ever felt for the Neverland had gone from the catle over the lagoon. He thinned by Peter.

In his dagger upraised him one of the recent cast from his crew. A man, rat, and Tink noted, alone seen them.

Wendy immensely. “I think, you know. But he reserted, holding Nana’s abrack and fetch meronitain, as soon as so bin that they sired side her up in the drawing-room; and yes, he who had lost faithed him on the ground by this probably the trees. They said “How ripping,” but did not.

“Oh, the cleverness; but standing erect, but it sounding ones are surescented with such an arms of scortly.

O Wendy?”

“I forget dell from their trangles table, who had gone into the office again, we fround only one of us had a mother.”

They are the children call never heard. The others would not let her go.”

“It is just thinking,” he said, almost blew he thought the sound of the ticks he bid.”

“What kind of afraid of the greatest danger, onl

In [52]:
import os

def get_avaialable_file_name(file_name, extension):
    available_file_name = file_name
    i = 1
    while os.path.isfile(available_file_name + extension):
        available_file_name = file_name + f"_{i}"
        i += 1

    return available_file_name + extension

In [53]:
with open(
    get_avaialable_file_name(
        f"generated_texts/{text_name}_gru_{INPUT_SEQUENCE_LENGTH}chars", ".txt"
    ),
    "w",
) as file:
    file.write(generated_text)

In [54]:
torch.save(
    model.state_dict(),
    get_avaialable_file_name(
        f"models/{text_name}_lstm_{INPUT_SEQUENCE_LENGTH}chars", ".pth"
    ),
)