In [1]:
import nltk

nltk.download('reuters')

[nltk_data] Downloading package reuters to /Users/felix/nltk_data...
[nltk_data]   Package reuters is already up-to-date!


True

In [None]:
! unzip -q /usr/share/nltk_data/corpora/reuters.zip -d /usr/share/nltk_data/corpora/

In [245]:
import os
import nltk
import numpy as np
import torch
from tqdm import tqdm
import string

from nltk.corpus import stopwords
from collections import Counter

nltk.download("stopwords")
nltk.download("punkt")
nltk.download("punkt_tab")

def lowercase_tokenizer(text):
    return [t.lower() for t in nltk.word_tokenize(text)]


def get_stopwords():
    stop = set(stopwords.words("english"))
    stop = stop.union(set(string.punctuation))
    stop = stop.union({'``', "''", "'s", "dlrs", "pct", "cts", 'lt', 'mln'})
    return stop


def create_data_set(dir="training", min_freq=10, corpus_limit=200000, number_of_topics = 15):

    path = os.path.join(os.getcwd(), "reuters", dir)
    files = os.listdir(path)
    files = sorted(files, key=lambda x: int(x))

    docs = []
    word_counter = Counter()
    total_words = 0

    for file in files:

        
        file = os.path.join(path, file)
        with open(file, 'r') as f:
            raw_file = f.readlines()

            file_words = []

            for raw_line in raw_file:
                new_words = lowercase_tokenizer(raw_line)
                file_words.extend(new_words)

            docs.append(file_words)
            word_counter.update(file_words)
            total_words += len(file_words)

        if total_words > corpus_limit:
            break

    uncommon_words = [item for item, count in word_counter.items() if count <= min_freq]

   
    stop_words =  get_stopwords()

    words_to_remove = stop_words.union(set(uncommon_words))

    topic_map = {}

    for i, doc in enumerate(docs):
        new_doc = []
        j = 0
        for word in doc:
            if word not in words_to_remove:
                new_doc.append(word)    
                topic_map[f"{i},{j}"] = np.random.randint(0, number_of_topics)
                j += 1
        docs[i] = new_doc

    return docs, topic_map, total_words




[nltk_data] Downloading package stopwords to /Users/felix/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /Users/felix/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /Users/felix/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [246]:
def create_word_mappings(docs):
    words = set()
    for doc in docs:
        words.update(doc)
    str_to_int = {}
    int_to_str = {}
    for i, word in enumerate(words):
        str_to_int[word] = i
        int_to_str[i] = word
    return str_to_int, int_to_str

def get_n_d_k(docs, topic_map, k):
    n_d_k = np.zeros((len(docs),k))
    for i, doc in enumerate(docs):
        for j in range(len(doc)):
            topic = topic_map[f"{i},{j}"]
            n_d_k[i, topic] += 1
    return n_d_k

def get_m_k_v(docs, topic_map, k, str_to_int):
    m_k_v = np.zeros((k, len(str_to_int)))
    for i, doc in enumerate(docs):
        for j, word in enumerate(doc):
            topic = topic_map[f"{i},{j}"]
            m_k_v[topic, str_to_int[word]] += 1
            
    return m_k_v

def n_dj_k(topic_map, n_d_k, d, j, k):
    
    topic = topic_map[f"{d},{j}"]

    update = 0
    if topic == k:
        update = -1

    count = n_d_k[d, k]

    if count == 0 and update == -1:
        print(topic)
        print(k)
        print(n_d_k[d, k])

    return  count + update

def m_dj_w(topic_map, m_k_v, d, j, k, w_index):
    
    topic = topic_map[f"{d},{j}"]

    update = 0
    if topic == k:
        update = -1

    count = m_k_v[k, w_index]

    return  count + update

def m_dj(topic_map, m_k, d, j, k):
    
    topic = topic_map[f"{d},{j}"]

    update = 0
    if topic == k:
        update = -1

    count = m_k[k]

    return  count + update




In [247]:
nr_iterations = 2000
k = 15
alpha = 0.1
beta = 0.1

docs, topic_map, total_words = create_data_set(number_of_topics=k)
d = len(docs)

str_to_int, int_to_str = create_word_mappings(docs)

n_d_k = get_n_d_k(docs, topic_map, k)
m_k_v = get_m_k_v(docs, topic_map, k, str_to_int)
m_k = m_k_v.sum(axis=1)



In [249]:
# Training loop
def train(nr_iterations, nr_words, m_k, n_d_k, m_k_v):
    for _ in tqdm(range(nr_iterations)):
        for _ in tqdm(range(nr_words)):
            r_d = 0    
            doc_length = 0
            while doc_length == 0:
                r_d = np.random.randint(0, d)
                doc_length = len(docs[r_d])

            r_j = np.random.randint(0, len(docs[r_d]))

            q = np.zeros(k)
            p = np.zeros(k)
            w_index = str_to_int[docs[r_d][r_j]]
            vocab_len = len(str_to_int)

            for k_i in range(k):
                temp_n_dj_k = n_dj_k(topic_map, n_d_k, r_d, r_j, k_i)
                temp_m_dj_w = m_dj_w(topic_map, m_k_v, r_d, r_j, k_i, w_index)
                temp_m_dj = m_dj(topic_map, m_k, r_d, r_j, k_i)
                
                q[k_i] = (alpha+temp_n_dj_k)*(beta+temp_m_dj_w) / (vocab_len*beta + temp_m_dj)

            q_sum = q.sum()
            for k_i in range(k):
                p[k_i] = q[k_i] / q_sum

            
            dist = torch.distributions.Categorical(torch.tensor(p))
            new_z = dist.sample().item()

            old_topic = topic_map[f"{r_d},{r_j}"]
            topic_map[f"{r_d},{r_j}"] = new_z

            n_d_k[r_d, old_topic] -= 1
            m_k_v[old_topic, w_index] -= 1

            n_d_k[r_d, new_z] += 1
            m_k_v[new_z, w_index] += 1

            m_k = m_k_v.sum(axis=1)
            

    return m_k, topic_map

In [250]:
nr_iterations = 100
m_k, topic_map = train(nr_iterations, total_words, m_k, n_d_k, m_k_v)

100%|██████████| 200172/200172 [00:12<00:00, 15513.01it/s]
100%|██████████| 200172/200172 [00:12<00:00, 15525.71it/s]
100%|██████████| 200172/200172 [00:12<00:00, 15517.21it/s]
100%|██████████| 200172/200172 [00:12<00:00, 15536.50it/s]
100%|██████████| 200172/200172 [00:12<00:00, 15575.01it/s]
100%|██████████| 200172/200172 [00:12<00:00, 15561.84it/s]
100%|██████████| 200172/200172 [00:13<00:00, 15205.97it/s]
100%|██████████| 200172/200172 [00:13<00:00, 15330.77it/s]
100%|██████████| 200172/200172 [00:12<00:00, 15490.37it/s]
100%|██████████| 200172/200172 [00:13<00:00, 15269.22it/s]
100%|██████████| 200172/200172 [00:12<00:00, 15488.84it/s]
100%|██████████| 200172/200172 [00:12<00:00, 15435.97it/s]
100%|██████████| 200172/200172 [00:12<00:00, 15496.77it/s]
100%|██████████| 200172/200172 [00:12<00:00, 15404.08it/s]
100%|██████████| 200172/200172 [00:13<00:00, 15330.00it/s]
 61%|██████    | 121713/200172 [00:08<00:05, 15053.18it/s]
 15%|█▌        | 15/100 [03:22<19:07, 13.50s/it]


KeyboardInterrupt: 

In [None]:

def get_related_document_words(docs, topic_map):

    related_words = []
    for i, doc in enumerate(docs):
        topic_splits = {}
        for j, word in enumerate(doc):
            topic = topic_map[f"{i},{j}"]
            topic_splits.setdefault(topic, []).append(word)
        related_words.append(topic_splits)

    return related_words

related_words = get_related_document_words(docs, topic_map)

def get_top_words_per_topic(related_words, n_top_words=20):
    combined_doc_topics = {}
    for doc in related_words:
        for key, array in doc.items():
            combined_doc_topics.setdefault(key, []).extend(array)

    
    most_common_words = {}
    for topic, words in combined_doc_topics.items():
        cnt = Counter()
        cnt.update(words)
        most_common_words[topic] = [(token, count) for token, count in cnt.most_common(n_top_words)]

    return most_common_words

common_words_topic = get_top_words_per_topic(related_words)

[('said', 258), ('trade', 121), ('u.s.', 66), ('would', 62), ('group', 56), ('year', 55), ('united', 54), ('market', 49), ('states', 48), ('official', 47), ('agreement', 45), ('officials', 42), ('talks', 41), ('exports', 40), ('reserves', 38), ('prices', 36), ('brazil', 35), ('francs', 35), ('told', 34), ('foreign', 33)]
[('vs', 154), ('year', 149), ('1986', 140), ('share', 136), ('net', 130), ('oper', 123), ('profit', 99), ('quarter', 98), ('1985', 83), ('loss', 82), ('said', 67), ('billion', 63), ('earnings', 61), ('per', 60), ('gain', 60), ('sales', 52), ('excludes', 52), ('operating', 48), ('fourth', 44), ('results', 42)]


In [289]:

def occurance_counts(docs):
    doc_sets = [set(doc) for doc in docs]


    unique_words = set()
    for doc_set in doc_sets:
        unique_words.update(doc_set)

    cnt = Counter()
    for doc in doc_sets:
        cnt.update(doc)


    co_occurance_count = Counter()
    for doc_words in doc_sets:
        for w1 in doc_words:
            for w2 in doc_words:
                if w1 != w2:
                    co_occurance_count[f"{w1},{w2}"] += 1
    
    return dict(co_occurance_count), dict(cnt), unique_words


co_occurance, occurance, unique = occurance_counts(docs)


def umass(co_occurance, occurance, common_words_topic):
    score = 0
    for m in range(1, len(common_words_topic)):
        w1, _ = common_words_topic[m]
        for l in range(0, m-1):
            w2, _ = common_words_topic[l]
            co_occur = co_occurance.get(f"{w1},{w2}", 0) + 1
            d_count = occurance.get(f"{w2}", 0)
            score += np.log(np.divide(co_occur, d_count))
    return score

In [288]:
scores = np.zeros(k)
for t in range(0, k):
    scores[t] = umass(co_occurance, occurance, common_words_topic[t])
    print(f"The topic umass score for topic {t} is: ", scores[t])

mean_score = scores.mean()
print(mean_score)

The topic umass score for topic 0 is:  -353.8812591629674
The topic umass score for topic 1 is:  -296.86238630568556
The topic umass score for topic 2 is:  -229.36632482638413
The topic umass score for topic 3 is:  -302.22042756174267
The topic umass score for topic 4 is:  -341.70019076495794
The topic umass score for topic 5 is:  -320.7918371021073
The topic umass score for topic 6 is:  -304.2129490152345
The topic umass score for topic 7 is:  -411.4289093326071
The topic umass score for topic 8 is:  -318.9911538696502
The topic umass score for topic 9 is:  -310.4094581609175
The topic umass score for topic 10 is:  -385.4344118186441
The topic umass score for topic 11 is:  -371.51263853035
The topic umass score for topic 12 is:  -349.7841910794869
The topic umass score for topic 13 is:  -349.2477080470174
The topic umass score for topic 14 is:  -416.23438040609324
-337.47188173225635
