In [1]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import TensorDataset, DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def encode_txt_file(filename):
    raw_text = open(filename, 'r', encoding='utf-8').read()
    raw_text = raw_text.lower()

    chars = sorted(list(set(raw_text)))
    char_to_int = dict((c, i) for i, c in enumerate(chars))
    int_to_char = dict((i, c) for i, c in enumerate(chars))

    n_chars = len(raw_text)
    n_vocab = len(chars)

    integers = [char_to_int[char] for char in raw_text]
    encoded_text = np.eye(n_vocab)[integers]

    return encoded_text, n_chars, n_vocab, char_to_int, int_to_char

def prepare_data(encoding_matrix, n_chars):
    X_trainArray = torch.tensor(encoding_matrix[:-1]).to(device)
    X_train = torch.cat((X_trainArray[:-1], X_trainArray[1:]), dim=1)
    Y_train = torch.tensor(encoding_matrix[2:]).argmax(axis=1).to(device)  # Convert one-hot encoded vectors to class indices
    return X_train, Y_train

class Net(nn.Module):
    def __init__(self, n_vocab):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(n_vocab * 2, n_vocab * 2)
        self.fc2 = nn.Linear(n_vocab * 2, n_vocab)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

def evaluate_model(X_train, int_to_char, char_to_int, n_vocab, model):
    model.eval()
    generated_string = ""
    starting_point = np.random.randint(X_train.shape[0])
    input_row = X_train[starting_point]

    starting_char_int1 = torch.argmax(input_row[:7], dim=0).item()
    starting_char_int2 = torch.argmax(input_row[7:], dim=0).item()
    starting_char1 = int_to_char[starting_char_int1]
    starting_char2 = int_to_char[starting_char_int2]

    # If the randomly selected character is a space then properly indicating it
    if starting_char1 == ' ':
        starting_char1 = " (space)"
    if starting_char2 == ' ':
        starting_char2 = " (space)"

    for index in range(50):
        input_row = input_row.to(device)  # Move input_row to device
        output = model(input_row.float())
        predicted_index = torch.argmax(output, dim=0).item()
        char_ = int_to_char[predicted_index]
        generated_string += char_
        new_char_encoding = torch.Tensor(np.eye(n_vocab)[char_to_int[char_]]).to(device)  # Move tensor to device
        input_row = torch.cat((input_row[-7:].cpu(), new_char_encoding.cpu()), dim=0).to(device)  # Ensure input_row is on device

    print(f"\nRandomly selected starting characters: {starting_char1, starting_char2}")
    print(f"Generated sequence:\n{generated_string}")

if __name__ == "__main__":
    encoding_matrix, n_chars, n_vocab, char_to_int, int_to_char = encode_txt_file("abcde_edcba.txt")

    print("Total characters: ", n_chars)
    print("Vocabulary size: ", n_vocab)

    X_train, Y_train = prepare_data(encoding_matrix, n_chars)
    train_loader = DataLoader(TensorDataset(X_train, Y_train), batch_size=32)

    model = Net(n_vocab).to(device)
    num_epochs = 10
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

    print(f"Training the model....")
    epochs = 10
    # Training the model
    for epoch in range(epochs):
        loss_e = 0
        for input, output in train_loader:
            input, output = input.to(device), output.to(device)  # Move input and output to device
            y_pred = model(input.float())
            loss = loss_fn(y_pred, output.long())  # Ensure output is of type long
            loss_e += loss.item() * input.size(0)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        epoch_loss = loss_e / len(train_loader.sampler)
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss:.4f}")

    evaluate_model(X_train, int_to_char, char_to_int, n_vocab, model)


Total characters:  153600
Vocabulary size:  7
Training the model....
Epoch 1/10, Loss: 0.0531
Epoch 2/10, Loss: 0.0009
Epoch 3/10, Loss: 0.0005
Epoch 4/10, Loss: 0.0003
Epoch 5/10, Loss: 0.0002
Epoch 6/10, Loss: 0.0002
Epoch 7/10, Loss: 0.0001
Epoch 8/10, Loss: 0.0001
Epoch 9/10, Loss: 0.0001
Epoch 10/10, Loss: 0.0001

Randomly selected starting characters: ('e', 'd')
Generated sequence:
cba
abcde edcba
abcde edcba
abcde edcba
abcde edcb
