## 导入库

In [1]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from PIL import ImageFont
from sklearn.metrics.pairwise import cosine_similarity

## 函数

In [2]:
def get_bg_color(hex_color):
    red, green, blue = hex_color[0], hex_color[1], hex_color[2]
    red_half = int(red) // 10 + (255 - 25)
    green_half = int(green) // 10 + (255 - 25)
    blue_half = int(blue) // 10 + (255 - 25)

    half_hex_color = "#{:02x}{:02x}{:02x}".format(red_half, green_half, blue_half)
    return half_hex_color

def show_text(text):
    return text[:10] + '...' + text[-10:] if len(text) > 20 else text

# Load the specific font and font size.
font_default = ImageFont.truetype("arial.ttf", 14)

def get_text_width(font, text):
    return font.getsize(text)[0]

def get_color_value(index, colorscale, min_value, max_value):
    value_range = max_value - min_value
    scaled_value = (index - min_value) / value_range
    color_index = int(scaled_value * (len(colorscale) - 1))
    return colorscale[color_index][1]

def create_heatmap(df, positions):
    # check if there's no repeated rows in positions, which is the second column of the positions list
    rows = [pos[1] for pos in positions]
    assert len(rows) == len(set(rows)), "There are repeated rows in positions"

    # compute similarity matrix
    sim_matrix = cosine_similarity(df['first_embed'].tolist(), df['second_embed'].tolist())
    
    jet_colorscale = [
        [0.0, "rgb(0, 0, 255)"],
        [0.1, "rgb(0, 100, 255)"],
        [0.2, "rgb(0, 200, 255)"],
        [0.3, "rgb(50, 255, 255)"],
        [0.4, "rgb(150, 255, 255)"],
        [0.5, "rgb(255, 255, 0)"],
        [0.6, "rgb(255, 150, 0)"],
        [0.7, "rgb(255, 100, 0)"],
        [0.8, "rgb(255, 50, 0)"],
        [0.9, "rgb(255, 0, 0)"],
        [1.0, "rgb(150, 0, 0)"],
    ]
    
    fig = go.Figure(data=go.Heatmap(z=sim_matrix, colorscale=jet_colorscale))
    
    # Add annotations for captions
    annotations = []
    shapes = []
    
    # Retrieve colors for the given positions
    position_colors = {}
    min_value, max_value = np.amin(sim_matrix), np.amax(sim_matrix)
    for col, row in positions:
        color_index = sim_matrix[row][col]
        color_value = get_color_value(color_index, jet_colorscale, min_value, max_value)
        position_colors[(col, row)] = color_value
    

    title_row = go.layout.Annotation(
        text="Row",
        font=dict(size=16, color="black"),
        x = -24,
        y = -1,
        showarrow=False,
        xref="x",
        yref="y",
        xanchor="center",
        yanchor="bottom",
    )

    title_col = go.layout.Annotation(
        text="Col",
        font=dict(size=16, color="black"),
        x = -9,
        y = -1,
        showarrow=False,
        xref="x",
        yref="y",
        xanchor="center",
        yanchor="bottom",
    )
    annotations.append(title_col)
    annotations.append(title_row)


    for pos in positions:
        col, row = pos
        for j, position in enumerate([(-15, 1), (-30, 0)]):
            text = show_text(df['first'][col] if j == 0 else df['second'][row])
            color = position_colors[(col, row)]
            hex_color = color.lstrip('rgb(').rstrip(')').split(", ")
            bgcolor = get_bg_color(hex_color)
            annotations.append(
                go.layout.Annotation(
                    x=position[0],
                    y=row,
                    text=text,
                    showarrow=False,
                    font=dict(size=14, color="black"),
                    bordercolor=color,
                    xref="x",
                    yref="y",
                    yshift=0,
                    xshift=-10,
                    xanchor="left",
                    yanchor="middle",
                    bgcolor=bgcolor,
                )
            )
            if j == 0:
                text_width = get_text_width(font_default, text) / 90
                shapes.append(
                    go.layout.Shape(
                        type='line',
                        x0= - text_width-2,
                        x1=col,
                        y0=row,
                        y1=row,
                        yref='y',
                        xref='x',
                        line=dict(
                            color=color,
                            width=1.5,
                        )
                    )
                )
    
    custom_width = 70 * len(sim_matrix)  # Increase the multiplier (40) for larger cells
    custom_height = 35 * len(sim_matrix)

    fig.update_layout(
        width=custom_width,
        height=custom_height,
        margin=dict(l=120),
        xaxis=dict(tickmode="array", tickvals=list(range(len(sim_matrix))), ticktext=list(range(len(sim_matrix))), title='X Axis'),
        yaxis=dict(tickmode="array", tickvals=list(range(len(sim_matrix))), ticktext=list(range(len(sim_matrix))), autorange="reversed", title='Y Axis'),
        annotations=annotations,
        shapes=shapes
    )

    # display plot
    fig.show()


## 读取数据
此单元格的输出为一个pandas Dataframe df，包含四列 ['first', 'second', 'first_embed', 'second_embed']

In [3]:
df = pd.read_csv("../data/sentspair_embed.csv", encoding="utf-8")

# convert the string representation of the embedding to a list of floats
df['first_embed'] = df['first_embed'].apply(lambda x: [float(i) for i in x[1:-1].split(',')])
df['second_embed'] = df['second_embed'].apply(lambda x: [float(i) for i in x[1:-1].split(',')])
# convert the list of floats to a numpy array
df['first_embed'] = df['first_embed'].apply(lambda x: np.array(x))
df['second_embed'] = df['second_embed'].apply(lambda x: np.array(x))

## 绘图

In [4]:
positions = [(i, i) for i in range(0, 20, 2)] + [(1, 5), (2, 3), (15, 9), (5, 13), (17, 7)]
create_heatmap(df, positions)


getsize is deprecated and will be removed in Pillow 10 (2023-07-01). Use getbbox or getlength instead.

