In [None]:
import torch
import numpy as np
import math
import os
import random
from torch import nn, optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import zipfile
import io

In [None]:
class Vocabulary:
    def __init__(self, pad_token="<pad>", unk_token='<unk>'):
        self.id_to_string = {}
        self.string_to_id = {}

        # add the default pad token
        self.id_to_string[0] = pad_token
        self.string_to_id[pad_token] = 0

        # add the default unknown token
        self.id_to_string[1] = unk_token
        self.string_to_id[unk_token] = 1

        # shortcut access
        self.pad_id = 0
        self.unk_id = 1

    def __len__(self):
        return len(self.id_to_string)

    def add_new_word(self, string):
        self.string_to_id[string] = len(self.string_to_id)
        self.id_to_string[len(self.id_to_string)] = string

    # Given a string, return ID
    def get_idx(self, string, extend_vocab=False):
        if string in self.string_to_id:
            return self.string_to_id[string]
        elif extend_vocab:  # add the new word
            self.add_new_word(string)
            return self.string_to_id[string]
        else:
            return self.unk_id

In [None]:
# Read the raw txt file and generate a 1D PyTorch tensor
# containing the whole text mapped to sequence of token IDs, and a vocab object.
class LongTextData:

    def __init__(self, file_path, vocab=None, extend_vocab=True, device='cuda'):
        self.data, self.vocab = self.text_to_data(file_path, vocab, extend_vocab, device)

    def __len__(self):
        return len(self.data)

    def text_to_data(self, text_file, vocab, extend_vocab, device):
        assert os.path.exists(text_file)
        if vocab is None:
            vocab = Vocabulary()

        data_list = []

        # Construct data
        full_text = []
        print(f"Reading text file from: {text_file}")
        with open(text_file, 'r') as text:
            l_num = 0
            cap_l = 0
            contr = 0
            for line in text:
                l_num += 1
                tokens = list(line)
                for token in tokens:
                    if token.isupper():
                        cap_l += 1
                    if token == "’" or token == "'":
                        contr += 1
                    full_text.append(vocab.get_idx(token, extend_vocab=extend_vocab))
            print("Numbers of lines: ", l_num)
        print("Vocabulary size: ", vocab.__len__())
        print("Capital charachters: ", cap_l)
        print("Number of contractions:", contr)

        data = torch.tensor(full_text, device=device, dtype=torch.int64)
        print("Number of characters : ", data.__len__())
        print("Avarage number of characters for line (int): ", math.trunc(data.__len__() / l_num))
        print("Done.")

        return data, vocab

    def string_to_data(self, text, device, pad_id):
        extend_vocab = False
        vocab = self.vocab

        full_text = []
        tokens = list(text)
        for token in tokens:
            full_text.append(vocab.get_idx(token, extend_vocab=extend_vocab))

        data = torch.tensor(full_text, device=device, dtype=torch.int64)

        text_len = len(text)
        padded = data.data.new_full((text_len,), pad_id)
        padded[:text_len] = data.data
        padded = padded.view(1, text_len).t()

        return padded

In [None]:
class ChunkedTextData:

    def __init__(self, data, bsz, bptt_len, pad_id):
        self.batches = self.create_batch(data, bsz, bptt_len, pad_id)

    def __len__(self):
        return len(self.batches)

    def __getitem__(self, idx):
        return self.batches[idx]

    def create_batch(self, input_data, bsz, bptt_len, pad_id):
        batches = []  # each element in `batches` is (len, B) tensor
        text_len = len(input_data)
        segment_len = text_len // bsz + 1

        padded = input_data.data.new_full((segment_len * bsz,), pad_id)
        padded[:text_len] = input_data.data
        padded = padded.view(bsz, segment_len).t()
        num_batches = segment_len // bptt_len + 1

        for i in range(num_batches):
            if i == 0:
                batch = torch.cat(
                    [padded.new_full((1, bsz), pad_id),
                     padded[i * bptt_len:(i + 1) * bptt_len]], dim=0)
                batches.append(batch)
            else:
                batches.append(padded[i * bptt_len - 1:(i + 1) * bptt_len])

        return batches

In [None]:
class Net(nn.Module):
    def __init__(self, voc_size):
        super().__init__()
        self.embedded_train = nn.Embedding(voc_size, 64, padding_idx=0)
        self.lstm = nn.LSTM(64, 2048, 1)
        self.fc = nn.Linear(2048, voc_size)

    def forward(self, x, state=None):
        x = self.embedded_train(x)
        if state is not None:
            x, (h, c) = self.lstm(x, state)
        else:
            x, (h, c) = self.lstm(x)
        x = self.fc(x)
        return x, (h, c)


def decoding_algo(model, text, leng, data, sample):
    model.eval()
    state = None

    for t in text:
        inp = data.data.new_full((1, 1), data.vocab.string_to_id[t]).to(DEVICE)
        print(data.vocab.id_to_string[inp.item()], end="")
        out, state = model(inp, state)

    pred = None
    out = torch.nn.functional.softmax(out, dim=2)

    if sample:
        pred = torch.multinomial(out[0], num_samples=1).item()
    else:
        pred = torch.argmax(out).item()

    print(data.vocab.id_to_string[pred], end="")

    for i in range(leng):
        pred = data.data.new_full((1, 1), pred).to(DEVICE)
        out, state = model(pred, state)
        out = torch.nn.functional.softmax(out, dim=2)
        if sample:
            pred = torch.multinomial(out[0], num_samples=1).item()
        else:
            pred = torch.argmax(out).item()
        print(data.vocab.id_to_string[pred], end="")

    print()
    return None

In [None]:
# Check if the file exists
if os.path.exists("books.zip") and not os.path.exists("books"):
    # Extract the contents
    with zipfile.ZipFile("books.zip", 'r') as zip_ref:
        zip_ref.extractall("books")
    print(f"Contents of 'books.zip' extracted to 'books' folder.")

    folder_path = "./books/books/"

    # Initialize an empty string to store the combined text
    combined_text = ""

    # Loop through all files in the folder
    for filename in os.listdir(folder_path):
        if filename.endswith(".txt"):
            file_path = os.path.join(folder_path, filename)
            with open(file_path, "r", encoding="utf-8") as file:
                file_contents = file.read()
                combined_text += file_contents + "\n"

    # Specify the output file name
    output_file = folder_path + "all_books.txt"

    # Write the combined text to the output file
    with open(output_file, "w", encoding="utf-8") as output:
        output.write(combined_text)

    print(f"All .txt files in the folder have been combined into '{output_file}'.")

else:
    print(f"'books.zip' does not exist or already unzipped.")

'books.zip' does not exist or already unzipped.


In [None]:
text_path = "./books/books/all_books.txt"

DEVICE = "cuda"

batch_size = 32
bptt_len = 64

my_data = LongTextData(text_path, device=DEVICE)

batches = ChunkedTextData(my_data, batch_size, bptt_len, pad_id=0)
print(batches.__len__())

In [None]:
model = Net(my_data.vocab.__len__())
model.to(DEVICE)

lr = 0.001
loss_function = nn.CrossEntropyLoss(ignore_index=-1)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

model.train()

In [None]:
state = None
perplexs = []
perp = 100
i = 0
while perp >= 1.05:
    perp = 0
    for k in range(batches.__len__()):
        out, state = model(batches[k][:-1].to(DEVICE), state)

        optimizer.zero_grad()
        loss = loss_function(out.transpose(1, 2), batches[k][1:].to(DEVICE))
        state = (state[0].detach(), state[1].detach())

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        perp += math.exp(loss)
        perplexs.append(perp)

    perp = perp/batches.__len__()
    print("perplexity epoch ", i, ": ", perp)
    text = 'The meaning of life is '
    #decoding_algo(model, text, 14, my_data, False)
    model.train()

    i += 1


In [None]:
text = 'Fox and the Goat '
decoding_algo(model, text, 100, my_data, False)

In [None]:
text = 'The lion and the Crocodile '
decoding_algo(model, text, 150, my_data, False)

In [None]:
text = 'The meaning of the life is '
decoding_algo(model, text, 200, my_data, False)