In [14]:
import pickle
import pandas as pd
import re
import torch
import torch.nn as nn
import string
import random
import sys
import unidecode
from datetime import datetime


In [15]:
corpus_size = 2000
# device config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [16]:
df = pd.read_csv('./data/True.csv')
df.head()


Unnamed: 0,title,text,subject,date
0,"As U.S. budget fight looms, Republicans flip t...",WASHINGTON (Reuters) - The head of a conservat...,politicsNews,"December 31, 2017"
1,U.S. military to accept transgender recruits o...,WASHINGTON (Reuters) - Transgender people will...,politicsNews,"December 29, 2017"
2,Senior U.S. Republican senator: 'Let Mr. Muell...,WASHINGTON (Reuters) - The special counsel inv...,politicsNews,"December 31, 2017"
3,FBI Russia probe helped by Australian diplomat...,WASHINGTON (Reuters) - Trump campaign adviser ...,politicsNews,"December 30, 2017"
4,Trump wants Postal Service to charge 'much mor...,SEATTLE/WASHINGTON (Reuters) - President Donal...,politicsNews,"December 29, 2017"


In [17]:



# functions to clean data

# filter out first part containing CITY (news agency) and separator "-"

# to be replaced: “ ”

# search and replace regex
double_quotes = r'“|”'
single_quotes = r'’|‘'
backslashes = r'\\'
multiple_whitespace = r'\t|\v|\f| '
double_quotes = re.compile(double_quotes)
single_quotes = re.compile(single_quotes)
backslashes = re.compile(backslashes)
multiple_whitespace = re.compile(multiple_whitespace)


def clean_data(row):
    txt = row[1].lower()
    txt = txt[txt.find('-')+1:].lstrip()
    txt = double_quotes.sub('"', txt)
    txt = txt.replace("’", "")
    txt = txt.replace("‘", "")
    txt = multiple_whitespace.sub(' ', txt)
    # remove everything before the first dash (news agency and city)
    txt = txt[txt.find('-')+1:]

    return txt


In [18]:
# function to tokenize the text

def token_lookup():
    """
    Generate a dict to turn punctuation into a token.
    :return: Tokenized dictionary where the key is the punctuation and the value is the token
    """
    # TODO: Implement Function
    token = dict()
    token['.'] = ' <PERIOD> '
    token[','] = ' <COMMA> '
    token['"'] = ' <QUOTATION_MARK> '
    token[':'] = ' <COLON>'
    token[';'] = ' <SEMICOLON> '
    token['!'] = ' <EXCLAIMATION_MARK> '
    token['?'] = ' <QUESTION_MARK> '
    token['('] = ' <LEFT_PAREN> '
    token[')'] = ' <RIGHT_PAREN> '
    token['-'] = ' <QUESTION_MARK> '
    token['\n'] = ' <NEW_LINE> '
    return token


In [19]:
def pad_to_max(tokenized, max):
    padding_length = max - len(tokenized)
    if padding_length == 0:
        return tokenized
    padding = ['<pad>' for i in range(padding_length)]
    tokenized.extend(padding)
    return tokenized


In [20]:
df = df[:corpus_size]
df = df.astype({'text': 'string'})

df['text'] = df.apply(clean_data, axis=1)
print(df['text'][0])
articles = df['text'].values.tolist()

longest_article = 0

token_dict = token_lookup()

tokenized_articles = []

for article in articles:
    for key, token in token_dict.items():
        article = article.replace(key, token)
    article = article.lower()
    article = article.split()
    if len(article) > longest_article:
        longest_article = len(article)
    tokenized_articles.append(article)
# for key, token in token_dict.items():
#     articles[0] = article[0].replace(key, token)

print(f'longest article contains {longest_article} tokens')


unique_tokens = set()

for tokens in tokenized_articles:
    tokens = pad_to_max(tokens, longest_article)
    for token in tokens:
        unique_tokens.add(token)

unique_tokens = list(unique_tokens)

print(f'there are {len(unique_tokens)} unique tokens')

print(
    f'articles equal length: {len(tokenized_articles[0])==len(tokenized_articles[1])}')

articles = [' '.join(art) for art in tokenized_articles]

print(articles[0])
print(tokenized_articles[0])


defense "discretionary" spending on programs that support education, scientific research, infrastructure, public health and environmental protection. "the (trump) administration has already been willing to say: were going to increase non-defense discretionary spending ... by about 7 percent," meadows, chairman of the small but influential house freedom caucus, said on the program. "now, democrats are saying thats not enough, we need to give the government a pay raise of 10 to 11 percent. for a fiscal conservative, i dont see where the rationale is. ... eventually you run out of other peoples money," he said. meadows was among republicans who voted in late december for their partys debt-financed tax overhaul, which is expected to balloon the federal budget deficit and add about $1.5 trillion over 10 years to the $20 trillion national debt. "its interesting to hear mark talk about fiscal responsibility," democratic u.s. representative joseph crowley said on cbs. crowley said the republic

In [21]:
with open('./data/tokenized_2k_articles.dat', 'wb') as file:
    pickle.dump(tokenized_articles, file)

with open('./data/unique_words_2k_articles.dat', 'wb') as file:
    pickle.dump(unique_tokens, file)


In [22]:
# get words
all_words = []


with open('./data/unique_words_2k_articles.dat', 'rb') as file:
    all_words = pickle.load(file)

all_words.append(' ')
vocab_length = len(all_words)
print(f'number of words: {vocab_length}')

number of words: 20449


In [23]:
# read file
corpus = []
file_content = []
with open('./data/tokenized_articles.dat', 'rb') as file:
    file_content = pickle.load(file)

# this approach uses embedding, and therefore doesn't need padding so we remove it from the prepared data
print(len(file_content))
for article in file_content:
    corpus.extend([word.strip() for word in article if word not in ['<pad>', ' ']])

print(len(corpus))
# print(corpus[-30:])


2000
722324


In [24]:
# module

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.embed = nn.Embedding(input_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(self.hidden_size * num_layers, output_size)

    def forward(self, x, hidden, cell):
        out = self.embed(x)
        out, (hidden, cell) = self.lstm(out.unsqueeze(1), (hidden, cell))
        out = self.fc(out.reshape(out.shape[0], -1))
        return out, (hidden, cell)
    
    def init_hidden(self, batch_size):
        hidden = torch.zeros(self.num_layers * 2, batch_size, self.hidden_size).to(device)
        cell = torch.zeros(self.num_layers * 2, batch_size, self.hidden_size).to(device)
        return hidden, cell
    
    def save(f_path):
        pass

    


In [25]:
class Generator():
    def __init__(self, chunk_length=200, num_epochs=4000, batch_size=1, hidden_size=256, num_layers=2, learning_rate=0.002):
        self.chunk_len = chunk_length
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.print_every = self.num_epochs // 20 or 1
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lr = learning_rate


    def word_tensor(self, string):
        tensor = torch.zeros(len(string)).long()
        for c in range(len(string)):
            tensor[c] = all_words.index(string[c])
        return tensor


    def get_random_batch(self):
        start_idx = random.randint(0, len(corpus) - self.chunk_len)
        end_idx = start_idx + self.chunk_len + 1
        text_str = corpus[start_idx:end_idx]
        text_input = torch.zeros(self.batch_size, self.chunk_len)
        text_target = torch.zeros(self.batch_size, self.chunk_len)
        for i in range(self.batch_size):
            text_input[i,:] = self.word_tensor(text_str[:-1])
            text_target[i,:] = self.word_tensor(text_str[1:])
        return text_input.long(), text_target.long()


    def generate(self, initial_str='the president is dead', predict_len=200, temperature=0.85):
        initial_words = initial_str.split(' ')
        hidden, cell = self.rnn.init_hidden(batch_size=self.batch_size)
        initial_input = self.word_tensor(initial_words)
        predicted = initial_words
        
        for p in range(len(initial_words) - 1):
            _, (hidden, cell) = self.rnn(initial_input[p].view(1).to(device), hidden, cell)

        last_word = initial_input[-1]
        for p in range(predict_len):
            output, (hidden, cell) = self.rnn(last_word.view(1).to(device), hidden, cell)
            output_dist = output.data.view(-1).div(temperature).exp()
            top_word = torch.multinomial(output_dist, 1)[0]
            predicted_word = [all_words[top_word]]
            predicted.extend(predicted_word)
            last_word = self.word_tensor(predicted_word)

        return predicted


    def train(self):
        self.rnn = RNN(vocab_length, self.hidden_size, self.num_layers, vocab_length).to(device)
        optimizer = torch.optim.Adam(self.rnn.parameters(), lr=self.lr)
        criterion = nn.CrossEntropyLoss()
        print(f'<{datetime.now()}>starting training')
        lowest_loss = 100.0 # just a high value, should not be lower than 10
        for epoch in range(1, self.num_epochs + 1):
            input, target = self.get_random_batch()
            hidden, cell = self.rnn.init_hidden(batch_size=self.batch_size)

            self.rnn.zero_grad()
            loss = 0
            input = input.to(device)
            target = target.to(device)

            for c in range(self.chunk_len):
                output, (hidden, cell) = self.rnn(input[:, c], hidden, cell)
                loss += criterion(output, target[:, c])

            loss.backward()
            optimizer.step()
            loss = loss.item() / self.chunk_len
            if loss < lowest_loss:
                self.best_model = self.rnn.state_dict()
                print(f'<{datetime.now()}> better model found after {epoch}/{self.num_epochs} epochs with loss: {loss}')
                lowest_loss = loss

            if epoch % self.print_every == 0:
                print(f'\n\n<{datetime.now()}> | epoch: {epoch}/{self.num_epochs} | loss: {loss}')
                # print(self.generate())
        file_path = f'./models/bidir_lstm_chunk_{self.chunk_len}_words_{vocab_length}_loss_{lowest_loss}.pt'
        print(f'saving model at {file_path}')
        torch.save(self.best_model, file_path)

In [26]:
# default parameters for generator: 
# chunk_length=200, num_epochs=4000, batch_size=1, hidden_size=256, num_layers=2, learning_rate=0.002

chunk_lengths = [50, 100, 150, 200, 300]
epoch_numbers = [500, 1000, 2000, 3000, 5000]
batch_sizes = [1, 2, 5, 10]
hidden_sizes = [64, 128, 256, 512]
layer_numbers = [2, 4]
learning_rates = [0.001, 0.002, 0.003, 0.005]


gentext = Generator()
gentext.train()


<2023-05-29 17:16:40.936561>starting training
<2023-05-29 17:16:41.839686> better model found after 1/10000 epochs with loss: 9.929571940104166
<2023-05-29 17:16:42.746757> better model found after 2/10000 epochs with loss: 9.911407063802084
<2023-05-29 17:16:43.600788> better model found after 3/10000 epochs with loss: 9.866660970052083
<2023-05-29 17:16:44.472761> better model found after 4/10000 epochs with loss: 9.818041178385416
<2023-05-29 17:16:45.345787> better model found after 5/10000 epochs with loss: 9.772354329427083
<2023-05-29 17:16:46.208759> better model found after 6/10000 epochs with loss: 9.625013020833334
<2023-05-29 17:16:47.080785> better model found after 7/10000 epochs with loss: 9.422939453125
<2023-05-29 17:16:47.941787> better model found after 8/10000 epochs with loss: 8.801774088541666
<2023-05-29 17:16:48.852787> better model found after 9/10000 epochs with loss: 8.1427197265625
<2023-05-29 17:16:49.717760> better model found after 10/10000 epochs with lo

In [None]:
temperatures = [0.2, 0.4, 0.6, 0.8]

for temperature in temperatures:
    stmt = ' '.join(gentext.generate(initial_str='i would like', predict_len=150, temperature=temperature)).replace('<quotation_mark>', '"').replace(' <question_mark>','?').replace(' <comma>', ',').replace(' <period>', '.')
    print(f'temperature: {temperature}\nstatement:\n{stmt}')
