## Visualization file for ZeroShot and Mallet

In [1]:
import pandas as pd
import os
import numpy as np
from scipy.sparse import csr_matrix, vstack
import pyLDAvis as vis


In [2]:
def read_ZS(path_ZS):
    '''
    Returns a dictionary for the pyLDAvis module.
    Takes the path to the parent folder ZS_results (path_ZS)
    and searches the parameters inside ZS_output.
    '''
    input_path = os.path.join(path_ZS, 'ZS_output')
    
    if not os.path.exists(input_path):
        raise Exception('Path not found in PC!')

    else:
        search_items={
            'betas':['betas.npy','topic_term_dists'],
            'thetas':['thetas.npy', 'doc_topic_dists'],
            'doc_len':['doc_len.npy', 'doc_lengths'],
            'term_freq':['term_freq.npy', 'term_frequency']
        }
        results = {}

        for item in search_items:
            doc = np.load(os.path.join(input_path, search_items[item][0]))
            results.update({search_items[item][1]:doc})


        vocab_path = os.path.join(input_path, 'vocab.txt')
        with open(vocab_path, 'r', encoding='utf-8') as f:
            vocab = [line.strip() for line in f.readlines()]

        results['vocab'] = vocab 

    return results

In [7]:
def read_mallet(path_mallet):
    '''
    Returns a dictionary for the pyLDAvis module.
    Takes the path to the parent folder mallet_folder (path_mallet)
    and searches the parameters inside mallet_output.
    '''
    input_path = os.path.join(path_mallet, 'mallet_output')
    
    if not os.path.exists(input_path):
        raise Exception('Path not found in PC!')

    else:
        search_items={
            'betas':['betas.npy','topic_term_dists'],
            'thetas_en':['thetas_EN.npz', 'doc_topic_dists_en'],
            'thetas_es':['thetas_ES.npz', 'doc_topic_dists_es'],
        }
        results = {}

        for item in search_items:
            doc = np.load(os.path.join(input_path, search_items[item][0]))
            results.update({search_items[item][1]:doc})

        #in order to get the doc-topic we have to transform back from -npz
        #to a matrix format, we do it in this lines
        aux = results['doc_topic_dists_en']
        #Reshaping of the auxiliar variable
        dense_vec_en = csr_matrix((aux['data'], aux['indices'], aux['indptr']), shape=aux['shape'])
        results['doc_topic_dists_en'] = dense_vec_en.toarray()

        aux = results['doc_topic_dists_es']
        #Reshaping of the auxiliar variable
        dense_vec_es = csr_matrix((aux['data'], aux['indices'], aux['indptr']), shape=aux['shape'])
        results['doc_topic_dists_es'] = dense_vec_es.toarray()

        doc_topic_matrix = vstack([dense_vec_en, dense_vec_es])

        # Convert to dense
        results['doc_topic_dists'] = doc_topic_matrix.toarray()

        #Get the vocab and frequency, both stored in vocab.txt
        vocab_path = os.path.join(input_path, 'vocab.txt')

        vocab_df = pd.read_csv(vocab_path, sep='\t', header = None)

        results['vocab'] = vocab_df[0]
        results['term_frequency'] = vocab_df[1] 

        results['doc_lengths_en'] = np.round(results['doc_topic_dists_en'].sum(axis=1)).astype(int)
        results['doc_lengths_es'] = np.round(results['doc_topic_dists_es'].sum(axis=1)).astype(int)
        results['doc_lengths'] = np.round(results['doc_topic_dists'].sum(axis=1)).astype(int)





    return results

In [9]:
vis_inputs = read_mallet('/export/usuarios_ml4ds/ammesa/mallet_folder')

visuals = vis.prepare(topic_term_dists=vis_inputs['topic_term_dists'],
                      doc_topic_dists=vis_inputs['doc_topic_dists_en'],
                      doc_lengths=vis_inputs['doc_lengths_en'],
                      vocab=vis_inputs['vocab'],
                      term_frequency=vis_inputs['term_frequency'])
vis.display(visuals)



  result = func(self.values, **kwargs)
  by='saliency', ascending=False).head(R).drop('saliency', 1)
  result = func(self.values, **kwargs)


In [5]:
vis_inputs_zs = read_ZS('/export/usuarios_ml4ds/ammesa/ZS_results')
visuals = vis.prepare(topic_term_dists=vis_inputs_zs['topic_term_dists'],
                      doc_topic_dists=vis_inputs_zs['doc_topic_dists'],
                      doc_lengths=vis_inputs_zs['doc_lengths'],
                      vocab=vis_inputs_zs['vocab'],
                      term_frequency=vis_inputs_zs['term_frequency'])
vis.display(visuals)

  by='saliency', ascending=False).head(R).drop('saliency', 1)
