# Visualisation simple de topic models

Projet : metromap

Auteur : Julien Velcin, Laboratoire ERIC, Université Lyon 2

Dans ce notebook, le jeu de données a déjà été prétraité et le modèles thématique entraîné. Cette version "light" ne permet pas d'accéder aux données textuelles, juste au modèle.

In [15]:
import math
import numpy as np
import importlib
import json
import ldamallet
import modules.topicmodel
import modules.dataset
import modules.models

#import plotly.express as px
from dash import Dash, dcc, callback_context, html, dash_table, Input
import dash_cytoscape as cyto
import dash_bootstrap_components as dbc
#import plotly.graph_objects as go
from dash.dependencies import Input, Output, State

import networkx as nx 
from networkx import degree_centrality, eigenvector_centrality, betweenness_centrality, pagerank, katz_centrality, clustering

## Preprocessing

Chargement du jeu de données

In [16]:
with open('config/datasets.json') as f:
    datasets = json.load(f)

In [17]:
data = {}
num_topics = {}
generate = {}
mult = {}

light = True
#light = False

for k, v in datasets.items():
    if v["load"] == "yes":
        dataname = v["dataset_name"]
        data[dataname] = modules.dataset.dataset(dataname, v["file_name"])
        load_done = False
        if light:
            if (data[dataname].is_corpus_light()):
                # si le fichier existe, on le charge
                data[dataname].load_corpus_light()
                load_done = True
                print("chargé light")
        else:
            if (data[dataname].is_corpus()):
                # si le fichier existe, on le charge
                data[dataname].load_corpus()
                load_done = True
                print("chargé full")                
        if not load_done:
            lang="english"
            if "lang" in v:
                lang=v["lang"]
            src_file = "txt"
            if "source_file" in v:
                src_file=v["source_file"]
            if src_file == "txt":
                data[dataname].read_file_txt(lang=lang)
            elif src_file == "csv":
                data[dataname].read_file_csv(lang=lang)
            else:
                print("erreur de lecture")
            data[dataname].clean()
            data[dataname].build_corpus() 
            if light:
                print("save light")                
                data[dataname].save_corpus_light() # on sauvegarde le corpus nouvellement créé (version light)
            else:
                print("save full")                                
                data[dataname].save_corpus() # on sauvegarde le corpus nouvellement créé (version complète)                
        num_topics[dataname] = v["num_topics"]
        if "generate" in v.keys():
            generate[dataname] = v["generate"]
        else:
            generate[dataname] = "no"
        if "mult" in v.keys():
            mult[dataname] = int(v["mult"])
        else:
            mult[dataname] = 1            
        print("OK for corpus " + dataname)
        print("Size: {} sentences with a dictionary of {} words".format(data[dataname].ndocs,len(data[dataname].dico)))

chargé light
OK for corpus dune
Size: 28782 sentences with a dictionary of 5397 words
chargé light
OK for corpus lotr
Size: 26639 sentences with a dictionary of 4596 words
chargé light
OK for corpus legi
Size: 1725695 sentences with a dictionary of 49361 words


## Apprentissage du modèle

Initialisation / chargement de la table des modèles

In [18]:
mod = modules.models.models("tm", data)

In [19]:
mod.get_all_models()

Unnamed: 0,id,data,k,filename
0,1,dune,100,model_dune_k100_1
1,2,lotr,100,model_lotr_k100_2
2,3,legi,100,model_legi_k100_3


Chargement du modèle qui nous intéresse (ici, "legi" en version light càd sans charger les textes)

In [20]:
num_model = 1

mod.load_models([num_model], light=light)

### Visualisation des graphes

In [21]:
# json with additional information for the rendering
with open('config/visu.json') as f:
    visu = json.load(f)
    
max_words_to_display = visu[str(num_model)]["top_term_to_display"]

Chargement / calcul du graphe des thématiques

In [22]:
for m in mod.get_loaded_models():
    # store the top-10 words for each topic (for visualization purpose mainly)
    model = mod.get_model(m)
    model.compute_topwords(num_words=20)
    # normalize the matrix p(z|d) then build the squared matrix z * z
    if light:
        # chargement des triplets directement
        model.load_edge_triplets()
        model.load_pz()
    else:
        model.compute_pz_pw()
        model.normalize_pzd()
        if (str(m) in visu.keys()):        
            model.set_hidden_topics(visu[str(m)]["hidden_topics"])    
        model.compute_edge_triplets()
        model.save_pz()
        model.save_edge_triplets()

# if we want to save the triplets onto a file
#lda_lotr.save_edge_triplets("test.txt")

Construction des graphes (on peut choisir une métrique particulière)

In [23]:
measures = [degree_centrality, eigenvector_centrality, betweenness_centrality, katz_centrality, pagerank, clustering]
measures_names = ["degree_centrality", "eigenvector_centrality", "betweenness_centrality",
                  "katz_centrality", "pagerank", "clustering"]
dico_measures_names = dict(zip(measures, measures_names))

for m in mod.get_loaded_models():
    model = mod.get_model(m)
    model.build_graph(num_top_edges=250)
    model.compute_graph_measures(dico_measures_names, num_values = 10)

m = degree_centrality

Différents prétraitements pour afficher notamment les mots distribués sur plusieurs thématiques (top z pour t)

In [24]:
pz_dup = np.tile(model.pz, (len(model.data.dico),1))
pwz = model.get_topics()

mmm = np.multiply(np.transpose(pwz), pz_dup)
mmm_norm = np.zeros((len(model.data.dico), 100))

for i in range(len(model.data.dico)):
    s = np.sum(mmm[i,:])
    if s!=0:
        mmm_norm[i,:] = mmm[i,:] / s
    else:
        mmm_norm[i,:] = 0
    
def get_pzw(t, i):
    ind = model.data.dico.token2id[t]
    return mmm_norm[ind, i]

def get_pzw_all(t):
    ind = model.data.dico.token2id[t]
    return mmm_norm[ind, :]

def get_top_z_for_t(t):
    cc = {i: v for i, v in enumerate(get_pzw_all(t)) if v>0.1}
    cc = sorted(cc.items(), key=lambda x: x[1], reverse=True)
    return cc

In [25]:
def norm_word(m, z, t, p):
    cc = {m.get_node_name(i):get_pzw(t, i) for i in range(m.k)}
    return p/cc[z]

node_list = sorted(list(model.node_list))
pos = nx.nx_pydot.pydot_layout(model.topic_graph_m[m], prog="neato")
node_colors = nx.get_node_attributes(model.topic_graph_m[m], "color")
node_weights = nx.get_node_attributes(model.topic_graph_m[m], "weight")

#node_topwords = {z:model.top_words[z] for z in pos}
node_topwords = {z:[(t,p,norm_word(model, z, t, p)) for (t,p) in model.top_words[z]] for z in pos}

data_for_graph = []
i=0
for n in pos:
    data_for_graph.append(
        {"data": {"id": n, "label": n, "color":node_colors[n], "weight": 10+node_weights[n]*3,
                  "topwords": node_topwords[n]},
         "position" : {"x": pos[n][0], "y": pos[n][1]}}
    )
    i+=1
for n1, n2 in list(model.topic_graph.edges()):
    data_for_graph.append(
        {"data": {"source": n1, "target": n2}}
    )

In [26]:
def norm_prob(p1, p2, m):
    if p1>m:
        p1 = m
    return math.floor(((p1-p2)/m)*100)

def col_if_prob(p, m):
    if p>m:
        return "danger"
    else:
        return "warning"

def get_table_topic(z, max_words=max_words_to_display):       
    
    #max_prob = np.max([pw for (t,w,pw) in node_topwords[z]])
    max_prob = node_topwords[z][0][2]
    
    rows = []
    for i, (t,w,pw) in enumerate(node_topwords[z]):
        if i>=max_words:
            break
        rows.append(dbc.Col([
            dbc.Row([
                dbc.Col(dbc.Button(t,
                    value=t,
                    id = "but_"+str(i)
                )),
                dbc.Col([dbc.Progress(
                    [
                        dbc.Progress(value=norm_prob(w, 0, max_prob), color="success", bar=True),
                        dbc.Progress(value=norm_prob(pw, w, max_prob), color=col_if_prob(pw, max_prob), bar=True)
                    ],
                    style={'width': 60}, className="mb-3", id="bar-"+t
                ),
                dbc.Tooltip("{:.3f} ({:.3f})".format(w, pw), target="bar-"+t)
                        ])])
        ]))
    table = rows
    
    full_table = dbc.Form(
            [    dbc.Row([#html.Div([
                      dbc.Col(
                      dbc.Button("-", color="primary", size="sm", id='bminus')),
                      dbc.Col(
                          html.P(z, id='topic_label')),    
                      dbc.Col(
                      dbc.Button("+", color="primary", size="sm", id='bplus')), #, id='bplus'
                     dbc.Col(
                      dbc.Input(placeholder=node_list[0], type="text", id="input_z", debounce=True, style={'width': 60}))
            ], className="mb-2")] + table)                      
             #dbc.Row(table, className="mb-2")])

    return full_table

In [None]:
app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])

default_stylesheet = [
            {
                'selector': 'node',
                'style': {
                    'background-color': 'data(color)',
                    'content': 'data(label)',
                    'width': 'data(weight)',
                    'height': 'data(weight)'
                }
            }
        ]

# column 2 = topic graph
col2 = [
        cyto.Cytoscape(
            id='mygraph',
            minZoom=0.5,
            maxZoom=2,
            layout={'name': 'preset'},
            style={'width': '100%', 'height': '600px'},
            stylesheet=default_stylesheet,
            elements=data_for_graph
        ),
        html.P("Not topic selected - ", id='my-output-clic'),
        html.P("Hover:", id='my-output-hover')]

# main panel
app.layout = dbc.Container(
    [
        html.H1("Graphe des thématiques"),
        html.Hr(),
        dbc.Row(
            [
                dbc.Col(get_table_topic(node_list[0]), # column 1
                    md=4, align="start", id="tab_top"),
                dbc.Col(col2, md=8) # column 2
            ],
            align="center"
        ),
        dbc.Input(id="list_high_topics", debounce=True)#, type="hidden")
        #,
        #dbc.Tooltip(
        #    "Selected topic",
        #    target="tab_top",
        #    placement="top"
        #)
    ],
    fluid=True,
)

@app.callback(
    Output(component_id='my-output-clic', component_property='children'),
    Output('mygraph', 'stylesheet'),
    Output('tab_top', 'children'),
    State(component_id='topic_label', component_property='children'),
    Input('input_z', 'value'),
    Input('mygraph', 'tapNodeData'),
    Input('mygraph', 'tapEdgeData'),
    Input('bplus', 'n_clicks'),
    Input('bminus', 'n_clicks'),
    Input("list_high_topics", "children")    
)
def displayTap(cur_z, nz, data_node, data_edge, clic_plus, clic_minus, list_high_topics):

    if not (data_node or data_edge or nz or clic_plus or clic_minus or cur_z):
        default_topic = node_list[0]
        table_topic = get_table_topic(default_topic)
        return default_topic, default_stylesheet, table_topic#, default_topic
    
    new_stylesheet = default_stylesheet.copy()    
    new_z = cur_z
    
    if (clic_plus is not None):
        i = node_list.index(new_z)
        new_ind = i+1
        if new_ind >= len(node_list):
            new_ind = 0
        new_z = node_list[new_ind]

    if (clic_minus is not None):
        i = node_list.index(new_z)
        new_ind = i-1
        if new_ind <= 0:
            new_ind = len(node_list)-1
        new_z = node_list[new_ind]
        
    if (nz is not None) and (nz in node_list):
        new_z = nz

    not_list_high_topics = (list_high_topics is None) or (len(list_high_topics)==0)
        
    if data_node and (not (clic_plus or clic_minus or nz)) and (not_list_high_topics):
        new_z = data_node['id']
        
    s = "New topic in the box: " + new_z

    table_topic = get_table_topic(new_z)  
        
    stylesheet = [{
                "selector": 'node[id = "{}"]'.format(new_z),
                "style": {
                    'background-color': 'yellow',
                    #'content': 'data(topwords)',
                    "border-width": 2,
                    "border-opacity": 1,
                    "opacity": 0.5,
                    "width": 100,
                    "height": 100,
                    'text-halign':'center',
                    'text-valign':'center'
                  }
            }]
        
    if not_list_high_topics:
        
        s1 = set([neighbor for (neighbor, n) in list(model.topic_graph.edges()) if n == new_z])
        s2 = set([neighbor for (n, neighbor) in list(model.topic_graph.edges()) if n == new_z])
        set_neigh = list(s1.union(s2))
        for n in set_neigh:
            stylesheet.append({
                            "selector": 'node[id = "{}"]'.format(n),
                            "style": {
    #                            "opacity": 0.8,
                                'background-color': "#E2C41D"
                                }
                        })
    else:           
        for z,_ in list_high_topics:
            stylesheet.append({
                        "selector": 'node[id = "{}"]'.format(z),
                        "style": {
                            'background-color': "#52c72e",
                            "width": 50,
                            "height": 50
                            }
                    })            
    
    new_stylesheet = new_stylesheet + stylesheet
    return s, new_stylesheet, table_topic

@app.callback(
        Output(component_id='my-output-hover', component_property='children'),
        Output("list_high_topics", "children"),
        State(component_id='topic_label', component_property='children'),
        [Input("but_"+str(i), "n_clicks") for i in range(max_words_to_display)]
    )
def select_word(cur_z, *val):
    if val is not None and len([v for v in val if v is not None])>0:
        trigger = callback_context.triggered[0]
        i = -1
        if trigger is not None:
            s1 = trigger["prop_id"].split(".")
            if (len(s1)>0):
                s2 = s1[0].split("_")
                if (len(s2)>1):
                    i = int(s2[1])
        num_w = i
        t,_,_ = node_topwords[cur_z][num_w]
        s = [(model.get_node_name(z), v) for z, v in get_top_z_for_t(t)]
        return str(s), s
    return "vide", []

app.run_server()

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

 * Serving Flask app '__main__' (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: off


 * Running on http://127.0.0.1:8050 (Press CTRL+C to quit)
127.0.0.1 - - [19/Jul/2022 17:15:49] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [19/Jul/2022 17:15:49] "GET /_dash-layout HTTP/1.1" 200 -
127.0.0.1 - - [19/Jul/2022 17:15:49] "GET /_dash-dependencies HTTP/1.1" 200 -
127.0.0.1 - - [19/Jul/2022 17:15:49] "POST /_dash-update-component HTTP/1.1" 200 -
127.0.0.1 - - [19/Jul/2022 17:15:49] "POST /_dash-update-component HTTP/1.1" 200 -
127.0.0.1 - - [19/Jul/2022 17:15:51] "POST /_dash-update-component HTTP/1.1" 200 -
127.0.0.1 - - [19/Jul/2022 17:15:51] "POST /_dash-update-component HTTP/1.1" 200 -
127.0.0.1 - - [19/Jul/2022 17:15:52] "POST /_dash-update-component HTTP/1.1" 200 -
127.0.0.1 - - [19/Jul/2022 17:15:52] "POST /_dash-update-component HTTP/1.1" 200 -
127.0.0.1 - - [19/Jul/2022 17:15:56] "POST /_dash-update-component HTTP/1.1" 200 -
127.0.0.1 - - [19/Jul/2022 17:15:56] "POST /_dash-update-component HTTP/1.1" 200 -
127.0.0.1 - - [19/Jul/2022 17:15:58] "POST /_dash-update-component 