# Cria grafos de palavras expandidas

In [None]:
#### Imports
import pandas as pd
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from matplotlib.backends.backend_pdf import PdfPages
import ast

## Funções úteis

In [None]:
def encontrar_palavras_distintas(nodes):
    original_phrase = None
    original_idf = None

    for phrase, attributes in nodes:
        if attributes['type'] == 'original':
            original_phrase = phrase
            original_idf = ast.literal_eval(attributes['idf'])
            break

    original_words = original_phrase.split()
    resultado = ""
    palavras_identificadas = set()

    # Encontra as palavras diferentes da frase original nas frases expandidas
    for phrase, attributes in nodes:
        if attributes['type'] == 'expanded':
            expanded_words = phrase.split()
            # Identifica as palavras da frase original que não estão na frase expandida
            palavras_distintas = set(original_words).difference(expanded_words)

            # Constrói a string de resultado com as palavras distintas e seus IDFs
            novas_palavras = palavras_distintas.difference(palavras_identificadas)

            if novas_palavras:
                for palavra in novas_palavras:
                    idx = original_words.index(palavra)
                    resultado += f"IDF da palavra '{palavra}': {original_idf[idx]:.4f}\n"
                # Adiciona as palavras novas ao conjunto de palavras identificadas
                palavras_identificadas.update(novas_palavras)

    return resultado

def format_label(query_original, query_expandida, precision):
    original_words = set(query_original.split())
    expanded_words = query_expandida.split()
    formatted_words = [
        r"$\mathbf{" + word + "}$" if word not in original_words else word
        for word in expanded_words
    ]
    formatted_text = " ".join(formatted_words)
    return f"{formatted_text}\n({precision:.3f})"



#### Recupera o dataset com as queries, expansões e semelhança entre elas

In [None]:
queries_df = pd.read_csv("../1_enrich_results/queries_train_judged_expanded_enriched.csv", sep="\t")
queries_df.drop(columns=['Unnamed: 0'], inplace=True)
len(queries_df)
queries_df.head(2)


In [None]:
# Mantém apenas registros com um número mínimo de documentos julgados
filtered_df = queries_df[queries_df['relevant_count'] >= 5]


In [None]:
## Agrupa os registros original e expandidos de uma mesma query
grouped_data = filtered_df.groupby('query_idx')
grouped_data

## Plota o grafo

In [None]:
num_plots = 4  #   <------------------------------------- ESCOLHA AQUI O NÚMERO DE PLOTS A SEREM GERADOS
pdf_path = 'expansion_plots.pdf' # os plots serão salvos aqui
pdf = PdfPages(pdf_path)

# Itera sobre grupos de linhas provenientes da mesma query original ('idx_query_original')
count_plots = 1
for idx_query_original, group in grouped_data:
    data = group

    data['expanded_words'] = data.apply(lambda row: [(word, row['query_expandida'].split().index(word)) for word in row['query_expandida'].split() if word not in set(row['query_original'].split())], axis=1)

    edges = []
    for _, row in data.iterrows():
        if np.isnan(row['spearman']):
            continue

        edges.append((row['query_original'], row['query_expandida'], {
            'spearman': row['spearman'],
            'avg_precision_query_original': row['avg_precision_query_original'],
            'avg_precision_query_expansao': row['avg_precision_query_expansao'],
            'idf_original_words':row['idf_original_values'],
            'num_passagens':row['relevant_count']
        }))

    if not edges:
        continue

    
    #Create the graph
    G = nx.Graph()

    for edge in edges:
        G.add_node(edge[0], avg_precision=edge[2]['avg_precision_query_original'], type='original', idf=edge[2]['idf_original_words'], num_passagens=edge[2]['num_passagens'])
        G.add_node(edge[1], avg_precision=edge[2]['avg_precision_query_expansao'], type='expanded')
        G.add_edge(edge[0], edge[1], spearman=edge[2]['spearman'])

    pos = nx.spring_layout(G)

    fig, ax = plt.subplots(figsize=(15, 15))

    original_nodes = [node for node, attr in G.nodes(data=True) if attr['type'] == 'original']
    expanded_nodes = [node for node, attr in G.nodes(data=True) if attr['type'] == 'expanded']

    nx.draw_networkx_nodes(G, pos, nodelist=original_nodes, node_color='blue', node_size=700, label='Original')
    nx.draw_networkx_nodes(G, pos, nodelist=expanded_nodes, node_color='green', node_size=700, label='Expanded')

    edge_colors = [G[u][v]['spearman'] for u, v in G.edges()]
    if edge_colors:
        range_edge_colors = max(edge_colors) - min(edge_colors)
        if range_edge_colors == 0:
            edge_colors_normalized = [0.5 for _ in edge_colors]
        else:
            edge_colors_normalized = [(value - min(edge_colors)) / range_edge_colors for value in edge_colors]

        cmap = plt.cm.viridis
        colors = [cmap(color) for color in edge_colors_normalized]

        nx.draw_networkx_edges(G, pos, edge_color=colors, width=2)

        edge_labels = {(u, v): f"{G[u][v]['spearman']:.3f}" for u, v in G.edges()}
        nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_color='red')

        sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=min(edge_colors), vmax=max(edge_colors)))
        sm.set_array([])
        plt.colorbar(sm, ax=ax, label='Correlação Spearman')

    original_labels = {node: f"{node}\n({attr['avg_precision']:.3f})" for node, attr in G.nodes(data=True) if attr['type'] == 'original'}
    expanded_labels = {node: format_label(edge[0], node, attr['avg_precision']) for node, attr in G.nodes(data=True) if attr['type'] == 'expanded'}

    def adjust_labels(pos, x_shift=0, y_shift=0.05):
        return {node: (coord[0] + x_shift, coord[1] + y_shift) for node, coord in pos.items()}

    label_pos = adjust_labels(pos, y_shift=0.13)

    nx.draw_networkx_labels(G, label_pos, labels=original_labels, font_color='black')
    nx.draw_networkx_labels(G, label_pos, labels=expanded_labels, font_color='black')

    # Recupera o número de julgamentos (e de passagens) utilizados
    num_passagens_text = ""
    for node, attr in G.nodes(data=True):
        if attr['type'] == 'original':
            num_passagens_text = f"K: {attr['num_passagens']}\n"

    # Calcula o IDF das palavras expandidas
    idf_original_words = encontrar_palavras_distintas(G.nodes(data=True))

    box_text = " INFORMAÇÕES ADICIONAIS \n" + num_passagens_text + idf_original_words
    plt.gcf().text(0.132, 0.14, box_text, fontsize=12, verticalalignment='bottom', bbox=dict(facecolor='white', alpha=0.5))

    plt.title(f"Performance e correlação da expansão da query {idx_query_original}")

    handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='blue', markersize=10, label='Query Original'),
               plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='green', markersize=8, label='Query Expandida')]

    plt.legend(handles=handles, loc='upper right')

    plt.xlim(-1.8, 1.8) #plt.xlim(-2.5, 2.5)
    plt.ylim(-1.8, 1.8) #plt.ylim(-2.5, 2.5)

    pdf.savefig()
    if count_plots >= num_plots:
        break
    else:
        count_plots += 1

pdf.close()
