In [1]:
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.colors as mcolors
import plotly.graph_objects as go
from plotly.subplots import make_subplots

custom_params = {"axes.spines.right": False, "axes.spines.top": False}
sns.set_theme(style="ticks", font_scale=0.8, rc=custom_params)
%config InlineBackend.figure_format='retina'


In [2]:
df = pd.read_csv("mouse_df.csv")
display(df.head(3))

Unnamed: 0,ensg_id,gene_symbol,leiden-nt,leiden-esm3,mmidas-joint,nt-umap-0,nt-umap-1,aa-umap-0,aa-umap-1
0,ENSMUSG00000018566.15,Slc2a4,8,3,9,10.830853,0.346162,14.341636,4.835434
1,ENSMUSG00000040888.12,Gfer,14,24,-1,11.152143,-4.010281,7.149468,4.560211
2,ENSMUSG00000078235.4,Fam43b,2,23,1,11.687284,1.119307,6.526957,0.824544


In [3]:
msize=1

# assign colors for mmidas-joint
palette = sns.color_palette("Spectral", n_colors=len(df['mmidas-joint'].unique()))
hex_palette = [mcolors.rgb2hex(color) for color in palette]
mmidas_marker = dict(color=[hex_palette[int(l)] if l >= 0 else '#808080' for l in df['mmidas-joint']], size=msize)

# assign colors for leiden-esm3
palette = sns.color_palette("Spectral", n_colors=len(df['leiden-esm3'].unique()))
hex_palette = [mcolors.rgb2hex(color) for color in palette]
esm3_marker = dict(color=[hex_palette[int(l)] if l >= 0 else '#808080' for l in df['leiden-esm3']], size=msize)

# assign colors for leiden-nt
palette = sns.color_palette("Spectral", n_colors=len(df['leiden-nt'].unique()))
hex_palette = [mcolors.rgb2hex(color) for color in palette]
nt_marker = dict(color=[hex_palette[int(l)] if l >= 0 else '#808080' for l in df['leiden-nt']], size=msize)

In [4]:
hovertext = [str(df['gene_symbol'][g]) + \
    '\n leiden-nt: ' + str(df['leiden-nt'][g]) + \
    '\n leiden-esm3: ' + str(df['leiden-esm3'][g]) + \
    '\n mmidas-joint: ' + str(df['mmidas-joint'][g]) for g in range(len(df))]

In [5]:


# Create Subplots
fig = make_subplots(rows=1, cols=2, subplot_titles=("NT embeddings", "ESM3 embeddings"))

# Panel 1 - Scatter Plot with Labels 1
fig.add_trace(go.Scatter(
    x=df['nt-umap-0'],
    y=df['nt-umap-1'],
    mode='markers',
    marker=nt_marker,
    hovertext=hovertext,
    hoverinfo='text',
    name='DNA view'
), row=1, col=1)

# Panel 2 - Scatter Plot with Labels 1
fig.add_trace(go.Scatter(
    x=df['aa-umap-0'],
    y=df['aa-umap-1'],
    mode='markers',
    marker=esm3_marker,
    hovertext=hovertext,
    hoverinfo='text',   
    name='Protein view'
), row=1, col=2)

# Add Dropdown Menu
fig.update_layout(
    updatemenus=[
        dict(
            buttons=[
                dict(
                    label='Leiden NT and ESM3',
                    method='update',
                    args=[{
                        'marker.color': [
                            nt_marker['color'],
                            esm3_marker['color']
                        ]
                    }]
                ),
                dict(
                    label='Leiden NT',
                    method='update',
                    args=[{
                        'marker.color': [
                            nt_marker['color'],
                            nt_marker['color']
                        ]
                    }]
                ),
                dict(
                    label='Leiden ESM3',
                    method='update',
                    args=[{
                        'marker.color': [
                            esm3_marker['color'],
                            esm3_marker['color']
                        ]
                    }]
                ),
                dict(
                    label='MMIDAS Joint',
                    method='update',
                    args=[{
                        'marker.color': [
                            mmidas_marker['color'],
                            mmidas_marker['color']
                        ]
                    }]
                )
            ],
            direction="down",
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0.5,
            xanchor="center",
            y=1.3,
            yanchor="top"
        )
    ]
)

for annotation in fig.layout.annotations:
    annotation.font.size = 10 

fig.update_layout(xaxis_title='UMAP 0',
                  yaxis_title='UMAP 1',
                  xaxis_title_font=dict(size=8), 
                  yaxis_title_font=dict(size=8),
                  title_font=dict(size=10), 
                  font=dict(size=10),
                  height=325, width=650, showlegend=False)

fig.update_xaxes(showticklabels=False)
fig.update_yaxes(showticklabels=False)

fig.update_layout(
    margin=dict(l=20, r=20, t=20, b=20),
)

#fig.show()

fig.write_html('umap_plot.html')


In [6]:
print(df['leiden-nt'].unique().size)
print(df['leiden-esm3'].unique().size)
print(df['mmidas-joint'].unique().size)
# All the cases where midas labels disagree between views were collected into a label = -1.
# So the number of mmidas labels is actually one less than unique entries.

20
55
28
