In [1]:
import nltk

nltk.download('reuters')

[nltk_data] Downloading package reuters to /Users/felix/nltk_data...
[nltk_data]   Package reuters is already up-to-date!


True

In [None]:
! unzip -q /usr/share/nltk_data/corpora/reuters.zip -d /usr/share/nltk_data/corpora/

In [131]:
import os
import nltk
import numpy as np
import random
import torch
from tqdm import tqdm

from nltk.corpus import stopwords
from collections import Counter

nltk.download("stopwords")
nltk.download("punkt")
nltk.download("punkt_tab")

def lowercase_tokenizer(text):
    return [t.lower() for t in nltk.word_tokenize(text)]


def create_data_set(dir="training", min_freq=10, corpus_limit=200000, number_of_topics = 15):

    path = os.path.join(os.getcwd(), "reuters", dir)
    files = os.listdir(path)
    files = sorted(files, key=lambda x: int(x))

    docs = []
    word_counter = Counter()
    total_words = 0

    for file in files:

        
        file = os.path.join(path, file)
        with open(file, 'r') as f:
            raw_file = f.readlines()

            file_words = []

            for raw_line in raw_file:
                new_words = lowercase_tokenizer(raw_line)
                file_words.extend(new_words)

            docs.append(file_words)
            word_counter.update(file_words)
            total_words += len(file_words)

        if total_words > corpus_limit:
            break

    uncommon_words = [item for item, count in word_counter.items() if count <= min_freq]

    stop_words = set(stopwords.words("english"))

    words_to_remove = stop_words.union(set(uncommon_words))

    topic_map = {}

    for i, doc in enumerate(docs):
        new_doc = []
        j = 0
        for word in doc:
            if word not in words_to_remove:
                new_doc.append(word)    
                topic_map[f"{i},{j}"] = np.random.randint(0, number_of_topics)
                j += 1
        docs[i] = new_doc

    return docs, topic_map




[nltk_data] Downloading package stopwords to /Users/felix/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /Users/felix/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /Users/felix/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [132]:
def create_word_mappings(docs):
    words = set()
    for doc in docs:
        words.update(doc)
    str_to_int = {}
    int_to_str = {}
    for i, word in enumerate(words):
        str_to_int[word] = i
        int_to_str[i] = word
    return str_to_int, int_to_str

def get_n_d_k(docs, topic_map, k):
    n_d_k = np.zeros((len(docs),k))
    for i, doc in enumerate(docs):
        for j in range(len(doc)):
            topic = topic_map[f"{i},{j}"]
            n_d_k[i, topic] += 1
    return n_d_k

def get_m_k_v(docs, topic_map, k, str_to_int):
    m_k_v = np.zeros((k, len(str_to_int)))
    for i, doc in enumerate(docs):
        for j, word in enumerate(doc):
            topic = topic_map[f"{i},{j}"]
            m_k_v[topic, str_to_int[word]] += 1
            
    return m_k_v

def n_dj_k(topic_map, n_d_k, d, j, k):
    
    topic = topic_map[f"{d},{j}"]

    update = 0
    if topic == k:
        update = -1

    count = n_d_k[d, k]

    if count == 0 and update == -1:
        print(topic)
        print(k)
        print(n_d_k[d, k])

    return  count + update

def m_dj_w(topic_map, m_k_v, d, j, k, w_index):
    
    topic = topic_map[f"{d},{j}"]

    update = 0
    if topic == k:
        update = -1

    count = m_k_v[k, w_index]

    return  count + update

def m_dj(topic_map, m_k, d, j, k):
    
    topic = topic_map[f"{d},{j}"]

    update = 0
    if topic == k:
        update = -1

    count = m_k[k]

    return  count + update




In [133]:
nr_iterations = 2000
k = 15
alpha = 0.1
beta = 0.1

docs, topic_map = create_data_set(number_of_topics=k)
d = len(docs)

str_to_int, int_to_str = create_word_mappings(docs)

n_d_k = get_n_d_k(docs, topic_map, k)
m_k_v = get_m_k_v(docs, topic_map, k, str_to_int)
m_k = m_k_v.sum(axis=1)



In [134]:
# Training loop
def train(nr_iterations, m_k, n_d_k, m_k_v):
    for iteration in tqdm(range(nr_iterations)):
        r_d = 0    
        doc_length = 0
        while doc_length == 0:
            r_d = np.random.randint(0, d)
            doc_length = len(docs[r_d])

        r_j = np.random.randint(0, len(docs[r_d]))

        q = np.zeros(k)
        p = np.zeros(k)
        w_index = str_to_int[docs[r_d][r_j]]
        vocab_len = len(str_to_int)

        for k_i in range(k):
            temp_n_dj_k = n_dj_k(topic_map, n_d_k, r_d, r_j, k_i)
            temp_m_dj_w = m_dj_w(topic_map, m_k_v, r_d, r_j, k_i, w_index)
            temp_m_dj = m_dj(topic_map, m_k, r_d, r_j, k_i)
            
            q[k_i] = (alpha+temp_n_dj_k)*(beta+temp_m_dj_w) / (vocab_len*beta + temp_m_dj)

        q_sum = q.sum()
        for k_i in range(k):
            p[k_i] = q[k_i] / q_sum

        
        dist = torch.distributions.Categorical(torch.tensor(p))
        new_z = dist.sample().item()

        old_topic = topic_map[f"{r_d},{r_j}"]
        topic_map[f"{r_d},{r_j}"] = new_z

        n_d_k[r_d, old_topic] -= 1
        m_k_v[old_topic, w_index] -= 1

        n_d_k[r_d, new_z] += 1
        m_k_v[new_z, w_index] += 1

        m_k = m_k_v.sum(axis=1)
        

    return m_k, topic_map

In [94]:
nr_iterations = 100 * 10000
m_k, topic_map = train(nr_iterations, m_k, n_d_k, m_k_v)

100%|██████████| 1000000/1000000 [01:05<00:00, 15300.39it/s]


In [None]:

def get_related_document_words(docs, topic_map):

    related_words = []
    for i, doc in enumerate(docs):
        topic_splits = {}
        for j, word in enumerate(doc):
            topic = topic_map[f"{i},{j}"]
            current_topic_words = topic_splits.get(topic, [])
            current_topic_words.append(word)
            topic_splits[topic] = current_topic_words
        related_words.append(topic_splits)

    return related_words

related_words = get_related_document_words(docs, topic_map)

print(related_words[0])


{9: ['cocoa', 'drought', 'come', 'around', 'export', 'march', 'dlrs', 'made', ',', ',', 'york', 'dec', 'crop', 'expected'], 0: ['review', ',', '60', 'still', 'cocoa', 'bags', 'lower', 'rose', 'dlrs', 'dlrs', 'may', 'dlrs', 'new', ',', 'areas', ',', '1.25', 'new', ',', 'bags', '1987/88', 'trade', '.'], 3: ['continued', 'since', '.', '.', 'making', 'included', 'smith', 'still', 'mln', 'much', 'shipment', 'york', 'convertible', 'dlrs', 'new', 'brazilian'], 1: ['week', ',', 'review', 'february', 'last', 'good', 'held', '.', 'dlrs', 'named', '.', 'light', 'tonne', 'dlrs', 'dec', 'u.s.', 'new', '.', 'york', 'said', 'february'], 13: ['cocoa', ',', 'smith', 'week', 'crop', 'sales', 'offer', 'per', '35', '45', 'per', 'new', 'dlrs', 'times', ',', '.', 'currency', 'ends'], 11: ['early', '.', 'much', 'total', 'spot', 'limited', 'crop', 'open', 'july', '.', 'dlrs', 'selling', 'dlrs', 'times', 'total'], 14: ['january', 'figures', '.', 'estimates', '.', 'new', '.', '.', 'new', 'sales', 'limited', '1.