<a href="https://colab.research.google.com/github/RebeccaRoberts/phd-codebites/blob/master/VMP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

VMP Code

In [15]:
import numpy as np
from scipy.special import digamma
from scipy import stats
import copy

def convert_texts_for_vmp(texts):
    # convert to format for VMP
    dictionary_vmp = []
    corpus_vmp = []
    for text in texts:
        doc = []
        for word in text:
            if word not in dictionary_vmp:
                dictionary_vmp.append(word)
            doc.append(word)
        corpus_vmp.append(doc)
    return corpus_vmp, dictionary_vmp

def create_topic_word_dirs(K=None, V=None, prior=0.1, noise_variance=0.001):
    topic_word_list = []
    for k in range(K):
        noisy_vector = np.ones(V) + np.random.normal(0, noise_variance, V)
        topic_word_list.append((noisy_vector/np.sum(noisy_vector)) * prior)

    return np.array(topic_word_list)


def create_doc_topic_list_dirs(M=None, K=None, prior=0.1):
    doc_topic_list = []
    for d in range(M):
        doc_topic_list.append(np.ones(K) * prior)
    return np.array(doc_topic_list)


def create_word_give_topic_cat(K=None, V=None):
    word_given_topic_cat = []
    for k in range(K):
        word_given_topic_cat.append(np.ones(V) * (1 / V))
    return np.array(word_given_topic_cat)


def create_topic_given_doc_cat(M=None, K=None):
    topic_given_doc_cat = []
    for n in range(M):
        topic_given_doc_cat.append(np.ones(K) * (1 / K))
    return np.array(topic_given_doc_cat)


def initialize_graph(M=None, N=None, K=None, V=None, doc_prior=None, topic_prior=None, noise_variance=None):
    topic_word_list = create_topic_word_dirs(K=K, V=V, prior=topic_prior, noise_variance=noise_variance)
    word_given_topic_cat = create_word_give_topic_cat(K=K, V=V)
    topic_given_doc_cat = create_topic_given_doc_cat(M=M, K=K)
    doc_topic_list = create_doc_topic_list_dirs(M=M, K=K, prior=doc_prior)
    return topic_word_list, word_given_topic_cat, topic_given_doc_cat, doc_topic_list


def run_vmp_lda_one_by_one(corpus=None, dictionary=None, K=7, doc_prior=0.5, topic_prior=0.5, noise_variance=0.0001,
                           epochs=100, true_message=True, words_in_topics_ndarray_original=None):
    M = len(corpus)
    N = len(corpus[0])
    V = len(set(dictionary)) + 1
    print(len(dictionary))
    print(V)
    print(M)

    topic_word_ndarray, word_given_topic_cat, topic_given_doc_cat, doc_topic_array = initialize_graph(M=M, N=N, K=K, V=V, doc_prior=doc_prior, topic_prior=topic_prior,
                                                                                                  noise_variance=noise_variance)

    topic_word_ndarray_prior = create_topic_word_dirs(K=K, V=V, prior=topic_prior, noise_variance=0)
    doc_topic_array_prior = copy.deepcopy(doc_topic_array)
    topic_given_doc_cat_mat = []
    word_given_topic_cat_mat = []
    for m in range(M):
        word_given_topic_cat_list = []
        topic_given_doc_cat_list = []
        for n in range(len(corpus[m])):
            word_given_topic_cat_list.append(word_given_topic_cat)
            topic_given_doc_cat_list.append(topic_given_doc_cat)
        word_given_topic_cat_mat.append(word_given_topic_cat_list)
        topic_given_doc_cat_mat.append(topic_given_doc_cat_list)
    # topic_given_doc_cat_mat = np.array(topic_given_doc_cat_mat)
    # word_given_topic_cat_mat = np.array(word_given_topic_cat_mat)

    previous_topic_word_ndarray = copy.deepcopy(topic_word_ndarray)
    previous_doc_topic_array = copy.deepcopy(doc_topic_array)
    kl_list = []
    previous_doc_topic_array_list = []
    previous_topic_word_ndarray_list = []

    for epoc in range(epochs):
        for m in range(M):
            # ++++++++++++++++++++++++
            # ++ * forward sweep +++++
            # +++++++++++++++++++++++
            for n in range(len(corpus[m])):
                # from topic-word dir to word-topic categorical 
                for k in range(K):
                    if true_message:
                        # adapted VMP message from dirichlet psudo-counts 
                        message_from_topic_word_dir = previous_topic_word_ndarray[k]/np.sum(previous_topic_word_ndarray[k])
                    else:
                        # proper VMP message from dirichlet psudo-counts
                        message_from_topic_word_dir = np.exp(digamma(previous_topic_word_ndarray[k]) -
                                             digamma(np.sum(previous_topic_word_ndarray[k])))
                        # normalise
                        message_from_topic_word_dir = message_from_topic_word_dir / np.sum(message_from_topic_word_dir)

                    # (1) update the word given topic matrix (parent to child) 
                    # print(len(message_from_topic_word_dir),len(corpus[m]),m,n,k)
                    word_given_topic_cat_mat[m][n][k] = message_from_topic_word_dir  # no mixing, just insert

                topic_given_doc_list_mn = []
                for k in range(K):
                    # select the word index corpus[m][n] only and make a list of word probs for that word in each topic (get the slice)
                    topic_given_doc_list_mn.append(word_given_topic_cat_mat[m][n][k][corpus[m][n]])

                # scale the word probabilities for word at index corpus[m][n] over all topics
                message_from_word_given_topic_cat_mat = topic_given_doc_list_mn / np.sum(topic_given_doc_list_mn)  # normalise over K

                # (2) calculate message to document topic dirichlet (child to parent) 
                updated_doc_topic_proportions = topic_given_doc_cat_mat[m][n][m] * message_from_word_given_topic_cat_mat  # we mix at this node

                # normalize
                approximate_update_message_to_dirichlet = updated_doc_topic_proportions/np.sum(updated_doc_topic_proportions)

                # (3) update dirichlet counts by adding new document topic proportions (child to parent)
                doc_topic_array[m] += approximate_update_message_to_dirichlet/np.sum(approximate_update_message_to_dirichlet) # mix by adding to \alpha hyperparameters
                
            # ++++++++++++++++++++++++
            # +++ backward sweep +++++
            # ++++++++++++++++++++++++

            for n in range(len(corpus[m])):
                # from topic-doc dir to topic-doc categorical
                if true_message:
                    message_from_doc_topic_dir =doc_topic_array[m]/np.sum(doc_topic_array[m])
                else:
                    message_from_doc_topic_dir = np.exp(digamma(doc_topic_array[m]) -
                                         digamma(np.sum(doc_topic_array[m])))
                    # normalise
                    message_from_doc_topic_dir = message_from_doc_topic_dir / np.sum(message_from_doc_topic_dir)

                # (4) update the doc given topic matrix (parent to child)
                # print(m,n)
                topic_given_doc_cat_mat[m][n][m] = message_from_doc_topic_dir # no mixing, just insert
 

                # (5) + (6a) from topic-doc categorical to topic-word dir 
                topic_given_word_probs = np.ones(K)
                for k in range(K):
                    # a. select the word index corpus[m][n]
                    # b. multiply by the proportion of that topic within that document
                    # This gives a list of word probs for that word in each topic so we can update the Dirichelts
                    # NB: we arent updating word given topic matrix since it will anyways get replaced
                    # in the following epoch in the oposite direction. 
                    # The W node is always a child! (parent to child)
                    topic_given_word_probs[k] = word_given_topic_cat_mat[m][n][k][corpus[m][n]] * \
                                                topic_given_doc_cat_mat[m][n][m][k]  # this looks like mixing but doing this for the dir update

                # normalise over K                              
                message_to_word_topic_dirichlet = (topic_given_word_probs / np.sum(topic_given_word_probs))

                # updating the topic-word dir (the calculated message is done in 6a)
                for k in range(K):
                    topic_word_ndarray[k][corpus[m][n]] += message_to_word_topic_dirichlet[k]

            previous_doc_topic_array_list.append(doc_topic_array)
            previous_topic_word_ndarray_list.append(topic_word_ndarray)


        previous_topic_word_ndarray = copy.deepcopy(topic_word_ndarray)
        topic_word_ndarray = copy.deepcopy(topic_word_ndarray_prior)

        previous_doc_topic_array = copy.deepcopy(doc_topic_array)
        doc_topic_array = copy.deepcopy(doc_topic_array_prior)

    return previous_doc_topic_array, previous_topic_word_ndarray, previous_doc_topic_array_list, previous_topic_word_ndarray_list



Some preprocessy stuff


In [2]:
import os
from subprocess import run
import gzip

from gensim.utils import simple_preprocess
from gensim.models import Phrases
from gensim.corpora import Dictionary
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer



def build_texts(file_name=None, raw_texts=None):
    """
    Build tokenized texts from file name or raw_texts

    Parameters:
    ----------
    file_name: str
            File to be read
    raw_texts: list
            Raw texts with no pre-processing

    Returns:
    -------
    yields generator: if cast to list, is list of text each containing lists of words
    """

    if file_name:
        with open(file_name) as f:
            for line in f:
                yield simple_preprocess(line, deacc=True, min_len=3, max_len=50)
    elif raw_texts:
        for doc in raw_texts:
            # print(doc)
            # print()
            yield simple_preprocess(doc, deacc=True, min_len=3,  max_len=50)


def process_texts(texts, additional_stopwords=None, stop=True, lemmatize=True, collocation=False):
    """
    Function to process texts. Following are the steps we take:

    1. Stopword Removal.
    2. Lemmatization
    3. Collocation detection.

    Parameters:
    ----------
    texts: texts with very basic processing (tokenized)
    additional_stopwords: extra stopwords to remove
    stop: True or False
    lemmatize: True or False
    collocation: True or False

    Returns:
    -------
    texts: Pre-processed tokenized texts.
    """

    if collocation:
        bigram = Phrases(texts)
    if stop:
        nltk.download('stopwords')
        if not additional_stopwords:
            additional_stopwords = []
        stops = set(stopwords.words('english') + additional_stopwords)
    if lemmatize:
        nltk.download('wordnet')
        lemmatizer = WordNetLemmatizer()

    # remove stopwords
    if stop:
        texts = [[word for word in doc if word not in stops] for doc in texts]
    # convert words in different forms to same root word (like stemming)
    if lemmatize:
        texts = [[lemmatizer.lemmatize(word=word, pos='n') for word in doc] for doc in texts]
    # remove stopwords again in case more popped up after lematizing
    if stop:
        texts = [[word for word in doc if word not in stops] for doc in texts]
    if collocation:
        texts = [bigram[doc] for doc in texts]


    return texts
    
def raw_to_corpus_dict_texts(texts=None, filename=None, additional_stopwords=None):
    """
    Generate corpus (bag of words), dictionary and tokenized texts from filename or unprocessed texts

    Parameters:
    ----------
    texts: texts with very basic processing (tokenized)
    fname: name of file to get raw texts
    additional_stopwords: extra stopwords to remove

    Returns:
    -------
    texts: Pre-processed tokenized texts.
    corpus: bag of words corpus
    dictionary: gensim dictionary to link IDs to actual words

    """
    # for text in texts:
    #         print(text)

    if texts:
        texts = list(texts)
    elif filename:
        texts = list(build_texts(filename))
    else:
        raise ValueError("Requires either file_name or raw_tests.")
    texts = process_texts(texts, additional_stopwords=additional_stopwords)

    all_tokens = sum(texts, [])
    tokens_few = set(word for word in set(all_tokens) if all_tokens.count(word) < 3)
    texts = [[word for word in text if word not in tokens_few]
            for text in texts]
  
    corpus, dictionary = processed_to_corpus_dict(texts)


    return texts, corpus, dictionary

def processed_to_corpus_dict(preprocessed):
    """
    Build corpus (bag of words) and dictionary from from preprocessed texts (tokenized)

    Parameters:
    ----------
    preprocessed: tokenized texts

    Returns:
    -------
    corpus: bag of words corpus
    dictionary: gensim dictionary to link IDs to actual words
    """
    dictionary = Dictionary()
    corpus = [dictionary.doc2bow(doc, allow_update=True) for doc in preprocessed]
    dictionary.compactify() # remove gaps in id sequence after words that were removed
    return corpus, dictionary

def convert_texts_for_vmp(texts):
    # convert to format for VMP
    dictionary_vmp = []
    corpus_vmp = []
    for text in texts:
        doc = []
        for word in text:
            if int(word) not in dictionary_vmp:
                dictionary_vmp.append(int(word))
            doc.append(int(word))
        corpus_vmp.append(doc)
    return corpus_vmp, dictionary_vmp

def convert_texts_for_vmp_indices(texts, dictionary):
    # convert to format for VMP
    dictionary_vmp = []
    corpus_vmp = []
    for text in texts:
        doc = []
        for word in text:
            ind = dictionary.token2id[word]
            if int(ind) not in dictionary_vmp:
                dictionary_vmp.append(int(ind))
            doc.append(int(ind))
        corpus_vmp.append(doc)
    return corpus_vmp, dictionary_vmp

How I process a corpus: this is the covid tweets example

In [3]:
from gensim.parsing.preprocessing import strip_multiple_whitespaces, strip_numeric, split_alphanum, strip_non_alphanum

import pandas as pd
import re
import nltk
nltk.download('punkt')
nltk.download('stopwords')

import string
from nltk import word_tokenize, FreqDist
from nltk.corpus import stopwords
regex = re.compile('[^a-zA-Z]')

covid19_tweets = pd.read_csv("sample_data/covid19_tweets.csv")
tweets_list = list(covid19_tweets['text'])[0:20000]
print(tweets_list[0:10])
new_tweets_list = []

for i, tweet in enumerate(tweets_list):
    s = regex.sub(" ", str(tweet))
    ps = strip_multiple_whitespaces(strip_numeric(split_alphanum(strip_non_alphanum(s))))
    new_tweets_list.append(ps)

raw_texts = build_texts(raw_texts=new_tweets_list)

texts, _, _ = raw_to_corpus_dict_texts(texts=raw_texts,additional_stopwords=["http","cov","covid","amp", "july"])

decent_texts = []
for text in texts:
    decent_text = []
    for word in text:
        if len(word)>2:
            decent_text.append(word)

    if len(decent_text) > 3:
      decent_texts.append(decent_text)

texts = decent_texts
_, dictionary = processed_to_corpus_dict(texts)
corpus, dictionary = convert_texts_for_vmp_indices(texts, dictionary)

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
['If I smelled the scent of hand sanitizers today on someone in the past, I would think they were so intoxicated that… https://t.co/QZvYbrOgb0', "Hey @Yankees @YankeesPR and @MLB - wouldn't it have made more sense to have the players pay their respects to the A… https://t.co/1QvW0zgyPu", '@diane3443 @wdunlap @realDonaldTrump Trump never once claimed #COVID19 was a hoax. We all claim that this effort to… https://t.co/Jkk8vHWHb3', '@brookbanktv The one gift #COVID19 has give me is an appreciation for the simple things that were always around me… https://t.co/Z0pOAlFXcW', '25 July : Media Bulletin on Novel #CoronaVirusUpdates #COVID19 \n@kansalrohit69 @DrSyedSehrish @airnewsalerts @ANI… https://t.co/MN0EEcsJHh', "#coronavirus #covid19 deaths continue to rise. It'

In [None]:
K = 10
true_messge = True



print(corpus)
doc_topic_array, topic_word_ndarray, doc_topic_array_list_list_vmp, topic_word_ndarray_list_list_vmp = run_vmp_lda_one_by_one(corpus=corpus,
                                                                          dictionary=dictionary,
                                                                          K=K,
                                                                          doc_prior=0.5,
                                                                          topic_prior=0.5,
                                                                          noise_variance=0.0001,
                                                                          epochs=100,
                                                                          true_message=true_messge)


[[0, 2, 5, 3, 1, 6, 4], [7, 14, 9, 8, 13, 11, 10, 12], [20, 21, 19, 16, 18, 15, 17], [26, 24, 25, 27, 28, 22, 23], [33, 34, 32, 37, 29, 30, 35, 36, 31, 38], [39, 44, 40, 41, 42, 43, 41], [51, 47, 46, 49, 50, 48, 45], [55, 53, 54, 56, 52], [57, 59, 60, 58], [62, 65, 64, 63, 61], [70, 69, 26, 71, 68, 71, 66, 67], [73, 75, 76, 74, 72, 73], [83, 78, 5, 86, 82, 77, 84, 81, 85, 79, 80, 87], [93, 91, 89, 90, 92, 62, 88], [96, 97, 94, 98, 95], [93, 103, 105, 101, 102, 104, 99, 100], [20, 53, 106, 108, 109, 107, 110, 111], [33, 114, 112, 93, 114, 112, 113, 115], [117, 121, 118, 120, 116, 122, 119], [125, 124, 129, 131, 123, 127, 132, 131, 126, 128, 130], [135, 136, 133, 138, 137, 79, 133, 134, 136, 82], [64, 140, 139, 2], [142, 90, 141, 144, 143], [145, 147, 148, 149, 146, 137, 5, 150], [156, 156, 153, 154, 151, 152, 157, 155, 158, 156, 153, 154], [161, 160, 159, 79, 162], [163, 169, 166, 167, 92, 137, 165, 164, 168], [93, 150, 62, 56, 34, 171, 170, 112, 98, 150], [174, 45, 172, 62, 26, 173], [