In [1]:
from WSBM import *
from Transformations import *
from Chernoff import *
from Helper import *

In [2]:
import numpy as np
import matplotlib.pyplot as plt

In [3]:
n = 1000

pis = [0.1, 0.5]
PIS = [np.diag([pi, 1-pi]) for pi in pis]
RHOS = [0.25, 0.5]
τ = 0.1

ALPHAS = [np.array([[0.1, 1.0],
					[1.0, 1.0]]),
		  np.array([[0.5, 0.5],
					[0.5, 1.0]])]

SIGMAS = [np.array([[1, 0.5],
					[0.5, 1]]),
		  np.array([[0.1, 0.5],
					[0.5, 1]])]

In [4]:
Id    = IdentityTransform()
Opp   = OppositeTransform()
Log   = LogTransform()
Tresh = ThresholdTransform(τ)

Transforms = [Id, Opp, Log, Tresh]

In [5]:
def get_transformed_graphs_separations_metrics(rho, Pi, ALPHAS, SIGMAS):
	Graphs = []
	for Alpha in ALPHAS:
		beta = betaWSBM(n, rho, Pi, Alpha)
		A, Z = beta(42)
		Graphs.append((A, Z, beta))

	for Sigma in SIGMAS:
		lognorm = lognormWSBM(n, rho, Pi, Sigma)
		A, Z = lognorm(42)
		Graphs.append((A, Z, lognorm))

	metrics = {}
	Graphs_name = []

	for A, Z, model in Graphs:
		for T in Transforms:
			metrics[(model.name, T.name)] = TWSBInstance(model, T, T(A), Z)
		Graphs_name.append(model.name)

	return metrics, Graphs_name

In [6]:
METRICS = []
for rho in RHOS:
    for Pi in PIS:
        METRICS.append((rho, Pi, get_transformed_graphs_separations_metrics(rho, Pi, ALPHAS, SIGMAS)))

In [7]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets

def plot(metrics, Graphs_name):
	def switch(mode='Truth'):
		fig, axes = plt.subplots(len(Graphs_name), len(Transforms), figsize=(15, 15))
		for i, G_name in enumerate(Graphs_name):
			axes[i, 0].set_ylabel(G_name + "\n")
			for j, T in enumerate(Transforms):
				G = metrics[(G_name, T.name)]
				Z, X, Z_hat, M, Σ, C_true, C_graph, C_embedding, RAND = G.Z, G.X, G.Z_hat, G.M, G.Σ, G.C_true, G.C_graph, G.C_embedding, G.RAND
				ax = axes[i][j]
				plt.sca(ax)
				if mode == 'Truth':
					ax.scatter(X[:, 0], X[:, 1], c=Z, cmap='bwr', marker='.', alpha=0.2)
				else:
					ax.scatter(X[:, 0], X[:, 1], c=Z_hat, cmap='bwr', marker='.', alpha=0.2)
				ax.set_xticks([])
				ax.set_yticks([])
				for mean, cov in zip(M, Σ):
					eigenvalues, eigenvectors = np.linalg.eigh(cov)
					angle = np.degrees(np.arctan2(eigenvectors[0, 1], eigenvectors[0, 0]))
					width, height = 2 * np.sqrt(6 * eigenvalues)
					ellip = plt.matplotlib.patches.Ellipse(
						mean, width, height, angle=angle, edgecolor='k', facecolor='none', linestyle='solid'
					)
					ax.add_patch(ellip)
				transform_name = T.name + "\n" if i == 0 else ""
				title = (
					transform_name
					+ f"RI: {RAND:.2f}\n"
					+ f"CT: {C_true:.5f} "
					+ f"CG: {C_graph:.5f} "
					+ f"CE: {C_embedding:.5f}"
				)
				ax.set_title(title)
		plt.tight_layout()
		plt.show()

	widgets.interact(switch, mode=['Truth', 'Prediction'])

In [None]:
for rho, Pi, (metrics, Graphs_name) in METRICS:
	print(f"rho = {rho}\n Pi = {Pi}")
	plot(metrics, Graphs_name)

rho = 0.25
 Pi = [[0.1 0. ]
 [0.  0.9]]


interactive(children=(Dropdown(description='mode', options=('Truth', 'Prediction'), value='Truth'), Output()),…

rho = 0.25
 Pi = [[0.5 0. ]
 [0.  0.5]]


interactive(children=(Dropdown(description='mode', options=('Truth', 'Prediction'), value='Truth'), Output()),…

rho = 0.5
 Pi = [[0.1 0. ]
 [0.  0.9]]


interactive(children=(Dropdown(description='mode', options=('Truth', 'Prediction'), value='Truth'), Output()),…

rho = 0.5
 Pi = [[0.5 0. ]
 [0.  0.5]]


interactive(children=(Dropdown(description='mode', options=('Truth', 'Prediction'), value='Truth'), Output()),…