# This Python Notebook visualizes the patent embeddings in 2D space using scatter plots.

In [1]:
import pandas as pd
# Import random
import random

from sentence_transformers import SentenceTransformer

from umap import UMAP

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

  @numba.jit()
  @numba.jit()
  @numba.jit()
  @numba.jit()


In [12]:
# Control variables

patentSBERTa_model = SentenceTransformer('AI-Growth-Lab/PatentSBERTa')
own_model = SentenceTransformer('own_models/patentCL_epochs_10_sample_100K_2023-06-22_17-38/')
triplet_csv = "data/triplet_training.csv"

SAMPLE = "100K"
EPOCHS = 10
PRUNE_DF = True

In [4]:
# Load the triplet csv file into a pandas dataframe
triplet_df = pd.read_csv(triplet_csv)

# Drop all columns except a_id, a_title, a_abstract, cpc_id
triplet_df = triplet_df.drop(['p_id', 'p_title', 'p_abstract', 'n_id', 'n_title', 'n_abstract'], axis=1)

# Create a new column a_title_abstract by concatenating a_title [SEP] a_abstract
triplet_df['a_title_abstract'] = triplet_df['a_title'] + " [SEP] " + triplet_df['a_abstract']

In [5]:
# Generate embeddings for a_title_abstract with the patentSBERTa model and with the own model and add them as columns to the dataframe
print("Generate embeddings for a_title_abstract with the patentSBERTa model and with the own model and add them as columns to the dataframe")
triplet_df['patentSBERTa_embedding'] = patentSBERTa_model.encode(triplet_df['a_title_abstract'].tolist(), show_progress_bar=True).tolist()
triplet_df['own_model_embedding'] = own_model.encode(triplet_df['a_title_abstract'].tolist(), show_progress_bar=True).tolist()

# Perform UMAP on the embeddings
print("Perform UMAP on the embeddings")
umap = UMAP(n_neighbors=5, min_dist=0.3, metric='cosine')
triplet_df['patentSBERTa_umap'] = umap.fit_transform(triplet_df['patentSBERTa_embedding'].tolist()).tolist()
triplet_df['own_model_umap'] = umap.fit_transform(triplet_df['own_model_embedding'].tolist()).tolist()

# Save the dataframe to a csv file
print("Save the dataframe to a csv file")
df_safename = f"triplet_training_{SAMPLE}_E{EPOCHS}_embeddings.csv"
triplet_df.to_csv(df_safename, index=False)

Generate embeddings for a_title_abstract with the patentSBERTa model and with the own model and add them as columns to the dataframe


Batches:   0%|          | 0/2573 [00:00<?, ?it/s]

Batches:   0%|          | 0/2573 [00:00<?, ?it/s]

Perform UMAP on the embeddings


In [None]:
if PRUNE_DF:
    # Get a subset of the dataframe with only 10000 rows
    print("[PRUNE] Get a subset of the dataframe with only 10000 rows")
    triplet_df = triplet_df.sample(n=20000, random_state=1)

# Create a color_dict by assigning a color to each individual cpc_id by random
print("Create a color_dict by assigning a color to each individual cpc_id by random")
color_dict = {}
for cpc_id in triplet_df['cpc_id'].unique():
    color_dict[cpc_id] = "rgb(" + str(random.randint(0, 255)) + "," + str(random.randint(0, 255)) + "," + str(
        random.randint(0, 255)) + ")"

# Count the number of occurrences of each cpc_id in the triplet_df
cpc_counts = triplet_df['cpc_id'].value_counts()

# Sort the cpc_counts in descending order and get the resulting index
sorted_cpc_ids = cpc_counts.index.tolist()

# Create a plotly figure with two columns for two scatter plots
fig = make_subplots(rows=1, cols=2, subplot_titles=("patentSBERTa", "own model"))

print("Create a plotly figure with two columns for two scatter plots")
# Add the patentSBERTa scatter plot by plotting the x and y coordinates of the umap embeddings. Color the points by cpc_id
for cpc_id in sorted_cpc_ids:
    fig.add_trace(go.Scatter(
        x=triplet_df.loc[triplet_df['cpc_id'] == cpc_id, 'patentSBERTa_umap'].apply(lambda x: x[0]),
        y=triplet_df.loc[triplet_df['cpc_id'] == cpc_id, 'patentSBERTa_umap'].apply(lambda x: x[1]),
        # Add the cpc_id as well as the a_title to the hover text
        text="[" + triplet_df.loc[triplet_df['cpc_id'] == cpc_id, 'cpc_id'] + "] " + triplet_df.loc[triplet_df['cpc_id'] == cpc_id, 'a_title'],
        hoverinfo='text',
        mode='markers',
        marker=dict(
            size=3,
            color=color_dict[cpc_id],
            opacity=0.9,
        ),
        name=cpc_id,
        legendgroup=cpc_id
    ), row=1, col=1)

# Add the own model scatter plot by plotting the x and y coordinates of the umap embeddings. Color the points by cpc_id. Add it in the second column of the figure.
for cpc_id in sorted_cpc_ids:
    fig.add_trace(go.Scatter(
        x=triplet_df.loc[triplet_df['cpc_id'] == cpc_id, 'own_model_umap'].apply(lambda x: x[0]),
        y=triplet_df.loc[triplet_df['cpc_id'] == cpc_id, 'own_model_umap'].apply(lambda x: x[1]),
        text="[" + triplet_df.loc[triplet_df['cpc_id'] == cpc_id, 'cpc_id'] + "] " + triplet_df.loc[triplet_df['cpc_id'] == cpc_id, 'a_title'],
        hoverinfo='text',
        mode='markers',
        marker=dict(
            size=3,
            color=color_dict[cpc_id],
            opacity=0.9,
        ),
        name=cpc_id,
        legendgroup=cpc_id
    ), row=1, col=2)

# Calculate x and y ranges. 
# First calculate the min and max values of the x and y coordinates of the umap embeddings of both scatter plots.
# I want the same range for both scatter plots. Add a 10% margin.
x_min_patentSBERTa = triplet_df['patentSBERTa_umap'].apply(lambda x: x[0]).min()
x_max_patentSBERTa = triplet_df['patentSBERTa_umap'].apply(lambda x: x[0]).max()
x_min_own_model = triplet_df['own_model_umap'].apply(lambda x: x[0]).min()
x_max_own_model = triplet_df['own_model_umap'].apply(lambda x: x[0]).max()
x_range = [min(x_min_patentSBERTa, x_min_own_model) * 1.1, max(x_max_patentSBERTa, x_max_own_model) * 1.1]
y_min_patentSBERTa = triplet_df['patentSBERTa_umap'].apply(lambda x: x[1]).min()
y_max_patentSBERTa = triplet_df['patentSBERTa_umap'].apply(lambda x: x[1]).max()
y_min_own_model = triplet_df['own_model_umap'].apply(lambda x: x[1]).min()
y_max_own_model = triplet_df['own_model_umap'].apply(lambda x: x[1]).max()
y_range = [min(y_min_patentSBERTa, y_min_own_model) * 1.1, max(y_max_patentSBERTa, y_max_own_model) * 1.1]


# Set the xaxis and yaxis ranges of both scatter plots to be the same
fig.update_layout(
    title='UMAP embeddings of patentSBERTa and own model',
    xaxis=dict(
        title='x',
        gridcolor='white',
        gridwidth=2,
        range=[x_range, y_range]
    ),
    yaxis=dict(
        title='y',
        gridcolor='white',
        gridwidth=2,
        range=[x_range, y_range]
    ),
    xaxis2=dict(
        title='x',
        gridcolor='white',
        gridwidth=2,
        range=[x_range, y_range]
    ),
    yaxis2=dict(
        title='y',
        gridcolor='white',
        gridwidth=2,
        range=[x_range, y_range]
    ),
    paper_bgcolor='rgb(243, 243, 243)',
    plot_bgcolor='rgb(243, 243, 243)',
    legend=dict(
        orientation='v',
        yanchor='top',
        y=1.0,
        xanchor='right',
        x=1.0,
        # Set the legend order based on the sorted_cpc_ids list
        traceorder='normal',
        tracegroupgap=20
    )
)

# Define a callback function to update the axis ranges when a legend item is clicked
def update_legend(trace, points, state):
    # Get the current x and y axis ranges
    x_range = fig['layout']['xaxis']['range']
    y_range = fig['layout']['yaxis']['range']
    
    # If only one legend item is selected, update the axis ranges based on the selected data points
    if len(points.point_inds) == 1:
        # Get the selected cpc_id
        cpc_id = trace['name']
        
        # Get the x and y coordinates of the selected data points
        x_data = trace['x'][points.point_inds[0]]
        y_data = trace['y'][points.point_inds[0]]
        
        # Update the x and y axis ranges based on the selected data points
        x_range = [x_data - (x_data - x_range[0]) * 0.1, x_data + (x_range[1] - x_data) * 0.1]
        y_range = [y_data - (y_data - y_range[0]) * 0.1, y_data + (y_range[1] - y_data) * 0.1]
    
    # Update the axis ranges of the plot
    fig.update_layout(xaxis=dict(range=x_range), yaxis=dict(range=y_range))

# Add the callback function to each trace in the figure
for trace in fig.data:
    trace.on_click(update_legend)

save_name = f"patent_embedding_scatter_sample_{str(SAMPLE)}.html"

# Save the figure as html
fig.write_html(save_name)

# Show the figure
fig.show()