In [None]:
import re
import numpy as np
import pandas as pd

import umap

import plotly.express as px
import plotly.graph_objs as go

import dash
from dash import dcc, html, State, Patch
from dash.dependencies import Input, Output

from sklearn.neighbors import LocalOutlierFactor
from sklearn.preprocessing import MinMaxScaler

from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim

from textdemo import utils, report

# Fix problem with showing LaTeX in plotly figures
import plotly
from IPython.display import display, HTML
# https://github.com/microsoft/vscode-jupyter/issues/8131#issuecomment-1589961116
plotly.offline.init_notebook_mode()
display(HTML(
    '<script type="text/javascript" async src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-MML-AM_SVG"></script>'
))

# Use autoreload magic so that .py files can modified without having to restart the kernel
%load_ext autoreload
%autoreload 2

In [None]:
# Load the csv with the abstracts
df = pd.read_csv('../data/DT_review.csv', delimiter=';')
df.dropna(subset=['title'], inplace=True)
df.drop('id', axis=1, inplace=True)
df['keywords'] = df['keywords'].astype(str)
df.head()

In [None]:
# Load the embedding model
model = SentenceTransformer('thenlper/gte-large')

In [None]:
# Compute embeddings for every abstract
print("Computing gte-embeddings for the abstracts...")
EMBEDDINGS = model.encode(df['abstract'].tolist())

# Compute 2-D UMAP embeddings for visualization
print("Computing umap-embeddings for the abstracts...")
reducer = umap.UMAP(n_neighbors=10, min_dist=0., random_state=0)
EMBEDDINGS_UMAP = reducer.fit_transform(EMBEDDINGS)

# Calculate LOF scores to see single outliers
print("Computing lof-scores for the abstracts...")
lof = LocalOutlierFactor(n_neighbors=10)
lof.fit(EMBEDDINGS)
LOF_SCORES = lof.negative_outlier_factor_
# Normalize the LOF score
scaler = MinMaxScaler()
LOF_SCORES = scaler.fit_transform(LOF_SCORES.reshape(-1, 1)).flatten()

In [None]:
# Get word embeddings for each used word in all abstracts
UNIQUE_WORDS = utils.get_unique_words(df['abstract'].tolist())
WORD_EMBEDDINGS = model.encode(UNIQUE_WORDS)

# Get the average similarity between every word and all abstracts
WORD_SIMILARITIES = cos_sim(EMBEDDINGS, WORD_EMBEDDINGS).numpy()
avg_word_similarities = WORD_SIMILARITIES.mean(axis=0)

In [None]:
# Get embeddings and similarities for all pre-defined topics
TOPICS = [
    'energy efficiency and emission',
    'structural loads and stresses',
    'dynamics and maneuverability',
    'model-based simulation',
    'digital representation, manufacturing, industry',
    'network and communication, internet of things',
]
TOPICS_BREAKS = [
    'energy efficiency<br>and emission',
    'structural loads<br>and stresses',
    'dynamics and<br>maneuverability',
    'model-based<br>simulation',
    'digital representation,<br>manufacturing, industry',
    'network and communication,<br>internet of things',
]
TOPIC_EMBEDDINGS = model.encode(TOPICS)
TOPIC_SIMILARITIES = cos_sim(EMBEDDINGS, TOPIC_EMBEDDINGS).numpy()
scaler = MinMaxScaler()
TOPIC_SIMILARITIES = scaler.fit_transform(TOPIC_SIMILARITIES)

In [None]:
# Dash app settings
top_n = 20
alpha = 0.8
colors = px.colors.qualitative.Plotly
rgba_colors = [report.hex_to_rgba(hex_color, alpha) for hex_color in colors]

# Define the custom color scale
colorscale = [
    [0, 'white'],  # Start with yellow at the lowest value (0)
    [0.5, 'white'],  # Start with yellow at the lowest value (0)
    [1, colors[0]]     # End with grey at the highest value (1)
]

FONTSIZE = 16
HEIGHT = 500
margin = dict(l=60, r=20, t=20, b=50, pad=0)
fig_settings = {
    "height": HEIGHT,  #
    # "width": HEIGHT,
    "paper_bgcolor": "rgba(255, 255, 255, 0)",
    "margin": margin,
    "legend": dict(
        x=1, y=1,
        xanchor="right",
        yanchor="top",
        bgcolor="rgba(255, 255, 255, 0.5)",
        font=dict(size=FONTSIZE),
        itemwidth=30,
        tracegroupgap=10,
    ),
    "font": dict(size=FONTSIZE),
    "font_family": "arial",
    "title": dict(font=dict(size=FONTSIZE)),
}

In [None]:
app = dash.Dash(__name__)

# UMAP scatter figure
def get_scatter_figure():

    fig = go.Figure()
    fig.add_trace(
        go.Scatter(
            x=EMBEDDINGS_UMAP[:, 0],
            y=EMBEDDINGS_UMAP[:, 1],
            hovertext=[str(i) for i in np.arange(EMBEDDINGS_UMAP.shape[0])],
            hoverinfo="text",
            mode='markers',
            marker=dict(
                color=LOF_SCORES,
                size=10,
                opacity=0.8,
                colorscale=colorscale,
                # colorbar={'thickness': 20}
            ),
            showlegend=False)
    )
    fig.update_xaxes(showticklabels=False, automargin=False),
    fig.update_yaxes(showticklabels=False, automargin=False),
    fig.update_layout(fig_settings, dragmode="lasso")
    return fig

# Bar graph showing mean simialrity between words and selected abstracts
def get_bar_figure():

    margin_tmp = margin.copy()
    margin_tmp['l'] = 200
    fig_settings_bar = fig_settings.copy()
    fig_settings_bar["margin"] = margin_tmp
    fig_settings_bar["plot_bgcolor"] = "rgba(255, 255, 255, 0)"
    avg_word_similarities = WORD_SIMILARITIES.mean(axis=0)

    # Show the top N most similar words and their average similarity score
    sort_order = avg_word_similarities.argsort()[::-1]
    x = avg_word_similarities[sort_order[:top_n]]
    y = []
    for i in range(top_n):
        idx = sort_order[i]
        y.append(UNIQUE_WORDS[idx])
        # y.append(i)
    
    fig = go.Figure()
    fig.add_trace(
        go.Bar(
            x=np.flip(np.array(x)),
            y=np.flip(np.array(y)),
            orientation='h',
            marker=dict(color=rgba_colors[0])
        )
    )
    fig.update_layout(fig_settings_bar)
    # fig.update_xaxes(autorangeoptions_clipmin=0.7, showticklabels=False)
    fig.update_xaxes(range=[0.77, 0.87], showticklabels=False)
    return fig

# Spider plot showig how similar an abstract (or selection) is to each topic
def get_spider_figure():

    avg_topic_similarties = TOPIC_SIMILARITIES.mean(axis=0)
    
    fig = go.Figure()
    fig.add_trace(
        go.Scatterpolar(
            r = np.append(avg_topic_similarties, avg_topic_similarties[0]),
            # theta = np.append(np.arange(len(TOPICS))*delta_theta, 0) + delta_theta/2,
            theta = TOPICS + [TOPICS[0]],
            mode = 'markers+lines',
            marker=dict(color='gray'),
            name='Mean'
        )
    )
    fig.add_trace(
        go.Scatterpolar(
            r = np.full(len(TOPICS)+1, np.nan),
            theta = TOPICS + [TOPICS[0]],
            mode = 'markers+lines',
            marker=dict(color='black'),
            name='Hover'
        )
    )
    fig.update_polars(radialaxis_range=[0, 1], radialaxis_showticklabels=False)
    fig.update_polars(angularaxis_labelalias={k: v for k, v in zip(TOPICS, TOPICS_BREAKS)})
    fig.update_layout(fig_settings)
    return fig

# App layout
app.layout = html.Div([
    html.Div([ # Row
        html.Div([ # Text input field
            dcc.Input(
                id='text-input', 
                type='text', 
                placeholder='Enter text...', 
                style={'marginLeft': '20%', 'width': '100%', 'height': '30px', 'verticalAlign': 'bottom'}),
            ], style={'width': '25%', 'display': 'inline-block'}),
        html.Div([ # Submit button
            html.Button(
                'Submit', 
                id='submit-button', 
                n_clicks=0,
                style={'marginLeft': '70%', 'height': '30px', 'width': '100%', 'verticalAlign': 'bottom'}),
            ], style={'width': '10%', 'display': 'inline-block'}),
        html.Div([ # Dropdown meny
            dcc.Dropdown(
                options=["LOF scores"] + TOPICS, 
                value="LOF scores", 
                id='dropdown-selection',
                style={'marginLeft': '20%', 'height': '30px', 'width': '100%', 'verticalAlign': 'bottom'}),
            ], style={'width': '45%', 'display': 'inline-block', 'lineHeight': '30px'}),
    ], style={'width': '50%', 'height': '30px'}),
    html.Div([  # Row
        html.Div([  # Column 1: Graph
            dcc.Graph(
                id='scatter-plot',
                config={'displayModeBar': False},
                figure=get_scatter_figure()
            )
        ], style={'width': '50%', 'display': 'inline-block', 'height': f'{HEIGHT}px'}),
        html.Div([  # Column 2: Abstract
            html.Div(id='abstract-text', children="Hover over a point to see the abstract.", style={'overflowY': 'scroll', 'height': f'{HEIGHT-50}px', 'vertical-align': 'top'})
        ], style={'width': '50%', 'display': 'inline-block', 'vertical-align': 'top'})
    ]),
    html.Div([  # Row
        html.Div([  # Column 1: bar graph
            dcc.Graph(
                id='bar-plot',
                config={'displayModeBar': False},
                figure=get_bar_figure()
            )
        ], style={'width': '50%', 'display': 'inline-block', 'height': f'{HEIGHT}px'}),
        html.Div([  # Column 2: spider plot
            dcc.Graph(
                id='spider-plot',
                config={'displayModeBar': False},
                figure=get_spider_figure()
            )
        ], style={'width': '50%', 'display': 'inline-block', 'height': f'{HEIGHT}px'}),
    ])
])

# Callbacks for interactivity
@app.callback(
    Output('scatter-plot', 'figure', allow_duplicate=True),
    Input('submit-button', 'n_clicks'),
    State('text-input', 'value'),
    prevent_initial_call=True,
)
def input_similarity_color_coding(n_clicks, input_text):

  patched_fig = Patch()

  if not input_text:
    return patched_fig

  # Compute the embeddings of the input text
  input_embedding = model.encode(input_text)
  input_similarities = cos_sim(EMBEDDINGS, input_embedding)
  patched_fig["data"][0]["marker"]["color"] = input_similarities.flatten()

  return patched_fig

    
@app.callback(
    Output('scatter-plot', 'figure'),
    Input("dropdown-selection", "value"),
)
def dropdown_color_coding(dropdown_value):
    patched_fig = Patch()
    if dropdown_value == "LOF scores":
        color_coding = LOF_SCORES
    else:
        position  = TOPICS.index(dropdown_value)
        color_coding = TOPIC_SIMILARITIES[:, position].flatten()

    patched_fig["data"][0]["marker"]["color"] = color_coding
    return patched_fig


@app.callback(
    Output('bar-plot', 'figure'),
    Output('spider-plot', 'figure', allow_duplicate=True),
    Input("scatter-plot", "selectedData"),
    prevent_initial_call=True,
)
def on_selection(select_data):

    patched_bar_fig = Patch()
    patched_spider_fig = Patch()
    
    if select_data is not None and len(select_data['points']) > 0:

        selected_points = [ p['pointIndex'] for p in select_data['points'] ]
        selected_abstracts = df.loc[selected_points, 'abstract'].tolist()
    
        abstract_words = get_unique_words(selected_abstracts)
    
        target_dict = {t: True for t in abstract_words}
        positions = {idx: s for idx, s in enumerate(UNIQUE_WORDS) if s in target_dict}
    
        word_idxs = list(positions.keys())
        words = list(positions.values())
    
        similarities = WORD_SIMILARITIES[:, word_idxs]
        similarities = similarities[selected_points, :]
        avg_similarities = similarities.mean(axis=0)
        sort_order = avg_similarities.argsort()[::-1]
        
        y = []
        x = avg_similarities[sort_order[:top_n]]
        for i in range(top_n):
            idx = sort_order[i]
            y.append(words[idx])
    
        topic_similariteis_tmp = TOPIC_SIMILARITIES[selected_points, :].mean(axis=0)
        patched_spider_fig["data"][1]["r"] = np.append(topic_similariteis_tmp, topic_similariteis_tmp[0])
        patched_spider_fig["data"][1]["name"] = 'Selection'

    else:
        sort_order = avg_word_similarities.argsort()[::-1]
        x = avg_word_similarities[sort_order[:top_n]]
        y = []
        for i in range(top_n):
            idx = sort_order[i]
            y.append(UNIQUE_WORDS[idx])

    patched_bar_fig["data"][0]["x"] = np.flip(np.array(x))
    patched_bar_fig["data"][0]["y"] = np.flip(np.array(y))
    
    return patched_bar_fig, patched_spider_fig


@app.callback(
    Output('abstract-text', 'children'),
    Output('spider-plot', 'figure'),
    Input('scatter-plot', 'hoverData')
)
def on_hover(hover_data):

    patched_fig = Patch()

    if hover_data is None:
        return "Hover over a point to see the abstract.", patched_fig

    # Get the index of the hovered point
    point_index = hover_data['points'][0]['pointIndex']
    title = df.loc[point_index, 'title']
    year = df.loc[point_index, 'year']
    authors = df.loc[point_index, 'authors']
    journal = df.loc[point_index, 'journal']
    abstract = df.loc[point_index, 'abstract']
    keywords = df.loc[point_index, 'keywords']
    citations = df.loc[point_index, 'citations']

    child_html = [
        html.H3(title),
        html.P([html.Strong("Year:"), f" {int(year) if (year and not np.isnan(year)) else 'Not Available'}"]),
        html.P([html.Strong("Abstract:"), f" {abstract if abstract else 'Not Available'}"]),
        html.P([html.Strong("Citations:"), f" {int(citations) if citations else 'Not Available'}"]),
        html.P([html.Strong("Authors:"), f" {authors if authors else 'Not Available'}"]),
        html.P([html.Strong("Journal:"), f" {journal if journal else 'Not Available'}"]),
        html.P([html.Strong("Keywords:"), f" {keywords if (keywords and keywords != 'nan') else 'Not Available'}"])
    ]

    topic_similariteis_tmp = TOPIC_SIMILARITIES[point_index, :]
    patched_fig["data"][1]["r"] = np.append(topic_similariteis_tmp, topic_similariteis_tmp[0])
    patched_fig["data"][1]["name"] = 'Hover'

    return child_html, patched_fig

app.run_server(debug=False, jupyter_mode="external")

In [None]:
config = report.get_plotly_config()
fig = report.get_scatter_plot(EMBEDDINGS_UMAP, TOPICS, TOPIC_SIMILARITIES)
fig = report.add_panel_annotation(fig, 'a')
fig.show()
fig.write_image(f"umap_scatter_plot.png", scale=config['png_scaling'])

In [None]:
fig = report.get_spider_plot(TOPICS, TOPIC_SIMILARITIES, TOPICS_BREAKS)
fig = report.add_panel_annotation(fig, 'b', scale=6)
fig.show()
fig.write_image(f"umap_spider_plot.png", scale=config['png_scaling'])