In [34]:
import torch
import torch.nn as nn
import string
import random
import sys
import unidecode
from datetime import datetime
import pickle
# from torch.utils.tensorboard import SummaryWriter


In [35]:
# device config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [36]:
# get words
all_words = []


with open('./data/unique_words_2k_articles.dat', 'rb') as file:
    all_words = pickle.load(file)

all_words.append(' ')
vocab_length = len(all_words)
print(f'number of words: {vocab_length}')


number of words: 20449


In [37]:
# read file
corpus = []
file_content = []
with open('./data/tokenized_articles.dat', 'rb') as file:
    file_content = pickle.load(file)

# this approach uses embedding, and therefore doesn't need padding so we remove it from the prepared data
print(len(file_content))
for article in file_content:
    corpus.extend([word.strip() for word in article if word not in ['<pad>', ' ']])

print(len(corpus))
print(corpus[-30:])


2000
722324
['<period>', 'both', 'mexico', 'and', 'canada', 'have', 'dismissed', 'his', 'musing', 'in', 'a', 'tuesday', 'speech', 'that', '<quotation_mark>', 'well', 'end', 'up', 'probably', 'terminating', 'nafta', 'at', 'some', 'point', '<quotation_mark>', 'as', 'a', 'negotiating', 'tactic', '<period>']


In [38]:
# module

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.embed = nn.Embedding(input_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(self.hidden_size * num_layers, output_size)

    def forward(self, x, hidden, cell):
        out = self.embed(x)
        out, (hidden, cell) = self.lstm(out.unsqueeze(1), (hidden, cell))
        out = self.fc(out.reshape(out.shape[0], -1))
        return out, (hidden, cell)
    
    def init_hidden(self, batch_size):
        hidden = torch.zeros(self.num_layers * 2, batch_size, self.hidden_size).to(device)
        cell = torch.zeros(self.num_layers * 2, batch_size, self.hidden_size).to(device)
        return hidden, cell
    
    def save(f_path):
        pass

    


In [39]:
class Generator():
    def __init__(self):
        self.chunk_len = 120
        self.num_epochs = 2000
        self.batch_size = 1
        self.print_every = self.num_epochs // 20 or 1
        self.hidden_size = 256
        self.num_layers = 2
        self.lr = 0.003


    def word_tensor(self, string):
        tensor = torch.zeros(len(string)).long()
        for c in range(len(string)):
            tensor[c] = all_words.index(string[c])
        return tensor


    def get_random_batch(self):
        start_idx = random.randint(0, len(corpus) - self.chunk_len)
        end_idx = start_idx + self.chunk_len + 1
        text_str = corpus[start_idx:end_idx]
        text_input = torch.zeros(self.batch_size, self.chunk_len)
        text_target = torch.zeros(self.batch_size, self.chunk_len)
        for i in range(self.batch_size):
            text_input[i,:] = self.word_tensor(text_str[:-1])
            text_target[i,:] = self.word_tensor(text_str[1:])
        return text_input.long(), text_target.long()


    def generate(self, initial_str='the president is dead', predict_len=200, temperature=0.85):
        initial_words = initial_str.split(' ')
        hidden, cell = self.rnn.init_hidden(batch_size=self.batch_size)
        initial_input = self.word_tensor(initial_words)
        predicted = initial_words
        
        for p in range(len(initial_words) - 1):
            _, (hidden, cell) = self.rnn(initial_input[p].view(1).to(device), hidden, cell)

        last_word = initial_input[-1]
        for p in range(predict_len):
            output, (hidden, cell) = self.rnn(last_word.view(1).to(device), hidden, cell)
            output_dist = output.data.view(-1).div(temperature).exp()
            top_word = torch.multinomial(output_dist, 1)[0]
            predicted_word = [all_words[top_word]]
            predicted.extend(predicted_word)
            last_word = self.word_tensor(predicted_word)
            


        return predicted


    def train(self):
        self.rnn = RNN(vocab_length, self.hidden_size, self.num_layers, vocab_length).to(device)
        optimizer = torch.optim.Adam(self.rnn.parameters(), lr=self.lr)
        criterion = nn.CrossEntropyLoss()
        print(f'<{datetime.now()}>starting training')
        lowest_loss = 100.0
        for epoch in range(1, self.num_epochs + 1):
            input, target = self.get_random_batch()
            hidden, cell = self.rnn.init_hidden(batch_size=self.batch_size)

            self.rnn.zero_grad()
            loss = 0
            input = input.to(device)
            target = target.to(device)

            for c in range(self.chunk_len):
                output, (hidden, cell) = self.rnn(input[:, c], hidden, cell)
                loss += criterion(output, target[:, c])

            loss.backward()
            optimizer.step()
            loss = loss.item() / self.chunk_len
            if loss < lowest_loss:
                self.best_model = self.rnn.state_dict()
                print(f'better model found after {epoch} epochs with loss: {loss}')
                lowest_loss = loss

            if epoch % self.print_every == 0:
                print(f'\n\n<{datetime.now()}> | epoch: {epoch}/{self.num_epochs} | loss: {loss}')
                print(self.generate())
        file_path = f'./models/bidir_lstm_ckunk_{self.chunk_len}_words_{vocab_length}_loss_{lowest_loss}.model'
        torch.save(self.best_model, file_path)

In [40]:
gentext = Generator()
gentext.train()


<2023-05-29 16:37:46.034563>starting training
better model found with loss: 9.942793782552084
better model found with loss: 9.892582194010417
better model found with loss: 9.825705973307292
better model found with loss: 9.390010579427083
better model found with loss: 8.53220723470052
better model found with loss: 8.011594645182292
better model found with loss: 7.662193806966146
better model found with loss: 7.288881429036459
better model found with loss: 7.066395568847656
better model found with loss: 6.803366088867188
better model found with loss: 6.639908854166666
better model found with loss: 6.361934407552083
better model found with loss: 6.24465077718099
better model found with loss: 5.987163289388021


<2023-05-29 16:38:26.945564> | epoch: 100/2000 | loss: 7.183804321289062
['the', 'president', 'is', 'dead', 'of', 'wanted', 'place', 'election', '<period>', 'thursday', 'of', 'along', '<period>', 'to', '[2146', '<period>', 'sunk', 'joyces', 'the', 'research', '<comma>', 'he', 'shoc

In [41]:
s1 = ' '.join(gentext.generate(initial_str='i would like', predict_len=400, temperature=0.2)).replace('<quotation_mark>', '"').replace(' <question_mark>','?').replace(' <comma>', ',').replace(' <period>', '.')
s2 = ' '.join(gentext.generate(initial_str='i would like', predict_len=400, temperature=0.4)).replace('<quotation_mark>', '"').replace(' <question_mark>','?').replace(' <comma>', ',').replace(' <period>', '.')
s3 = ' '.join(gentext.generate(initial_str='i would like', predict_len=400, temperature=0.6)).replace('<quotation_mark>', '"').replace(' <question_mark>','?').replace(' <comma>', ',').replace(' <period>', '.')
s4 = ' '.join(gentext.generate(initial_str='i would like', predict_len=400, temperature=0.8)).replace('<quotation_mark>', '"').replace(' <question_mark>','?').replace(' <comma>', ',').replace(' <period>', '.')

statements = [s1, s2, s3, s4]
for statement in statements:
    print(statement)


i would like the u. s. constitution. " the president is going to be a " rocket man " on the issue, " he said. " we have a lot of the house and senate, " he said. " the president is going to be a " great " for " the president and the u. s. president donald trump and trump has denied any collusion by the u. s. election, " he said. " the president is a " great " " " rocket man " on the house of representatives and senate republicans are not going to be a " rocket man " on the senate floor. " the president is going to be a " great " " and " that " we have a " great " and " the president of the u. s. constitution, " he said. " the president is going to be a " rocket man " as " the white house, " he said. " the president is going to be a " great " and " that " i think we dont get the tax bill, " he said. " the senate is not going to be a " rocket man " in the u. s. election, " he said. " the president is going to be a " rocket man " on the same time. " the president is talking about the same

In [42]:
# m = nn.Linear(512, 100)
# input = torch.randn(1, 512)
# output = m(input)
# print(output)
