In [28]:
import json
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from glob import glob
from pathlib import Path
from plotly.subplots import make_subplots

template = "plotly_white"
colors = px.colors.qualitative.Vivid

In [29]:
results = glob("results*.npz")
results = sorted(results, key=lambda p: int(p.split("-")[1].replace(".npz", "")))

In [80]:
"""
np.savez("results-{}".format(self.sampler.labeled_mask.sum()),
                  Z=self.sampler.Z,
                  labeled_mask=self.labeled_mask,
                  X=self.train_x)
"""
r = results[3]
data = np.load(r)
Z = data["Z"]
labeled_mask = data["labeled_mask"]
X = data["X"]
Y = data["Y"]

In [81]:
fig = go.Figure()
fig.add_trace(go.Heatmap(z=Z))
fig.update_xaxes(showline=True, linewidth=1.5, linecolor='Black', mirror=True)
fig.update_yaxes(showline=True, linewidth=1.5, linecolor='Black', mirror=True)

fig.update_layout(width=500, 
                  height=400, 
                  font=dict(size=15), 
                  margin=dict(t=40, b=40, l=0, r=0),
                  legend=dict(x=0.98,
                              y=0.05,
                              yanchor="bottom",
                              xanchor="right",
                              bgcolor= 'rgba(0,0,0,0)',
                              bordercolor="Black",
                              borderwidth=1), 
                  title="Data Association Matrix (Z)", 
                  title_x=0.5, 
                  template=template)
fig.write_image("plots/z_matrix.pdf")
fig.show()

In [82]:
fig = make_subplots(rows=1, cols=2, subplot_titles=["Data Distribution", "Data Association"])

fig.add_trace(go.Scatter(x=X[:, 0], 
                         y=X[:, 1], 
                         showlegend=False, 
                         mode="markers", 
                         marker=dict(size=10, color=[colors[i] for i in Y])), 
             row=1, col=1)




unlabeled_idx = np.where(~labeled_mask)[0]
labeled_data = np.unique(Z.argmax(0)).tolist()

fig.add_trace(go.Scatter(x=X[labeled_mask, 0], 
                         y=X[labeled_mask, 1], 
                         showlegend=True, 
                         mode="markers", 
                         name="labeled", 
                         marker=dict(size=10, color=colors[0])), 
             row=1, col=2)


fig.add_trace(go.Scatter(x=X[~labeled_mask, 0], 
                         y=X[~labeled_mask, 1], 
                         showlegend=True, 
                         mode="markers", 
                         name="unlabeled", 
                         marker=dict(size=10, color=colors[1])), 
             row=1, col=2)

fig.add_trace(go.Scatter(x=X[labeled_data, 0], 
                         y=X[labeled_data, 1], 
                         showlegend=True, 
                         mode="markers", 
                         name="representatives", 
                         marker=dict(size=10, color=colors[2])), 
             row=1, col=2)



for i, z in enumerate(Z.T):
    
    j = z.argmax()
    idx = [unlabeled_idx[i], unlabeled_idx[j]]
    fig.add_trace(go.Scatter(x=X[idx, 0], 
                             y=X[idx, 1], 
                             showlegend=False, 
                             mode="lines", 
                             line=dict(width=1.5, color=colors[4])),  
                 row=1, col=2)





fig.update_xaxes(showline=True, linewidth=1.5, linecolor='Black', mirror=True)
fig.update_yaxes(showline=True, linewidth=1.5, linecolor='Black', mirror=True)


fig.update_layout(width=700, 
                  height=350, 
                  font=dict(size=15), 
                  margin=dict(t=40, b=40, l=0, r=0),
                  legend=dict(x=0.995,
                              y=0.01,
                              font=dict(size=11), 
                              yanchor="bottom",
                              xanchor="right",
                              bgcolor= 'rgba(0,0,0,0)',
                              bordercolor="Black",
                              borderwidth=1), 
                  title_x=0.5, 
                  template=template)
fig.write_image("plots/data_association.pdf")
fig.show()