### Library loading

In [1]:
# system imports
import sys, os, json, pickle
from pathlib import Path

# utility imports
from tqdm import tqdm
from itertools import product

# data science imports
import numpy as np, pandas as pd, scipy.sparse as sp, scipy.stats as stats

# plotting imports
import matplotlib.pyplot as plt, matplotlib.colors as mcolors, plotly.express as px, plotly.graph_objects as go
from matplotlib.colors import LinearSegmentedColormap

# ensure project root on path
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
	sys.path.insert(0, project_root)

# imports from src
from src.twsbm import TWSBM
from src.tdcsbm import TDCSBM
from src.metrics import *
from src.transformations import *
from src.models.beta import *
from src.models.lognorm import *
from src.models.wdcsbm import *

### Transformations selection

In [2]:
FAMILIES = {
	'classical': [
		QuantileTransform(q=0.1),
		QuantileTransform(q=0.25),
		QuantileThresholding(q=0.1),
		QuantileThresholding(q=0.25),
		PowerTransform(γ=0.5),
		PowerTransform(γ=1), # identity
		PowerTransform(γ=2),
		LogPowerTransform(γ=0.5),
		LogPowerTransform(γ=1), # log
		LogPowerTransform(γ=2),
	],
	'qtl_05': [QuantileTransform(q) for q in np.linspace(0, 0.5, 6)[1:]],
	'qtl_10': [QuantileTransform(q) for q in np.linspace(0, 0.5, 11)[1:]],
	'qtr_05': [QuantileThresholding(q) for q in np.linspace(0, 0.5, 6)[1:]],
	'qtr_10': [QuantileThresholding(q) for q in np.linspace(0, 0.5, 11)[1:]],
	'log_05': [LogPowerTransform(γ) for γ in np.linspace(0.5, 2, 7)],
	'log_10': [LogPowerTransform(γ) for γ in np.linspace(0.5, 2, 13)],
	'pow_05': [PowerTransform(p) for p in np.linspace(0.5, 2, 7)],
	'pow_10': [PowerTransform(p) for p in np.linspace(0.5, 2, 13)]
}

TRANSFORMS = FAMILIES['qtl_10'] + FAMILIES['qtr_10'] + FAMILIES['log_10'] + FAMILIES['pow_10']
TRANSFORMS_MAP_COLOR = {t.id : t.color for t in TRANSFORMS}

FAMILIES_IDX = {
	name: [TRANSFORMS.index(t) for t in transforms]
	for name, transforms in FAMILIES.items()
}

FAMILIES_MAP_COLOR = dict(zip(FAMILIES.keys(), [
    'brown',
    FAMILIES['qtl_10'][0].color,
	FAMILIES['qtl_10'][-1].color,
    FAMILIES['qtr_10'][0].color,
	FAMILIES['qtr_10'][-1].color,
    FAMILIES['log_10'][0].color,
	FAMILIES['log_10'][-1].color,
	FAMILIES['pow_10'][0].color,
	FAMILIES['pow_10'][-1].color
]))

### Utility

In [3]:
# Compute metrics for a list of graphs, transformations and families
def get_metrics(graphs, 
				transforms, 
				families,
				mode = 'TWSBM',
				q_outliers = 0.01,
				large_graph = False):
	
	if mode not in ['TWSBM', 'TDCSBM']:
		raise ValueError("Mode must be either 'TWSBM' or 'TDCSBM'")
	
	if 'P-1.00' not in [t.id for t in transforms]:
		raise ValueError("Transforms must include the identity transform 'P-1.00'")
	
	X_transformation = 'theta' # can be None, 'normalised', 'theta' or 'score'

	loop_graphs 	= tqdm(graphs, desc="Graphs") if not large_graph else graphs
	loop_transforms = tqdm(transforms, desc="Transforms") if large_graph else transforms
	loop_families 	= tqdm(families, desc="Families") if large_graph else families

	records = []
	for A, Z, K in loop_graphs:
		row = {}
		for j, t in enumerate(loop_transforms):
			if mode == 'TWSBM':
				I = TWSBM(A = t(A), Z = Z, K = K, transformation = t, q_outliers = q_outliers)
			else: # mode == 'TDCSBM'
				I = TDCSBM(A = t(A), Z = Z, K = K, X_transformation = X_transformation)

			row[f't{j}_ARI'] = I.ARI
			row[f't{j}_gĈ']  = I.gĈ_embed if mode == 'TWSBM' else np.nan
			row[f't{j}_X_A'] = I.X_A

			if t.id == 'P-1.00':
				row['baseline_ARI'] = I.ARI

		for i, fam in enumerate(loop_families):
			X_list = []
			for tj in fam:
				X_list.append(row[f't{tj}_X_A'])
			X_stacked = np.concatenate(X_list, axis=1)
			if mode == 'TWSBM':
				GMM = TWSBM.fit_gmm(X_stacked, K=K, q_outliers=0)
				Z_hat, *_ = TWSBM.get_gmm_estimates(GMM, X_stacked)
			else: # mode == 'TDCSBM'
				Z_hat = TDCSBM.fit_predict_egmm(X_stacked, K=K, X_transformation=X_transformation)
			row[f'f{i}_ARI'] = ARI(Z, Z_hat)

		for j in range(len(transforms)):
			del row[f't{j}_X_A']

		records.append(row)

	df = pd.DataFrame.from_records(records)
	df_transforms = df[[f't{j}_ARI' for j in range(len(transforms))]].copy()
	df_transforms = df_transforms.mean()
	
	df_families = pd.DataFrame()
	for i, fam in enumerate(families):
		fam_transforms_ARI = [f"t{i}_ARI" for i in fam]
		fam_transform_gĈ   = [f"t{i}_gĈ"  for i in fam]
		argmax_gĈ = df[fam_transform_gĈ].to_numpy().argmax(axis=1)
	
		df_families[f"f{i}_max"]    = df[fam_transforms_ARI].max(axis=1)
		df_families[f"f{i}_median"] = df[fam_transforms_ARI].median(axis=1)
		df_families[f"f{i}_mean"]   = df[fam_transforms_ARI].mean(axis=1)
		if mode == 'TWSBM':
			df_families[f"f{i}_ARI_max_ĉ"] = df[fam_transforms_ARI].to_numpy()[np.arange(len(df)), argmax_gĈ]
		else: # mode == 'TDCSBM'
			df_families[f"f{i}_ARI_max_ĉ"] = np.nan
		df_families[f'f{i}_ARI_stacked'] = df[f'f{i}_ARI']

	df_families = df_families.mean()

	A, Z, K = graphs[0]
	n = A.shape[0]

	metadata = {
		'n': n,
		'K': K,
		'samples': len(graphs),
		'baseline_ARI': df['baseline_ARI'].mean()
	}

	return df_transforms, df_families, metadata

# Reorganize df_families to have family names as index
def reorganize_df_families(df_families):
	df = df_families.copy()
	df = df.reset_index()
	df.columns = ["key", "value"]
	df[["family", "metric"]] = df["key"].str.extract(r"(f\d+)_(.*)")

	df_pivot = df.pivot(index="family", columns="metric", values="value")

	df_pivot = df_pivot.rename(columns={
		"max": "Max",
		"mean": "Mean",
		"median": "Median",
		"ARI_max_ĉ": "Max Ĉ",
		"ARI_stacked": "Stacked"
	})
	df_pivot.index = list(FAMILIES.keys())
	df_pivot.index.name = "family" # type: ignore

	return df_pivot

# Reorganize df_transforms to have transform ids as index
def reorganize_df_transforms(df_transforms):

	df = df_transforms.copy().reset_index()
	df.columns = ["transform_metric", "ARI"]

	df["transform"] = df["transform_metric"].str.extract(r"(t\d+)")
	df = df.set_index("transform")[["ARI"]]
	df.index = [t.id for t in TRANSFORMS]
	df.index.name = "transform" # type: ignore
	
	return df

### Synthetic graphs generation

In [None]:
N_sample = 2
N_grid   = 3

# Create a grid of N points (x, y) in [xmin, xmax] x [ymin, ymax]
def flattened_grid(xmin, xmax, ymin, ymax, N, order='C'):
	xs = np.linspace(xmin, xmax, N)
	ys = np.linspace(ymin, ymax, N)
	X, Y = np.meshgrid(xs, ys, indexing='xy')
	return list(zip(X.ravel(order=order), Y.ravel(order=order))) #type: ignore

beta_models, lognorm_models = [], []
beta_dc_models, lognorm_dc_models = [], []
μ = LognormWSBM.mu_for_quantile_at_zero(σ=1, quantile=0.99)

for ρ, π, p in ...

for ρ in [0.1, 
		  0.25, 
		  0.5,
		  ]:
	for π in [0.1, 
		   	  0.25, 
			  0.5,
			  ]:
		for p in flattened_grid(0.1, 0.9, 0.1, 0.9, N_grid):

			beta_models.append(BetaWSBM(K=2, 
							   ρ=ρ,
							   π=np.array([π, 1-π]), 
							   n=1000, 
							   α=np.array([[p[0], p[1]], [p[1], 1]])))
			
			lognorm_models.append(LognormWSBM(K=2, 
								  ρ=ρ,
								  π=np.array([π, 1-π]), 
								  n=1000, 
								  Σ=np.array([[p[0], p[1]], [p[1], 1]]),
								  μ=μ))

			beta_dc_models.append(WDCSBM(K=2,
								  H=np.array([[stats.beta(a=p[0], b=1), stats.beta(a=p[1], b=1)], 
											  [stats.beta(a=p[1], b=1), stats.beta(a=1, b=1)]]),
								  G=np.array([stats.beta(a=2, b=1), stats.beta(a=2, b=1)]),
								  π=np.array([π, 1-π]),
								  n=1000))
			
			lognorm_dc_models.append(WDCSBM(K=2,
									 H=np.array([[stats.lognorm(s=p[0], scale=np.exp(μ)), 
					   							  stats.lognorm(s=p[1], scale=np.exp(μ))], 
												 [stats.lognorm(s=p[1], scale=np.exp(μ)), 
			  									  stats.lognorm(s=1,    scale=np.exp(μ))]]),
									 G=np.array([stats.beta(a=2, b=1), stats.beta(a=2, b=1)]),
									 π=np.array([π, 1-π]),
									 n=1000))

beta_graphs       = [(*model.sample(seed), 2) 		for model in beta_models 		for seed in range(N_sample)]
lognorm_graphs    = [(*model.sample(seed), 2) 		for model in lognorm_models 	for seed in range(N_sample)]
beta_dc_graphs    = [(*model.sample(seed)[:2], 2) 	for model in beta_dc_models 	for seed in range(N_sample)]
lognorm_dc_graphs = [(*model.sample(seed)[:2], 2) 	for model in lognorm_dc_models 	for seed in range(N_sample)]

### Real graphs loading

In [4]:
# Load real-world dataset
def load_dataset(name, base_dir="data"):
	folder = Path(base_dir) / name
	meta = json.load(open(folder / "metadata.json"))
	if meta.get("sparse", False):
		A = sp.load_npz(folder / "adjacency.npz")
	else:
		A = np.load(folder / "adjacency.npy")
	Z = np.load(folder / "labels.npy")
	return A, Z, meta


# main loop
base_dir = Path(project_root) / 'data' / 'raw' / 'real_world_graphs'

real_world_graphs = []
real_world_names  = []

for ds_folder in base_dir.iterdir():
	if not ds_folder.is_dir():
		continue
	name = ds_folder.name
	print(f"Processing dataset: {name}")
	A, Z, meta = load_dataset(name, base_dir=base_dir) #type: ignore
	K = len(np.unique(Z))
	n = A.shape[0]
	print(f"n = {n}, K = {K}\n")
	#	print(f"{name} is being added to real_world_graphs")
	real_world_names.append(name)
	real_world_graphs.append((A, Z, K))

combined = list(zip(real_world_graphs, real_world_names))
combined.sort(key=lambda x: x[0][0].shape[0], reverse=False)
real_world_graphs, real_world_names = map(list, zip(*combined))

Processing dataset: cifar10
n = 50000, K = 10

Processing dataset: fashionmnist
n = 70000, K = 10

Processing dataset: high_school_2011
n = 118, K = 3

Processing dataset: high_school_2012
n = 180, K = 5

Processing dataset: high_school_2013
n = 327, K = 9

Processing dataset: mnist
n = 70000, K = 10

Processing dataset: primary_school
n = 232, K = 10

Processing dataset: workplace_2013
n = 88, K = 4

Processing dataset: workplace_2015
n = 161, K = 5



### Computation

In [None]:
df_dict_synthetic = { name : {} for name in ["Beta WSBM", "LogN WSBM", "Beta DCSBM", "LogN DCSBM"] }

for name, graphs, mode in list(zip(["Beta WSBM", "LogN WSBM", "Beta DCSBM", "LogN DCSBM"],
						[beta_graphs, lognorm_graphs, beta_dc_graphs, lognorm_dc_graphs],
						['TWSBM', 'TWSBM', 'TDCSBM', 'TDCSBM'])):
	
	print(f"Processing {name} graphs...")
	df_transforms, df_families, metadata = get_metrics(
		graphs,
		TRANSFORMS,
		list(FAMILIES_IDX.values()),
		mode=mode
		)
	
	df_dict_synthetic[name]['df_families']   = reorganize_df_families(df_families)
	df_dict_synthetic[name]['df_transforms'] = reorganize_df_transforms(df_transforms)
	df_dict_synthetic[name]['metadata']      = metadata

data_dir = os.path.join(project_root, 'data/processed/graphs')
os.makedirs(data_dir, exist_ok=True)
fn = os.path.join(data_dir, 'df_dict_synthetic.pkl')

with open(fn, 'wb') as f:
	pickle.dump(df_dict_synthetic, f)


Processing Beta WSBM graphs...


Graphs: 100%|██████████| 162/162 [27:59<00:00, 10.37s/it]


Processing LogN WSBM graphs...


Graphs: 100%|██████████| 162/162 [28:14<00:00, 10.46s/it]


Processing Beta DCSBM graphs...


  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.ne

Processing LogN DCSBM graphs...


  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.ne

In [5]:
df_dict_real = { f'{name} {mode}' : {} for name, mode in product(real_world_names, ['TWSBM', 'TDCSBM']) }

for (name, graph), mode in product(list(zip(real_world_names, real_world_graphs)), ['TWSBM', 'TDCSBM']):
	
	name = f'{name} {mode}'
	print(f"Processing {name} graph...")
	df_transforms, df_families, metadata = get_metrics(
		[graph],
		TRANSFORMS,
		list(FAMILIES_IDX.values()),
		mode=mode,
		q_outliers=0,
		large_graph=True
		)
	
	df_dict_real[name]['df_families']   = reorganize_df_families(df_families)
	df_dict_real[name]['df_transforms'] = reorganize_df_transforms(df_transforms)
	df_dict_real[name]['metadata']      = metadata

data_dir = os.path.join(project_root, 'data/processed/graphs')
os.makedirs(data_dir, exist_ok=True)
fn = os.path.join(data_dir, 'df_dict_real.pkl')

with open(fn, 'wb') as f:
	pickle.dump(df_dict_real, f)

Processing workplace_2013 TWSBM graph...


Transforms: 100%|██████████| 46/46 [00:03<00:00, 14.18it/s]
Families: 100%|██████████| 9/9 [00:03<00:00,  2.44it/s]


Processing workplace_2013 TDCSBM graph...


Transforms: 100%|██████████| 46/46 [00:10<00:00,  4.38it/s]
Families: 100%|██████████| 9/9 [00:12<00:00,  1.34s/it]


Processing high_school_2011 TWSBM graph...


Transforms: 100%|██████████| 46/46 [00:01<00:00, 23.19it/s]
Families: 100%|██████████| 9/9 [00:02<00:00,  4.09it/s]


Processing high_school_2011 TDCSBM graph...


Transforms: 100%|██████████| 46/46 [00:06<00:00,  7.04it/s]
Families: 100%|██████████| 9/9 [00:07<00:00,  1.16it/s]


Processing workplace_2015 TWSBM graph...


Transforms: 100%|██████████| 46/46 [00:03<00:00, 15.26it/s]
Families: 100%|██████████| 9/9 [00:03<00:00,  2.72it/s]


Processing workplace_2015 TDCSBM graph...


Transforms: 100%|██████████| 46/46 [00:08<00:00,  5.15it/s]
Families: 100%|██████████| 9/9 [00:15<00:00,  1.69s/it]


Processing high_school_2012 TWSBM graph...


Transforms: 100%|██████████| 46/46 [00:07<00:00,  6.33it/s]
Families: 100%|██████████| 9/9 [00:07<00:00,  1.15it/s]


Processing high_school_2012 TDCSBM graph...


Transforms: 100%|██████████| 46/46 [00:17<00:00,  2.69it/s]
  responsibilities = np.divide(num, den[:, np.newaxis])
Families: 100%|██████████| 9/9 [00:19<00:00,  2.21s/it]


Processing primary_school TWSBM graph...


Transforms: 100%|██████████| 46/46 [00:17<00:00,  2.59it/s]
Families: 100%|██████████| 9/9 [00:18<00:00,  2.06s/it]


Processing primary_school TDCSBM graph...


Transforms: 100%|██████████| 46/46 [00:35<00:00,  1.28it/s]
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
Families: 100%|██████████| 9/9 [00:45<00:00,  5.11s/it]


Processing high_school_2013 TWSBM graph...


Transforms: 100%|██████████| 46/46 [00:19<00:00,  2.32it/s]
Families: 100%|██████████| 9/9 [00:20<00:00,  2.32s/it]


Processing high_school_2013 TDCSBM graph...


  responsibilities = np.divide(num, den[:, np.newaxis])
Transforms: 100%|██████████| 46/46 [00:22<00:00,  2.03it/s]
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
Families: 100%|██████████| 9/9 [00:29<00:00,  3.29s/it]


Processing cifar10 TWSBM graph...


Transforms: 100%|██████████| 46/46 [23:16<00:00, 30.36s/it]
Families: 100%|██████████| 9/9 [56:56<00:00, 379.61s/it]


Processing cifar10 TDCSBM graph...


Transforms: 100%|██████████| 46/46 [1:08:07<00:00, 88.85s/it] 
  responsibilities = np.divide(num, den[:, np.newaxis])
Families: 100%|██████████| 9/9 [1:37:13<00:00, 648.13s/it]


Processing fashionmnist TWSBM graph...


Transforms: 100%|██████████| 46/46 [20:30<00:00, 26.75s/it]
Families: 100%|██████████| 9/9 [40:04<00:00, 267.21s/it]


Processing fashionmnist TDCSBM graph...


Transforms: 100%|██████████| 46/46 [1:19:01<00:00, 103.07s/it]
  responsibilities = np.divide(num, den[:, np.newaxis])
  responsibilities = np.divide(num, den[:, np.newaxis])
Families: 100%|██████████| 9/9 [1:47:14<00:00, 714.90s/it]


Processing mnist TWSBM graph...


Transforms: 100%|██████████| 46/46 [07:10<00:00,  9.35s/it]
Families: 100%|██████████| 9/9 [18:33<00:00, 123.74s/it]


Processing mnist TDCSBM graph...


Transforms: 100%|██████████| 46/46 [1:40:26<00:00, 131.00s/it]  
  responsibilities = np.divide(num, den[:, np.newaxis])
Families: 100%|██████████| 9/9 [2:13:55<00:00, 892.83s/it]
