In [None]:
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords
import numpy as np
from collections import defaultdict
import torch

import spacy


In [None]:
torch.manual_seed(2024)

In [None]:
documents= ['Cristiano Ronaldo and Lionel Messi are both great player of football', 'People also admire Neymar and Ramos for their football skills',
'USA and China both are powerful countries', 'China is building a strong economy', 'India is also emerging as one of most developing country by promoting football at global scale']

In [None]:
for d in documents:
    print(d)

In [None]:
documents = [document.split() for document in documents]

In [None]:
stop_words = set(stopwords.words('english'))
stop_words.update(['also', 'one'])
documents = [[word for word in text if word not in stop_words] for text in documents]



In [None]:
rules = spacy.load('en_core_web_sm') #needs downloading beforehand
lemmatized_docs = []
for d in documents: 
    lemmas= [rules(word)[0].lemma_ for word in d]
    lemmatized_docs.append(lemmas)
    documents = lemmatized_docs


In [None]:
for d in documents:
    print(d)

In [None]:
def assign_word_id(docs):
    set_of_words = sorted(list(set([word for doc in docs for word in doc])))
    

    id_dictionary= {word: id for id,  word in enumerate(set_of_words)}
    return id_dictionary

In [None]:
word2id =assign_word_id(documents)

In [None]:
ls_of_words = [word for doc in documents for word in doc]

In [None]:
print(ls_of_words)

In [None]:
frequency_dict=defaultdict(int)
for word in sorted(ls_of_words):
    frequency_dict[word] +=1


In [None]:
print(frequency_dict)

In [None]:
def topic_word_calculate(docs,z, num_of_topics):
    topic_word = torch.zeros((num_of_topics, len(word2id)))
    for d, doc in enumerate(docs):
        for n, word in enumerate(doc):
            word_id= word2id.get(word)
            
            
            topic_id = z[d][n]
            topic_word[topic_id][word_id] +=1
    return topic_word

In [None]:
def collapsed_Gibbs(docs, num_of_topics=4, passes =5, alpha= 0.1, beta=0.01):
    max_len = max([len(d) for d in docs])
        
    z = torch.randint(0, num_of_topics, (len(docs), max_len)) #assigns word of each document to topic

    for d, doc in enumerate(docs):
        for i in range(len(doc), max_len):
            z[d][i] = num_of_topics
    

    document_topic = torch.nn.functional.one_hot(z).sum(dim=1) #counts words assigned to each topic per document (Nd,k)
    document_topic= document_topic[:,:num_of_topics]
    words_per_topic =  torch.sum(document_topic, dim=0)#(Nk)
    topic_word = topic_word_calculate(docs, z, num_of_topics) #counts which word is assigned per topic (Nk,w)
    

    for p in range(passes):
        for d, doc in enumerate(docs):
            for w, word in enumerate(doc):
                word_id = word2id.get(word)
                topic_id = z[d][w]
                document_topic[d][topic_id] -=1
                topic_word[topic_id][word_id] -=1
                words_per_topic[topic_id] -=1
                
            
                p = torch.zeros(num_of_topics)
                for topic in range(num_of_topics):
                    p[topic]= ((topic_word[topic][word_id] +beta) * (document_topic[d][topic]+ alpha)) /(words_per_topic[topic] + beta * len(word2id))
                

                #p /= p.sum() #Normalize the probability vector 
                
                topic = torch.multinomial(p, 1, replacement=True).item() 
                #print(topic)
                z[d][w] = topic 
                #print(p)
                document_topic[d][topic] +=1
                topic_word[topic][word_id] +=1
                words_per_topic[topic] +=1
        #print(document_topic)
            
    return document_topic, topic_word, words_per_topic

In [None]:
document_topic, topic_word, words_per_topic= collapsed_Gibbs(documents, 2, 1000)

In [None]:
print(document_topic, topic_word)