# Lab5: LDA Topic Model with Gibbs Sampling

## Glimpse of Dataset

In [1]:
import numpy as np

raw_text_list = np.load("text.npy")
print(raw_text_list[0])
print('There even exists null text: ', raw_text_list[51])

Well i'm not sure about the story nad it did seem biased. What
I disagree with is your statement that the U.S. Media is out to
ruin Israels reputation. That is rediculous. The U.S. media is
the most pro-israeli media in the world. Having lived in Europe
I realize that incidences such as the one described in the
letter have occured. The U.S. media as a whole seem to try to
ignore them. The U.S. is subsidizing Israels existance and the
Europeans are not (at least not to the same degree). So I think
that might be a reason they report more clearly on the
atrocities.
	What is a shame is that in Austria, daily reports of
the inhuman acts commited by Israeli soldiers and the blessing
received from the Government makes some of the Holocaust guilt
go away. After all, look how the Jews are treating other races
when they got power. It is unfortunate.

There even exists null text:   


## Hyper-Parameters and Global Variables

In [2]:
dataset_name = 'text'
# dataset_name = 'computer-cat'
if dataset_name == 'text':
    raw_text_list = np.load("text.npy")
    num_topics = 20
    num_keywords = 10
    alpha = np.ones(num_topics)
    eta = None  # TO_BE_ASSIGNED
elif dataset_name == 'computer-cat':
    raw_text_list = np.load("computer-cat.npy")
    num_topics = 2
    num_keywords = 5
    alpha = np.ones(num_topics)
    eta = None  # TO_BE_ASSIGNED

## Data Preprocess

1. filter strange words, here `sklearn.feature_extraction.text.CountVectorizer` or `nltk.*` is used.
1. get the dictionary
2. transform texts into `Doc` object, which contains the `freq_dist` and the `words` set.

In [3]:
import re
import nltk
import numpy as np
from nltk.corpus import stopwords
from collections import Counter
from sklearn.feature_extraction.text import CountVectorizer

punctuations = {'(', ')', '{', '}', '[', ']', '"', "'",
                ',', ';', '.', '!', '?'}
stopwords_set = set(stopwords.words('english'))
exception_search_pattern = re.compile(r'\^+|=+|~+|-\|*-|_+|\|_*\|')
exception_match_pattern = re.compile(r'^[\']?[-+]?([0-9]+(\.[0-9]+)?|\.[0-9]+)|(/\\)+')
should_search_pattern = re.compile(r'[\w]+')
brown_taged = dict(nltk.corpus.brown.tagged_words())

count = 0
def pretransform(word):
    global count
    ok = word not in stopwords_set and word not in punctuations \
        and not exception_search_pattern.search(word) and not exception_match_pattern.match(word)
    ok = ok and should_search_pattern.search(word)
    if not ok:
        return ''
    if word[0] == '-' or word[0] == '|':
        word = word[1:]
    if word in brown_taged and brown_taged[word] != 'NN':
        word = ''
    return word

class Doc:
    def __init__(self, text, count_vectorizer: CountVectorizer, tokenizer='sklearn'):
        # 1. get tokens by `sklearn.feature_extraction.text.CountVectorizer` or by `nltk.word_tokenize()`
        if tokenizer == 'sklearn':
            words_vec = count_vectorizer.transform([text]).toarray()[0]
            dictionary = count_vectorizer.get_feature_names()
            wids = np.where(words_vec > 0)[0]
            words = []
            for wid in wids:
                words += [dictionary[wid] for _ in range(words_vec[wid])]
        else:
            assert tokenizer == 'nltk' and 'Please use `sklearn` or `nltk` as tokenizer.'
            words = [word.lower() for sent in nltk.sent_tokenize(text) for word in nltk.word_tokenize(sent)]
        # 2. pretransform words with `pretransform` to filter useless words or turn words into better forms
        words = [pretransform(word) for word in words]
        # 3. filter all `''` (null string)
        words = [word for word in words if word != '']
        # 4. set corresponding variables
        self.text = text
        self.freq_dist = dict(nltk.FreqDist(words))
        self.words = np.array(sorted(list(self.freq_dist)))
        self.num_words = len(self.words)
        self.last_assigned_topics = np.zeros(self.words.shape)
        self.assigned_topics = np.zeros(self.words.shape)
        self.word_ids = None
        self.wid_freq_dist = None
        

class Dataset:
    def __init__(self, text_list, stopwords_set, num_topics):
        self.text_list = text_list
        self.stopwords_set = set(stopwords.words('english'))
        self.doc_list = self.preprocess()
        self.dictionary = self.get_dictionary()
        self.num_topics = num_topics
        self.num_docs = len(self.doc_list)
        self.num_words = len(self.dictionary)
        # wid here means word id
        self.word_to_wid = {self.dictionary[i]: i for i in range(self.num_words)}
        self.set_word_ids_and_freq()
        
    def set_word_ids_and_freq(self):
        for doc in self.doc_list:
            doc.word_ids = np.array([self.word_to_wid[word] for word in doc.words])
            doc.wid_freq_dist = Counter(doc.word_ids)
        
    def get_dictionary(self):
        dictionary = set()
        for doc in self.doc_list:
            dictionary |= set(doc.words)
        return np.array(sorted(list(dictionary)))
    
    def preprocess(self):
        count_vectorizer = CountVectorizer()
        count_vectorizer.fit(raw_text_list)
        doc_list = []
        for text in self.text_list:
            doc = Doc(text, count_vectorizer)
            if doc.num_words > 0:
                doc_list.append(doc)
        return doc_list
    
    def count(self):
        n_word = np.zeros((self.num_topics, self.num_words))
        n_topic = np.zeros((self.num_docs, self.num_topics))
        # TODO: this could be paralleled, cuda or multiprocessing
        for doc_id, doc in enumerate(self.doc_list):
            for wid_in_doc, word in enumerate(doc.words):
                word_freq = doc.freq_dist[word]
                topic_id = doc.assigned_topics[wid_in_doc]
                wid = self.word_to_wid[word]
                n_word[topic_id, wid] += 1 #word_freq
                n_topic[doc_id, topic_id] += 1 # word_freq
        return n_word, n_topic
        
        
dataset = Dataset(raw_text_list, stopwords_set=stopwords_set, num_topics=num_topics)


In [4]:
print(dataset.doc_list[0].words)
print(dataset.doc_list[0].freq_dist)
print(dataset.num_docs)
print((dataset.dictionary.shape))

['austria' 'biased' 'blessing' 'commited' 'degree' 'europe' 'europeans'
 'existance' 'government' 'guilt' 'holocaust' 'incidences' 'israeli'
 'israels' 'jews' 'letter' 'nad' 'occured' 'power' 'reason' 'rediculous'
 'reputation' 'statement' 'story' 'subsidizing' 'world']
{'austria': 1, 'biased': 1, 'blessing': 1, 'commited': 1, 'degree': 1, 'europe': 1, 'europeans': 1, 'existance': 1, 'government': 1, 'guilt': 1, 'holocaust': 1, 'incidences': 1, 'israeli': 2, 'israels': 2, 'jews': 1, 'letter': 1, 'nad': 1, 'occured': 1, 'power': 1, 'reason': 1, 'rediculous': 1, 'reputation': 1, 'statement': 1, 'story': 1, 'subsidizing': 1, 'world': 1}
970
(11455,)


## LDA and Gibbs Sampling

### Gibbs Sampling in LDA

Here is the corresponding conditional probability:
$$P(z_i=k'|z_{-i},\alpha,\eta)\propto\frac{\eta_{v'}+n_{k'v'}}{\sum_v\eta_v+n_{k'v}}\cdot \frac{\alpha_{k'}+n_{d'k'}}{\sum_k\alpha_k+n_{d'k}}$$

Then just follow the algorithm below, we can get all $\textbf{z}_i$

```python
for i in range(T):
    do gibbs sample
```

### Get $\theta_d$ and $\beta_k$

$\theta_d$ and $\beta_k$ can be updated with this estimation:
$$\theta_{dk}=\frac{n_{dk}+\alpha_k}{\sum_k(n_{dk}+\alpha_k)},\qquad\beta_{kv}=\frac{n_{kv}+\eta_v}{\sum_v(n_{kv}+\eta_v)}$$

In [5]:
import numpy as np
from tqdm import tqdm


class LDA:
    def __init__(self, num_topics, num_keywords, alpha, eta):
        self.num_topics = num_topics
        self.num_keywords = num_keywords
        self.alpha = np.array(alpha)
        self.eta = np.array(eta)
    def train(self, dataset: Dataset, max_epoches = 1):
        # Assign topic to each word randomly
        for doc in dataset.doc_list:
            doc.assigned_topics = np.random.randint(low=0, high=self.num_topics-1, size=doc.assigned_topics.shape)
        # Gibbs Sampling
        n_word, n_topic = dataset.count()
        
        last_theta = np.zeros((dataset.num_docs, dataset.num_topics))
        last_beta = np.zeros((dataset.num_topics, dataset.num_words))
        last_error = np.inf
        print('=====================================================================================')
        for epoch_id in range(max_epoches):
            for doc_id, doc in (enumerate(dataset.doc_list)):
                # get the conditional probability for this doc, i.e. `pz`
                pz = np.zeros((doc.num_words, dataset.num_topics))
                doc_eta = self.eta[doc.word_ids]
                for doc_wid, wid in enumerate(doc.word_ids):
                    doc_n_word = n_word[:, doc.word_ids]
                    for topic_id in range(dataset.num_topics):
                        pz[doc_wid][topic_id] = (doc_eta[doc_wid] + n_word[topic_id, wid]) / np.sum(doc_eta + doc_n_word[topic_id]) \
                                                * (self.alpha[topic_id] + n_topic[doc_id, topic_id]) / np.sum(self.alpha + n_topic[doc_id])
                # update the topic assignment to each word, i.e. `doc.assigned_topics`
                for doc_wid in range(doc.num_words):
                    # wzh's sampling
                    doc.assigned_topics[doc_wid] = np.argmax(np.random.multinomial(1, pz[doc_wid]/pz[doc_wid].sum(), size=1))
            # calculate theta and beta, to see whether they are convergent
            n_word, n_topic = dataset.count()
            theta = self.get_theta(dataset, n_topic)
            beta = self.get_beta(dataset, n_word)
            
            total_count_words = 0
            still_pred = 0
            for doc_id, doc in enumerate(dataset.doc_list):
                still_pred += np.sum((doc.assigned_topics - doc.last_assigned_topics) != 0)
                total_count_words += doc.assigned_topics.shape[0]
                doc.last_assigned_topics = doc.assigned_topics.copy()
            error = still_pred / total_count_words
            
            if epoch_id % 1 == 0:
#                 print(self.get_topics(dataset))
                print(f'epoch {epoch_id}', error, np.abs(last_theta - theta).sum(), np.abs(last_beta - beta).sum())
#                 print('=====================================================================================')
#             if 0 < np.abs(last_error - error) < 1e-7:
#                 break
            last_theta = theta
            last_beta = beta
            last_error = error
            
    def get_theta(self, dataset, n_topic):
        theta = np.zeros((dataset.num_docs, dataset.num_topics))
        for doc_id, doc in enumerate(dataset.doc_list):
            theta[doc_id] = (n_topic[doc_id] + self.alpha) / np.sum(n_topic[doc_id] + self.alpha)
        return theta
    
    def get_beta(self, dataset, n_word):
        beta = np.zeros((dataset.num_topics, dataset.num_words))
        for topic_id in range(dataset.num_topics):
            beta[topic_id] = (n_word[topic_id] + self.eta) / np.sum(n_word[topic_id] + self.eta)
        return beta
    
    def get_topics(self, dataset, topk=10):
        n_word, n_topic = dataset.count()
        beta = self.get_beta(dataset, n_word)
        top_wid = np.argsort(-beta)[:, :topk]
        topics = dataset.dictionary[top_wid]
        return topics
    
from datetime import datetime
np.set_printoptions(linewidth=120)
eta = np.ones(dataset.num_words)
lda = LDA(num_topics=num_topics, num_keywords=num_keywords, alpha=alpha, eta=eta)
start = datetime.now()
lda.train(dataset, max_epoches=1000)
print(datetime.now() - start)
print(lda.get_topics(dataset))

epoch 0 0.9517414601473543 970.0 19.999999999999996
epoch 1 0.882786336235767 325.2585051199026 2.8584646290876643
epoch 2 0.8782652377762894 328.49268231225307 2.8348924645752906
epoch 3 0.8749162759544541 322.23937240852734 2.8298618186197064
epoch 4 0.8798727394507703 322.86276699324105 2.844087969684036
epoch 5 0.87357669122572 323.52736159271797 2.8290981567617597
epoch 6 0.8767916945746819 325.61418716205395 2.8334027202826233
epoch 7 0.8725050234427327 324.5473170140963 2.8159619801274474
epoch 8 0.872906898861353 326.9642061281044 2.832406756978396
epoch 9 0.8790689886135298 322.6309011650695 2.8263298708158358
epoch 10 0.875753516409913 326.40144605067326 2.8378613717245735
epoch 11 0.8752176825184192 327.25252094245474 2.8427030043318715
epoch 12 0.871868720696584 322.2806366946076 2.830711749596029
epoch 13 0.8764567983924983 328.97194910658385 2.816756689041937
epoch 14 0.8721701272605492 324.687502057449 2.830820316174429
epoch 15 0.8730073677160081 323.44869627472536 2.85

KeyboardInterrupt: 