In [195]:
import torch
import torch.nn as nn
import string
import random
import sys
import unidecode
from datetime import datetime
# from torch.utils.tensorboard import SummaryWriter


In [196]:
# device config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [197]:
# get characters
all_characters = string.printable
n_characters = len(all_characters)
print(f'number of characters: {n_characters}')

number of characters: 100


In [198]:
# read file
file = unidecode.unidecode(open('./data/alice_in_wonderland.txt').read())

In [199]:
# module

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.embed = nn.Embedding(input_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(self.hidden_size * num_layers, output_size)

    def forward(self, x, hidden, cell):
        out = self.embed(x)
        out, (hidden, cell) = self.lstm(out.unsqueeze(1), (hidden, cell))
        # print(f'forward function: out.shape')
        # print(out.shape)
        out = self.fc(out.reshape(out.shape[0], -1))
        return out, (hidden, cell)
    
    def init_hidden(self, batch_size):
        hidden = torch.zeros(self.num_layers * 2, batch_size, self.hidden_size).to(device)
        cell = torch.zeros(self.num_layers * 2, batch_size, self.hidden_size).to(device)
        return hidden, cell



In [200]:
class Generator():
    def __init__(self):
        self.chunk_len = 300
        self.num_epochs = 4000
        self.batch_size = 1
        self.print_every = self.num_epochs // 25 or 1
        self.hidden_size = 256
        self.num_layers = 2
        self.lr = 0.003


    def char_tensor(self, string):
        tensor = torch.zeros(len(string)).long()
        for c in range(len(string)):
            tensor[c] = all_characters.index(string[c])
        return tensor


    def get_random_batch(self):
        start_idx = random.randint(0, len(file) - self.chunk_len)
        end_idx = start_idx + self.chunk_len + 1
        text_str = file[start_idx:end_idx]
        text_input = torch.zeros(self.batch_size, self.chunk_len)
        text_target = torch.zeros(self.batch_size, self.chunk_len)
        for i in range(self.batch_size):
            text_input[i,:] = self.char_tensor(text_str[:-1])
            text_target[i,:] = self.char_tensor(text_str[1:])
        return text_input.long(), text_target.long()


    def generate(self, initial_str='i would like', predict_len=200, temperature=0.85):
        hidden, cell = self.rnn.init_hidden(batch_size=self.batch_size)
        initial_input = self.char_tensor(initial_str)
        predicted = initial_str
        
        for p in range(len(initial_str) - 1):
            _, (hidden, cell) = self.rnn(initial_input[p].view(1).to(device), hidden, cell)

        last_char = initial_input[-1]
        for p in range(predict_len):
            output, (hidden, cell) = self.rnn(last_char.view(1).to(device), hidden, cell)
            output_dist = output.data.view(-1).div(temperature).exp()
            top_char = torch.multinomial(output_dist, 1)[0]
            predicted_char = all_characters[top_char]
            predicted += predicted_char
            last_char = self.char_tensor(predicted_char)

        return predicted


    def train(self):
        self.rnn = RNN(n_characters, self.hidden_size, self.num_layers, n_characters).to(device)
        optimizer = torch.optim.Adam(self.rnn.parameters(), lr=self.lr)
        criterion = nn.CrossEntropyLoss()
        # writer = SummaryWriter(f'runs/alice0')
        print(f'<{datetime.now()}>starting training')

        for epoch in range(1, self.num_epochs + 1):
            input, target = self.get_random_batch()
            hidden, cell = self.rnn.init_hidden(batch_size=self.batch_size)

            self.rnn.zero_grad()
            loss = 0
            input = input.to(device)
            target = target.to(device)

            for c in range(self.chunk_len):
                # print(f'input.shape')
                # print(input.shape)
                # print(f'hidden.shape')
                # print(hidden.shape)
                # print(f'cell.shape')
                # print(cell.shape)
                output, (hidden, cell) = self.rnn(input[:, c], hidden, cell)
                loss += criterion(output, target[:, c])

            loss.backward()
            optimizer.step()
            loss = loss.item() / self.chunk_len

            if epoch % self.print_every == 0:
                print(f'\n\n<{datetime.now()}> | epoch: {epoch}/{self.num_epochs} | loss: {loss}')
                print(self.generate())


In [201]:
gentext = Generator()
gentext.train()


<2023-05-28 20:24:58.018391>starting training


<2023-05-28 20:26:13.248390> | epoch: 160/4000 | loss: 1.8421734619140624
i would like ta say tart opent it reegt?' but the Quell thable bas to say
said
the bot!  And be such, thing it's heer mirs the said hat to thing `I't be she
ras im, and the Mething her those bregcexat a some crep


<2023-05-28 20:27:32.211362> | epoch: 320/4000 | loss: 1.5943794759114582
i would like the word the be, was go it was to bened of the groked to sound her she souse about was foll herself I pit to
bines of their she gone, and they rest wren explen to the-ver the prised
out in Alice of t


<2023-05-28 20:28:52.131362> | epoch: 480/4000 | loss: 0.9546227010091146
i would like a growled your she
pight at as the was as first, `I've to she garmes, and as the
Queen had it was said for little.

  `Peraming Master was
shouted to say
wish she was to now on growify and the next,'


<2023-05-28 20:30:11.417361> | epoch: 640/4000 | loss: 1.323508504231771
i would like

In [202]:
gentext.generate(initial_str='i would like', predict_len=400, temperature=0.65)


"i would like the Dormouse close to of their faces.  `I don't take song to go!  The Duchess!  The Footman said to the Dormouse,' said the Hatter.  `Oh my tea with the moon,\n    And they're mulaully fetching their slates--and I've got to be an old panch is as you could not.'\n\n  `The lobsters!' said the Duchess; `I\nought to be a book to try the hedgehos, or the jury.\n\n                                            "

In [203]:
# m = nn.Linear(512, 100)
# input = torch.randn(1, 512)
# output = m(input)
# print(output)
