In [429]:
import topmost.utils
import topmost.utils._utils
from topmost import eva
import numpy as np

class Beta_env():
    def __init__(self, dataset, max_steps, num_topics, embed_size, reference_corpus, num_top_words = 5, random = False):

        self.num_topics = num_topics
        self.embed_size = embed_size
        self.random_init = random
        self.step_counter = 0
        self.max_steps = max_steps
        self.num_top_words = num_top_words
        self.reference_corpus = reference_corpus
        self.dataset = dataset
        self.vocab = dataset.vocab
        self.vocab_size = len(self.vocab)
        # self.vocab_size = 100

        self.str_to_embeds = {
            dataset.vocab[i]:dataset.pretrained_WE[i]
            for i in range(0, self.vocab_size)
        }
        
        self.beta, self.topic_embeds = self.get_starting_state()


    def get_starting_state(self):
        if self.random_init:
            beta = np.random.normal(loc = 0, scale=1, size = (self.num_topics, self.vocab_size))
            topic_embeds = np.random.normal(loc = 0, scale = 1, size = (self.num_topics, self.embed_size))
        else:
            beta, topic_embeds = None, None
        return beta, topic_embeds

    def reset(self):
        self.step_counter = 0
        self.beta, self.topic_embeds = self.get_starting_state()
        return np.concatenate((self.beta, self.topic_embeds), axis = 1)
    
    def cosine_similarity(self, a, b):
        a = np.asarray(a)
        b = np.asarray(b)
        if np.linalg.norm(a) == 0 or np.linalg.norm(b) == 0:
            return np.asarray([0])
        return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
    
    def compute_cosine_similarity(self, list_txt):
        result = 0
        for i in range(self.num_topics):
            for word in list_txt[i].split():
                we = self.str_to_embeds[word]
                result += self.cosine_similarity(we, self.topic_embeds[i, :]).item()
        return result
        
    def calculate_reward(self):
        top_words = self.get_top_words(False)
        total_TD =  eva.topic_diversity._diversity(top_words)

        total_cosine_similarity = self.compute_cosine_similarity(top_words)
        return (total_TD + total_cosine_similarity / (self.num_topics * self.num_top_words)) / 2
        

    def get_top_words(self, verbose = True):
        top_word_list = topmost.utils._utils.get_top_words(self.beta, self.vocab, num_top_words=self.num_top_words, verbose=verbose)
        return top_word_list

    def step(self, action):
        self.beta += action
        new_state = np.concatenate((self.beta, self.topic_embeds), axis = 1)
        reward = self.calculate_reward()
        self.step_counter += 1
        done = False
        if self.step_counter == self.max_steps:
            done = True
            self.step_counter = 0
        
        self.state = new_state
        return self.state, reward, done

In [430]:
from topmost.data import download_dataset

# device = "cuda" if torch.cuda.is_available() else "cpu"
dataset_dir = "./datasets/20NG"
download_dataset('20NG', cache_path='./datasets')

dataset = topmost.BasicDataset(dataset_dir)

100%|██████████| 11.9M/11.9M [00:00<00:00, 18.2MB/s]


train_size:  11314
test_size:  7532
vocab_size:  5000
average length: 110.543


In [431]:
env = Beta_env(
    dataset = dataset,
    max_steps=3,
    num_topics=5,
    embed_size=200,
    reference_corpus=dataset.train_texts,
    random = True
)
env.calculate_reward()

0.497744489060825

In [432]:
env.get_top_words()

Topic 0: tiff tcp returned either releases
Topic 1: collapse particularly coast stops activities
Topic 2: fathers christopher members cutting possibilities
Topic 3: friends owns terrorist mormons funding
Topic 4: people examined unlike line husband


['tiff tcp returned either releases',
 'collapse particularly coast stops activities',
 'fathers christopher members cutting possibilities',
 'friends owns terrorist mormons funding',
 'people examined unlike line husband']

In [433]:
action = np.random.normal(loc = 0, scale = 1, size = env.beta.shape)
env.step(action)

(array([[ 2.14733779, -0.05049328,  1.67646017, ..., -1.28830018,
          0.40203184, -0.12363293],
        [-2.81581479, -0.16239032,  1.78195767, ..., -2.09482724,
          0.76477559,  0.58207478],
        [-3.19513548, -0.0234362 ,  0.16951657, ...,  1.69616629,
         -1.16462347,  1.39557918],
        [ 2.21971839, -0.08290676,  1.97690823, ..., -0.57450664,
          0.93220443,  1.50757706],
        [-0.54415543, -0.35779767, -1.82209127, ..., -0.72354627,
          0.32849414, -1.44888498]]),
 0.5123433505875354,
 False)

In [435]:
env.reset().shape

(5, 5200)