In [1]:
import sys
import os
import copy
import math
import pandas as pd
import numpy as np
from sympy import *

In [2]:
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import ipywidgets as widgets
from IPython.display import clear_output
#from google.colab import output
#output.enable_custom_widget_manager()
from sklearn.cluster import AgglomerativeClustering

In [3]:
# TRAINING DATA

paths = [x[0] for x in os.walk('train/')][1:]

In [4]:
# VALIDATION DATA

validation_paths = [x[0] for x in os.walk('validation/')][1:]

In [5]:
# TESTING DATA

testing_paths = [x[0] for x in os.walk('test/')][1:]

# Cluster Analysis 
In the following cells we perform cluster analysis to compare hotspots produced by the MD simulation and the predictions made by the CNN. 

In [6]:
def get_clusters(temp_grid):
  x = []
  y = []
  z = []
  temp = []
  for i in range(16):
    for j in range(32):
      for k in range(32):
        if temp_grid[i,j,k] >= 1.8: #temp cut-off
          x.append(i+1)
          y.append(j+1)
          z.append(k+1)
          temp.append(temp_grid[i,j,k])
  
  df = pd.DataFrame(temp,columns=['Temperature'])
  df['X'] = x
  df['Y'] = y
  df['Z'] = z

  pos = df[['X','Y','Z']].values
  agglo = AgglomerativeClustering(n_clusters=None, distance_threshold=1.8, linkage='single').fit(pos)
  df['Cluster ID'] = agglo.labels_

  return df

In [7]:
def analyze_clusters(df):
  dataframe = df.copy()
  vols = []
  mean_temps = []
  std_temps = []
  cms = []
  mis = []
  ids = []
  for i in range(int(dataframe['Cluster ID'].max())+1):
    if i in dataframe['Cluster ID'].tolist():
      tmp = dataframe[dataframe['Cluster ID'] == i]
      v = len(tmp)*4
      mean = np.mean(tmp['Temperature']*1000)
      std = np.std(tmp['Temperature']*1000)
      m = len(tmp)
      cm_x = tmp['X'].sum()/m
      cm_y = tmp['Y'].sum()/m
      cm_z = tmp['Z'].sum()/m
      cm = np.array([cm_x,cm_y,cm_z])
      r_i = tmp[['X','Y','Z']].values
      i_xx = 0
      i_yy = 0
      i_zz = 0
      i_xy = 0
      i_xz = 0
      i_yz = 0
      for j in range(len(tmp)):
        i_xx += (r_i[j][1]-cm_y)**2 + (r_i[j][2]-cm_z)**2
        i_yy += (r_i[j][0]-cm_x)**2 + (r_i[j][2]-cm_z)**2
        i_zz += (r_i[j][0]-cm_x)**2 + (r_i[j][1]-cm_y)**2

        i_xy += -(r_i[j][0]-cm_x)*(r_i[j][1]-cm_y)
        i_xz += -(r_i[j][0]-cm_x)*(r_i[j][2]-cm_z)
        i_yz += -(r_i[j][1]-cm_y)*(r_i[j][2]-cm_z)

      i_xx = np.around(i_xx,decimals=3)
      i_yy = np.around(i_yy,decimals=3)
      i_zz = np.around(i_zz,decimals=3)
      i_xy = np.around(i_xy,decimals=3)
      i_xz = np.around(i_xz,decimals=3)
      i_yz = np.around(i_yz,decimals=3)

      i_tensor = Matrix([[i_xx,i_xy,i_xz],[i_xy,i_yy,i_yz],[i_xz,i_yz,i_zz]])
      i_diag = i_tensor.diagonalize()[1]
      vols.append(v)
      mean_temps.append(mean)
      std_temps.append(std)
      cms.append(cm)
      mis.append(i_diag)
      ids.append(i)
    else:
      pass
  
  df_clust = pd.DataFrame(ids, columns=['cluster id'])
  df_clust['volume'] = vols
  df_clust['mean T (K)'] = mean_temps
  df_clust['std T (K)'] = std_temps
  df_clust['R center of mass'] = cms
  df_clust['moment of inertia'] = mis
  return df_clust

In [8]:
def plot_clusters(md,pred):
  fig = make_subplots(rows=1, cols=2, specs=[[{"type": "scatter3d"},{"type": "scatter3d"}]],
                    subplot_titles=["MD Clusters","CNN Clusters"], horizontal_spacing = 0.1, vertical_spacing = 0.1)
  fig.update_layout(autosize=False, width=1200, height=800) 
  trace_1 = go.Scatter3d(x = md['X'], y = md['Y'], z=md['Z'], hovertemplate = 'Cluster ID: %{marker.color:.2f}<extra></extra>',
                      mode='markers',  marker=dict(symbol='square', colorscale='rainbow', color = md['Cluster ID']), showlegend=False)
  fig.add_trace(trace_1, row=1, col=1)
  trace_2 = go.Scatter3d(x = pred['X'], y = pred['Y'], z=pred['Z'], hovertemplate = 'Cluster ID: %{marker.color:.2f}<extra></extra>',
                      mode='markers',  marker=dict(symbol='square', colorscale='rainbow', color = pred['Cluster ID']), showlegend=False)
  fig.add_trace(trace_2, row=1, col=2)

  return fig

## Example clusters

In [36]:
ex_path = 'test/smallPBX_2_mirror_btm'

In [37]:
md_temp = np.load(ex_path+'/output.npy')/1000
pred_temp = np.load('results/'+ex_path+'/prediction.npy')

md_1 = get_clusters(md_temp)
pred_1 = get_clusters(pred_temp)

md_clust = analyze_clusters(md_1)
pred_clust = analyze_clusters(pred_1)

In [38]:
plot_clusters(md_1,pred_1)

In [39]:
def visualize(md_1, pred_1, md_clust, pred_clust):
  fig = make_subplots(rows=1, cols=3, specs=[[{"type": "scatter3d"},{"type": "scatter3d"},{"type": "scatter"}]],
                      subplot_titles=["MD Clusters","CNN Clusters","Temp vs Volume"], horizontal_spacing = 0.1, vertical_spacing = 0.1)
  fig.update_layout(autosize=False, width=1800, height=600) 


  ## MD SIMULATION CLUSTERS ##
  tmp = md_1
  trace_1 = go.Scatter3d(x = tmp['X'], y = tmp['Y'], z = tmp['Z'],text=tmp['Cluster ID'],
                        mode='markers', marker=dict(size=8, line=dict(width=0), symbol='square', colorscale='rainbow', color = tmp['Cluster ID']), showlegend=False, scene='scene1')
 
  fig.add_trace(trace_1, row=1, col=1)

  ## CNN PREDICTED CLUSTERS ##
  tmp2 = pred_1
  trace_4 = go.Scatter3d(x = tmp2['X'], y = tmp2['Y'], z = tmp2['Z'],text=tmp2['Cluster ID'],
                        mode='markers', marker=dict(size=8, line=dict(width=0), symbol='square', colorscale='rainbow', color = tmp2['Cluster ID']), showlegend=False, scene='scene2')


  fig.add_trace(trace_4, row=1, col=2)



  ### Hotspot Temperature vs Volume ###
  trace_7 = go.Scatter(x=md_clust['volume'],y=md_clust['mean T (K)'],error_y=dict(type='data', array=md_clust['std T (K)'],visible=True),text=md_clust['cluster id'],
                      mode='markers',marker=dict(symbol='square', color = 'green',size=18),name='MD')
  trace_8 = go.Scatter(x=pred_clust['volume'],y=pred_clust['mean T (K)'],error_y=dict(type='data', array=pred_clust['std T (K)'],visible=True),text=md_clust['cluster id'],
                        mode='markers',marker=dict(symbol='circle', color = 'red',size=18),name='CNN')

  fig.add_trace(trace_7, row=1, col=3)
  fig.add_trace(trace_8, row=1, col=3)
  fig.update_xaxes(title="Volume")
  fig.update_yaxes(title="Temperature (K)")

  return fig

In [40]:
visualize(md_1,pred_1,md_clust,pred_clust)

## Overlap analysis
Calculate the overlap between clusters in the MD and clusters in the CNN.

In [41]:
def get_overlap(df1,df2):
  md = df1.copy()
  pred = df2.copy()
  overlap_data = []
  for i in range(md['Cluster ID'].max()+1):
    for j in range(pred['Cluster ID'].max()+1):
      md_tmp = md[md['Cluster ID'] == i]
      pred_tmp = pred[pred['Cluster ID'] == j]
      cnn_over = [v for v in md_tmp[['X','Y','Z']].values.tolist() if v in pred_tmp[['X','Y','Z']].values.tolist()]
      cnn_ratio = len(cnn_over)/len(pred_tmp)
      md_over = [v for v in pred_tmp[['X','Y','Z']].values.tolist() if v in md_tmp[['X','Y','Z']].values.tolist()]
      md_ratio = len(md_over)/len(md_tmp)
      overlap_data.append([i,j,md_ratio,cnn_ratio])
  df_over = pd.DataFrame(overlap_data,columns=['MD id','Pred id','MD ratio','CNN ratio'])
  df_over_ = df_over[df_over['MD ratio']>0.0].reset_index(drop=True)
  df_md_over = df_over_[df_over_.duplicated(subset=['MD id'],keep=False)].copy()#.reset_index(drop=True)
  df_cnn_over = df_over_[df_over_.duplicated(subset=['Pred id'],keep=False)].copy()#.reset_index(drop=True)
  df_cnn_over = df_cnn_over.sort_values(by='Pred id')
  indx = list(df_md_over.index)
  for i in list(df_cnn_over.index):
    if i not in indx:
      indx.append(i)
  df_single_over = df_over_.drop(indx)
  df_single_over['New id'] = df_single_over['MD id']
  n_id = []
  for id in df_md_over['MD id'].unique():
    tmp = df_md_over[df_md_over['MD id']==id].reset_index(drop=True)
    combined_ratio = 0
    for n in range(len(tmp)):
      combined_ratio+=tmp.iloc[n][2]
    if combined_ratio > 0.333:
      for l in range(len(tmp)):
        n_id.append(id)
    else:
      for l in range(len(tmp)):
        n_id.append(None)
  df_md_over['New id'] = n_id
  n_id = []
  for i,id in enumerate(df_cnn_over['Pred id'].unique()):
    tmp = df_cnn_over[df_cnn_over['Pred id']==id].reset_index(drop=True)
    combined_ratio = 0
    for n in range(len(tmp)):
      combined_ratio+=tmp.iloc[n][3]
    if id in df_md_over['Pred id'].unique():
      for l in range(len(tmp)):
        tmp2 = df_md_over[df_md_over['Pred id']==id].reset_index(drop=True)
        n_id.append(tmp2['New id'][0])
    elif combined_ratio > 0.333:
      for l in range(len(tmp)):
        if df_md_over['New id'].max() != df_md_over['New id'].max():
          n_id.append(df_single_over['New id'].max()+i+1)
        else:
          n_id.append(max(df_md_over['New id'].max(),df_single_over['New id'].max())+i+1)
    else:
      for l in range(len(tmp)):
        n_id.append(None)
  df_cnn_over['New id'] = n_id
  df_over_new = pd.concat([df_single_over,df_md_over,df_cnn_over]).drop_duplicates(subset=['MD id','Pred id']).reset_index(drop=True)
  nan_ls = [None]*len(md)
  md['New ID'] = nan_ls
  for m_id in df_over_new['MD id'].unique():
    for i in range(len(md)):
      if md['Cluster ID'][i]==m_id:
        tmp = df_over_new[df_over_new['MD id']==m_id].copy()
        md['New ID'][i] = tmp.copy()['New id'].iloc[0]
  nan_ls = [None]*len(pred)
  pred['New ID'] = nan_ls
  for p_id in df_over_new['Pred id'].unique():
    for i in range(len(pred)):
      if pred['Cluster ID'][i]==p_id:
        tmp = df_over_new[df_over_new['Pred id']==p_id].copy()
        pred['New ID'][i] = tmp.copy()['New id'].iloc[0] 

  return md, pred, df_over     

In [42]:
def analyze_clusters_overlap(df):
  dataframe = df.copy()
  vols = []
  mean_temps = []
  std_temps = []
  cms = []
  mis = []
  ids = []
  min_mi = []
  max_mi = []
  mean_mi = []
  filt = dataframe.dropna()
  for i in range(int(filt['New ID'].max())+1):
    if i in filt['New ID'].tolist():
      tmp = filt[filt['New ID'] == i]
      v = len(tmp)*4
      mean = np.mean(tmp['Temperature']*1000)
      std = np.std(tmp['Temperature']*1000)
      m = len(tmp)
      cm_x = tmp['X'].sum()/m
      cm_y = tmp['Y'].sum()/m
      cm_z = tmp['Z'].sum()/m
      cm = np.array([cm_x,cm_y,cm_z])
      r_i = tmp[['X','Y','Z']].values
      i_xx = 0
      i_yy = 0
      i_zz = 0
      i_xy = 0
      i_xz = 0
      i_yz = 0
      for j in range(len(tmp)):
        i_xx += (r_i[j][1]-cm_y)**2 + (r_i[j][2]-cm_z)**2
        i_yy += (r_i[j][0]-cm_x)**2 + (r_i[j][2]-cm_z)**2
        i_zz += (r_i[j][0]-cm_x)**2 + (r_i[j][1]-cm_y)**2

        i_xy += -(r_i[j][0]-cm_x)*(r_i[j][1]-cm_y)
        i_xz += -(r_i[j][0]-cm_x)*(r_i[j][2]-cm_z)
        i_yz += -(r_i[j][1]-cm_y)*(r_i[j][2]-cm_z)

      i_xx = np.around(i_xx,decimals=3)
      i_yy = np.around(i_yy,decimals=3)
      i_zz = np.around(i_zz,decimals=3)
      i_xy = np.around(i_xy,decimals=3)
      i_xz = np.around(i_xz,decimals=3)
      i_yz = np.around(i_yz,decimals=3)

      i_tensor = Matrix([[i_xx,i_xy,i_xz],[i_xy,i_yy,i_yz],[i_xz,i_yz,i_zz]])
      i_diag = i_tensor.diagonalize()[1]
      vols.append(v)
      mean_temps.append(mean)
      std_temps.append(std)
      cms.append(cm)
      mis.append(i_diag)
      min_mi.append(float(min(i_diag[0],i_diag[4],i_diag[8])))
      max_mi.append(float(max(i_diag[0],i_diag[4],i_diag[8])))
      mean_mi.append(float(np.mean([i_diag[0],i_diag[4],i_diag[8]])))
      ids.append(i)
    else:
      pass
  
  df_clust = pd.DataFrame(ids, columns=['cluster id'])
  df_clust['volume'] = vols
  df_clust['mean T (K)'] = mean_temps
  df_clust['std T (K)'] = std_temps
  df_clust['R center of mass'] = cms
  df_clust['moment of inertia'] = mis
  df_clust['min principal MI'] = min_mi
  df_clust['max principal MI'] = max_mi
  df_clust['mean principal MI'] = mean_mi
  return df_clust

In [43]:
def visualize_overlap(md_corr, pred_corr, md_clust, pred_clust):
  fig = make_subplots(rows=2, cols=2, specs=[[{"type": "scatter3d"},{"type": "scatter3d"}],
                                            [{"type": "scatter"},{"type": "scatter"}]],
                      subplot_titles=["MD Clusters","CNN Clusters","Temp vs Volume","Temp vs Principal"], horizontal_spacing = 0.1, vertical_spacing = 0.1)
  fig.update_layout(autosize=False, width=1200, height=800) 


  ## MD SIMULATION CLUSTERS ##
  tmp = md_corr.dropna()
  trace_1 = go.Scatter3d(x = tmp['X'], y = tmp['Y'], z = tmp['Z'],text=tmp['New ID'],
                        mode='markers', marker=dict(size=8, line=dict(width=0), symbol='square', colorscale='rainbow', color = tmp['New ID']), showlegend=False, scene='scene1')
  
  tmp = md_corr[md_corr['New ID'].isnull()]
  trace_2 = go.Scatter3d(x = tmp['X'], y = tmp['Y'], z = tmp['Z'],text=tmp['Cluster ID'],
                        mode='markers', marker=dict(size=8, line=dict(width=0), symbol='square', color = 'black'), showlegend=False, scene='scene1')    
  fig.add_trace(trace_1, row=1, col=1)
  fig.add_trace(trace_2, row=1, col=1)

  ## CNN PREDICTED CLUSTERS ##
  tmp2 = pred_corr.dropna()
  trace_3 = go.Scatter3d(x = tmp2['X'], y = tmp2['Y'], z = tmp2['Z'],text=tmp2['New ID'],
                        mode='markers', marker=dict(size=8, line=dict(width=0), symbol='square', colorscale='rainbow', color = tmp2['New ID']), showlegend=False, scene='scene2')
  
  tmp2 = pred_corr[pred_corr['New ID'].isnull()]
  trace_4 = go.Scatter3d(x = tmp2['X'], y = tmp2['Y'], z = tmp2['Z'],text=tmp2['Cluster ID'],
                        mode='markers', marker=dict(size=8, line=dict(width=0), symbol='square', color = 'black'), showlegend=False, scene='scene2')

  fig.add_trace(trace_3, row=1, col=2)
  fig.add_trace(trace_4, row=1, col=2)


  ### Hotspot Temperature vs Volume ###
  trace_5 = go.Scatter(x=md_clust['volume'],y=md_clust['mean T (K)'],error_y=dict(type='data', array=md_clust['std T (K)'],visible=True),text=md_clust['cluster id'],
                      mode='markers',marker=dict(symbol='square',colorscale='rainbow', color = md_clust['cluster id'],size=18),name='MD')
  trace_6 = go.Scatter(x=pred_clust['volume'],y=pred_clust['mean T (K)'],error_y=dict(type='data', array=pred_clust['std T (K)'],visible=True),text=md_clust['cluster id'],
                        mode='markers',marker=dict(symbol='circle',colorscale='rainbow', color = md_clust['cluster id'],size=18),name='CNN')

  fig.add_trace(trace_5, row=2, col=1)
  fig.add_trace(trace_6, row=2, col=1)
  fig.update_xaxes(title="Volume", row=2, col=1)
  fig.update_yaxes(title="Temperature (K)", row=2, col=1)

  ### Hotspot Moment Inertia vs Temperature ###
  trace_7 = go.Scatter(x=md_clust['min principal MI'],y=md_clust['mean T (K)'],error_y=dict(type='data', array=md_clust['std T (K)'],visible=True),text=md_clust['cluster id'],
                      mode='markers',marker=dict(symbol='square',color='green',size=14), showlegend=False,name='MD min principal axis')
  
  trace_8 = go.Scatter(x=pred_clust['min principal MI'],y=pred_clust['mean T (K)'],error_y=dict(type='data', array=pred_clust['std T (K)'],visible=True),text=pred_clust['cluster id'],
                      mode='markers',marker=dict(symbol='circle',color='green',size=14), showlegend=False,name='CNN min principal axis')
  
  trace_9 = go.Scatter(x=md_clust['max principal MI'],y=md_clust['mean T (K)'],error_y=dict(type='data', array=md_clust['std T (K)'],visible=True),text=md_clust['cluster id'],
                      mode='markers',marker=dict(symbol='square',color='red',size=14), showlegend=False,name='MD max principal axis')
  
  trace_10 = go.Scatter(x=pred_clust['max principal MI'],y=pred_clust['mean T (K)'],error_y=dict(type='data', array=pred_clust['std T (K)'],visible=True),text=pred_clust['cluster id'],
                      mode='markers',marker=dict(symbol='circle',color='red',size=14), showlegend=False,name='CNN max principal axis')
  
  trace_11 = go.Scatter(x=md_clust['mean principal MI'],y=md_clust['mean T (K)'],error_y=dict(type='data', array=md_clust['std T (K)'],visible=True),text=md_clust['cluster id'],
                      mode='markers',marker=dict(symbol='square',color='blue',size=14), showlegend=False,name='MD mean principal axis')
  
  trace_12 = go.Scatter(x=pred_clust['mean principal MI'],y=pred_clust['mean T (K)'],error_y=dict(type='data', array=pred_clust['std T (K)'],visible=True),text=pred_clust['cluster id'],
                      mode='markers',marker=dict(symbol='circle',color='blue',size=14), showlegend=False,name='CNN mean principal axis')
  
  fig.add_trace(trace_7, row=2, col=2)
  fig.add_trace(trace_8, row=2, col=2)
  fig.add_trace(trace_9, row=2, col=2)
  fig.add_trace(trace_10, row=2, col=2)
  fig.add_trace(trace_11, row=2, col=2)
  fig.add_trace(trace_12, row=2, col=2)
  fig.update_xaxes(title="Principal Moment of Inertia", row=2, col=2)
  fig.update_yaxes(title="Temperature (K)", row=2, col=2)

  return fig

## Example overlap

In [44]:
md_temp = np.load(ex_path+'/output.npy')/1000
pred_temp = np.load('results/'+ex_path+'/prediction.npy')

md_1 = get_clusters(md_temp)
pred_1 = get_clusters(pred_temp)

md_clust = analyze_clusters(md_1)
pred_clust = analyze_clusters(pred_1)

md_1_over, pred_1_over, df_1_over = get_overlap(md_1,pred_1)

md_clust_over = analyze_clusters_overlap(md_1_over)
pred_clust_over = analyze_clusters_overlap(pred_1_over)

test_fig = visualize_overlap(md_1_over,pred_1_over,md_clust_over,pred_clust_over)



A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/i

In [45]:
test_fig

In [46]:
vols = []
for i in range(len(md_1_over)):
    clus_id = md_1_over['New ID'][i]
    if clus_id is None:
        vols.append(None)
    else:
        tmp = md_clust_over[md_clust_over['cluster id']==clus_id]
        vols.append(tmp['volume'].iloc[0]*3.553)

md_1_over['volume'] = vols
md_1_over

Unnamed: 0,Temperature,X,Y,Z,Cluster ID,New ID,volume
0,2.09664,1,1,10,13,13,284.24
1,2.60155,1,2,10,13,13,284.24
2,2.31233,1,2,11,13,13,284.24
3,1.90297,1,3,10,13,13,284.24
4,2.61665,1,3,11,13,13,284.24
...,...,...,...,...,...,...,...
319,2.09986,16,16,8,11,11,142.12
320,2.01144,16,24,18,22,22,71.06
321,2.09275,16,24,19,22,22,71.06
322,1.91828,16,25,18,22,22,71.06


In [47]:
vols = []
for i in range(len(pred_1_over)):
    clus_id = pred_1_over['New ID'][i]
    if clus_id is None:
        vols.append(None)
    else:
        tmp = pred_clust_over[pred_clust_over['cluster id']==clus_id]
        vols.append(tmp['volume'].iloc[0]*3.553)

pred_1_over['volume'] = vols
pred_1_over

Unnamed: 0,Temperature,X,Y,Z,Cluster ID,New ID,volume
0,1.820790,1,2,10,13,13,156.332
1,2.026889,1,2,11,13,13,156.332
2,2.302288,1,3,11,13,13,156.332
3,2.037051,1,9,12,15,,
4,2.100370,1,14,8,0,24,213.180
...,...,...,...,...,...,...,...
262,2.135527,16,24,19,4,22,113.696
263,1.839056,16,24,20,4,22,113.696
264,2.225986,16,25,18,4,22,113.696
265,2.424878,16,25,19,4,22,113.696


In [48]:
md_ = md_1_over.dropna()
pred_ = pred_1_over.dropna()

pred_x = list(np.cbrt(pred_['volume']))
pred_y = list(pred_['Temperature']*1000)
pred_id = list(pred_['New ID'])

md_x = list(np.cbrt(md_['volume']))
md_y = list(md_['Temperature']*1000)
md_id = list(md_['New ID'])

pred_arr = np.column_stack([pred_x, pred_y, pred_id])
pred_df = pd.DataFrame(pred_arr)

md_arr = np.column_stack([md_x, md_y, md_id])
md_df = pd.DataFrame(md_arr)

md_groups = md_df.groupby(by=2)
pred_groups = pred_df.groupby(by=2)

In [49]:
colors = ['#aa007f','#7300ad','#00007f','','#5500ff','#0000ff','#0055ff','#5555ff','','','#52cbff','#59ffde','#50ffbf',
          '#66ff33','','#aaff7f','','#ffff00','#ffe205','','#ffaa00','#ff8000','#ff5500','#ff4102','#ff0000','#000000']
fig = go.Figure()
fig.update_layout(autosize=False, width=900, height=600) 
for (gm,gp) in zip(md_groups,pred_groups):
    id = int(gm[0])
    if id is None:
        c=colors[-1]
    else:
        c=colors[id]
    if gm[1][0].iloc[0] > 3.5:
        if gp[1][0].iloc[0] > 3.5:
            fig.add_trace(go.Violin(x=gm[1][0],
                            y=gm[1][1],
                            legendgroup='No', scalegroup='No', width=0.33,
                            side='negative',
                            line_color=c, showlegend=False, text=id))
for (gm,gp) in zip(md_groups,pred_groups):
    id = int(gp[0])
    if id is None:
        c=colors[-1]
    else:
        c=colors[id]
    if gp[1][0].iloc[0] > 3.5:
        if gm[1][0].iloc[0] > 3.5:
            fig.add_trace(go.Violin(x=gp[1][0],
                            y=gp[1][1],
                            legendgroup='Yes', scalegroup='Yes', width=0.33,
                            side='positive',
                            line_color=c, showlegend=False, hovertext=id))
fig.update_traces(meanline_visible=True, spanmode='hard', points=False)
fig.update_xaxes(title="Characteristic Length (nm)",title_font = {'size':24},tickfont = {'size':24},color='black')
fig.update_yaxes(title="Temperature (K)",title_font = {'size':24},tickfont = {'size':24},color='black')
fig.update_xaxes(showline=True, linewidth=2, linecolor='black', mirror=True) #, gridcolor='grey'
fig.update_yaxes(showline=True, linewidth=2, linecolor='black', mirror=True) #, gridcolor='grey'
fig.update_layout(legend=dict(yanchor="top",y=0.99,xanchor="left",x=0.01,font=dict({'size':24})),plot_bgcolor='rgba(0,0,0,0)', template='simple_white')
fig.show()