<a href="https://colab.research.google.com/github/Da-Pen/CS486-twitter-bot/blob/main/LSTM/CS486_LSTM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Preprocessing Code

In [None]:
import numpy as np
from collections import defaultdict

# CONSTANTS
NEWS_ORGS_DATA_FILE_NAME = '/content/data/newsorgs_data'
TRUMP_DATA_FILE_NAME = '/content/data/donald_trump_data'
SKIP_URLS = True
SKIP_ELLIPSES = True
SKIP_RETWEETS = True
SKIP_REPLIES = True     # it seems like Trump often has tweets where he simply replies to another Twitter user or quotes them. They usually start with '@' or '"@'. If this is set to true, then ignore those tweets.
ONLY_LOWERCASE = True  # if set to True, convert all text to lowercase
VALID_CHARS = [
    ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', ',', '-', '.', '/', '\n',
    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 
    ':', ';', '?', '@', 
    '_', 
    'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 
    '—', '’', '“', '”'
]
if not ONLY_LOWERCASE:
    VALID_CHARS = VALID_CHARS + ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']

VALID_CHARS_SET = set(VALID_CHARS)

CHAR_TO_INDEX = dict((c, i) for i, c in enumerate(VALID_CHARS))
INDEX_TO_CHAR = dict((i, c) for i, c in enumerate(VALID_CHARS))

MIN_VALID_CHAR_PERCENT = 0.9  # at least this ratio of the characters in the tweet have to be valid (i.e have to exist in the above grammar). Otherwise we ignore the tweet.


# returns a string minus all the urls in it
def ignore_urls(s):
    return ' '.join([x for x in s.split() if 'http' not in x])


# filters invalid characters. If the number of characters filtered is greater than (1 - MIN_VALID_CHAR_PERCENT), then return None.
def filter_invalid_chars(tweet):
    new_tweet = ''.join([char for char in tweet if char in VALID_CHARS_SET])
    if len(new_tweet) / len(tweet) < MIN_VALID_CHAR_PERCENT:
        # print('ignoring tweet (too few valid characters):', tweet)
        return None
    # elif len(new_tweet) < len(tweet):
    #     print('filtered tweet. Old:', tweet, 'New:', new_tweet)
    return new_tweet


# gets a list of strings representing the tweets in the given file.
# can limit the number of tweets to get using upto.
# replaces 'NEWLINE's with actual \n characters.
def get_tweets_list(filename, upto=None):
    f = open(filename, 'r')
    lines = f.read().split('\n')[:upto]
    f.close()
    # replace NEWLINE's and ignore all lines that do not have spaces (because they are probably just a link)
    lines = [line.replace('NEWLINE', '\n') for line in lines if line.strip().find(' ') != -1]
    if ONLY_LOWERCASE:
        lines = [line.lower() for line in lines]
    if SKIP_ELLIPSES:  # skip tweets with the '…' character, which indicates that it has been truncated
        lines = [line for line in lines if line.find('…') == -1]
    if SKIP_URLS:
        lines = [ignore_urls(line) for line in lines]
    if SKIP_RETWEETS:
        lines = [line for line in lines if line[:2] != 'RT']
    if SKIP_REPLIES:
        lines = [line for line in lines if line[0] != '@' and line[:2] != '"@']
    # check what percentage of characters are valid: if less than MIN_VALID_CHAR_PERCENT are valid, then ignore this tweet. Otherwise, delete invalid characters.
    lines = [filter_invalid_chars(line) for line in lines if filter_invalid_chars(line) is not None]
    return np.array(lines)


# gets most commonly occuring characters in datasets
def get_commonly_occuring_characters():
    threshold = 500  # If characters appear more than threshold times in all tweets in all datasets, it is printed
    news_org_tweets = get_tweets_list(NEWS_ORGS_DATA_FILE_NAME)
    trump_tweets = get_tweets_list(TRUMP_DATA_FILE_NAME)
    # see which characters exist
    chars_to_occurrence_map = defaultdict(lambda: 0)
    for tweet in news_org_tweets:
        for char in tweet:
            chars_to_occurrence_map[char] += 1
    for tweet in trump_tweets:
        for char in tweet:
            chars_to_occurrence_map[char] += 1
    chars_set = set()
    for char in chars_to_occurrence_map.keys():
        if chars_to_occurrence_map[char] > threshold:
            chars_set.add((char, chars_to_occurrence_map[char]))  # if we want to print (char, occurence_times) tuples
            # chars_set.add(char)                                     # if we want to print just the characters  
    return sorted(list(chars_set))


def main():
    pass    # do nothing (may comment out if we want to test something)
    # news_org_tweets = get_tweets_list(NEWS_ORGS_DATA_FILE_NAME)
    # trump_tweets = get_tweets_list(TRUMP_DATA_FILE_NAME)
    # print('\n'.join(trump_tweets))


if __name__ == '__main__':
    main()


Training Code

In [None]:
from keras.callbacks import LambdaCallback
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM, Bidirectional, BatchNormalization, Activation
from keras.layers import Dropout
from keras.optimizers import RMSprop
from keras.optimizers import Adam
from keras.utils.data_utils import get_file
import random
import sys
import io
from google.colab import files

INPUT_LENGTH = 40  # based on INPUT_LENGTH characters, our model generates the next character
GENERATED_TWEET_LENGTH = 120


def sample(preds, temperature=1.0):
    # helper function to sample an index from a probability array
    preds = np.asarray(preds).astype('float64')
    preds = np.log(preds) / temperature
    exp_preds = np.exp(preds)
    preds = exp_preds / np.sum(exp_preds)
    probas = np.random.multinomial(1, preds, 1)
    return np.argmax(probas)


def on_epoch_end(epoch, _, data, model):
    # Function invoked at end of each epoch. Prints generated text.
    print()
    print('----- Generating text after Epoch: %d' % epoch)
    for _ in range(2):     # use 10 different tweets as samples
        tweet = np.random.choice(data) # select random tweet
        start_index = 0

        for diversity in [0.2, 0.4, 0.6, 1.0]:
        # for diversity in [0.1, 0.2, 0.3, 0.4]:
        # for diversity in [0.3, 0.4, 0.5]:
            print('----- diversity:', diversity)

            generated = ''
            sentence = tweet[start_index: start_index + INPUT_LENGTH]
            generated += sentence
            print('----- Generating with seed: "' + sentence + '"')
            sys.stdout.write(generated)

            for i in range(GENERATED_TWEET_LENGTH):
                x_pred = np.zeros((1, INPUT_LENGTH, len(VALID_CHARS)))
                for t, char in enumerate(sentence):
                    x_pred[0, t, CHAR_TO_INDEX[char]] = 1.

                preds = model.predict(x_pred, verbose=0)[0]
                next_index = sample(preds, diversity)
                next_char = INDEX_TO_CHAR[next_index]

                generated += next_char
                sentence = sentence[1:] + next_char

                sys.stdout.write(next_char)
                sys.stdout.flush()
            print()


def train_from_data(data):
    # convert the raw tweets list to input and output
    # input is equal to INPUT_LENGTH characters, output is a single character
    sentences = []
    next_chars = []
    for x in data:
        for i in range(0, len(x) - INPUT_LENGTH):
            sentences.append(x[i: i + INPUT_LENGTH])
            next_chars.append(x[i + INPUT_LENGTH])
    print('# training samples:', len(sentences))
    # for i in range(10):
    #     print(sentences[i],'->',next_chars[i])

    # vectorize the data
    print('Vectorization...')
    x = np.zeros((len(sentences), INPUT_LENGTH, len(VALID_CHARS)), dtype=np.bool)
    y = np.zeros((len(sentences), len(VALID_CHARS)), dtype=np.bool)
    for i, sentence in enumerate(sentences):
        for t, char in enumerate(sentence):
            x[i, t, CHAR_TO_INDEX[char]] = 1
        y[i, CHAR_TO_INDEX[next_chars[i]]] = 1

    # build the model
    print('Build model...')
    model = Sequential()
    model.add(LSTM(len(VALID_CHARS) * 7, input_shape=(INPUT_LENGTH, len(VALID_CHARS))))
    
    model.add(BatchNormalization())
    model.add(Activation('selu'))

    model.add(Dense(len(VALID_CHARS)*4))
    model.add(Activation('selu'))

    model.add(Dense(len(VALID_CHARS)*4))
    model.add(BatchNormalization())
    model.add(Activation('selu'))

    # model.add(Bidirectional(LSTM(128), input_shape=(INPUT_LENGTH, len(VALID_CHARS))))
    model.add(Dense(len(VALID_CHARS), activation='softmax'))

    # optimizer = RMSprop(lr=0.01)
    optimizer = Adam()
    model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['categorical_crossentropy', 'accuracy'])

    epochs = 10
    
    print_callback = LambdaCallback(on_epoch_end=lambda a, b: on_epoch_end(a, b, data, model))

    # train the model
    model.fit(x, y,
            epochs=epochs,
            callbacks=[print_callback]
            )

    # save and download the model
    model.save('/content/model')
    !zip -r /content/model.zip /content/model
    files.download('/content/model.zip')

def main():
    # TRAIN TRUMP
    trump_data = get_tweets_list(TRUMP_DATA_FILE_NAME, 1000)  # TODO remove upto
    print("number of trump tweets:", len(trump_data))
    train_from_data(trump_data)
    # TRAIN NEWS ORGS
    # TODO


if __name__ == '__main__':
    main()

number of trump tweets: 850
# training samples: 48813
Vectorization...
Build model...
Epoch 1/10
----- Generating text after Epoch: 0
----- diversity: 0.2
----- Generating with seed: "take a look at and to see these beautifu"
take a look at and to see these beautifus frlllve ne tr mp nn tr mp nn tr mp nf n nf n ma n n walll tr mp tr mp nf n nf n tr mp nn tr mp nn nf lll fr mp ramp tr
----- diversity: 0.4
----- Generating with seed: "take a look at and to see these beautifu"
take a look at and to see these beautifurlloll be cell nf nr pprrmp nn nf llll bloll hamp of nr mall to llll fr mp rlak ffll tr mall ha nelvel to t ffrmm nat n 
----- diversity: 0.6
----- Generating with seed: "take a look at and to see these beautifu"
take a look at and to see these beautifur malllu grrl #t mpmrttrget nellolf rall llf t mrl hall be tall held nlve nn nelvert fr mas ralldebt tamp cllbright fex 
----- diversity: 1.0
----- Generating with seed: "take a look at and to see these beautifu"
take a look at 