In [2]:
import sys
sys.path.append('../code')
sys.path.append('../simulation')
import FOT
from data_generation import *
from scipy.spatial.distance import cdist
import ot
from tqdm import trange
import pandas as pd

In [10]:
def random_coupling_with_marginals(p, q, max_iter=1000, tol=1e-9, seed=None):
    rng = np.random.default_rng(seed)
    K = rng.random((len(p), len(q)))
    u = np.ones_like(p)
    v = np.ones_like(q)
    for _ in range(max_iter):
        u_prev = u.copy()
        u = p / (K @ v)
        v = q / (K.T @ u)
        if np.linalg.norm(u - u_prev, 1) < tol:
            break
    P = np.diag(u) @ K @ np.diag(v)
    P /= P.sum()  # optional normalization
    return P

# -------------------------------
# Simulation config
# -------------------------------
rng = None
B = 3
block_perm = {b: (b + 1) % B for b in range(B)}
N = B**(2*B-1)
a = b = np.ones((N,)) / N
pout = 0.01
T = 100
rep = 10
t_heat = 1

# Hyperparameters
alphas = [0.0, 0.5, 1.0]
sigmas = [0.0, 1.0]              # σ² = 0, 1만 사용
dists = ["geodesic","heat","diffusion"]
kernels = ["Id"]
inits = ['random','independent']

# 결과 저장용
records = []

# -------------------------------
# Main simulation loop
# -------------------------------
for sigma in sigmas:
    for alpha in alphas:
        for dist in dists:
            for kernel in kernels:
                for init in inits:

                    prop_node_acc_list = []
                    prop_block_acc_list = []
                    fgw_node_acc_list = []
                    fgw_block_acc_list = []
    
                    for i in trange(rep, desc=f"sigma={sigma}, alpha={alpha}, dist={dist}, kernel={kernel}, init={init}"):
    
                        # --- Graph generation ---
                        A_X, z, theta, pin_b = sample_dc_sbm(
                            N, B, pin=None, pout=pout, similar_pair=None, delta=None,
                            rng=rng, within_mode='grid'
                        )
                        perm = blockwise_permutation(
                            z, block_perm=block_perm, rng=rng,
                            shuffle_within_source=True, shuffle_within_target=True
                        )
                        P_true = perm_matrix(perm)
                        A_Y = P_true.T @ A_X @ P_true
    
                        # --- Features ---
                        FX, mu = make_block_features(z, d=B, margin=1, noise=sigma, rng=rng)
                        FY = P_true.T @ FX
                        C_f = cdist(FX, FY, metric="sqeuclidean")
    
                        # --- Structural kernels ---
                        if dist == "diffusion":
                            DX_rw = heat_kernel_from_adj(A_X, t=t_heat, lap='rw', method='expm')
                            DY_rw = heat_kernel_from_adj(A_Y, t=t_heat, lap='rw', method='expm')
                            DX = diffusion_distance_matrix(DX_rw)
                            DY = diffusion_distance_matrix(DY_rw)
                            # DX,DY = DX_rw,DY_rw
                        elif dist == "heat":
                            DX_sym = heat_kernel_from_adj(A_X, t=t_heat, lap='sym', method='taylor', order=2)
                            DY_sym = heat_kernel_from_adj(A_Y, t=t_heat, lap='sym', method='taylor', order=2)
                            DX = rkhs_distance_matrix_from_kernel(DX_sym) 
                            DY = rkhs_distance_matrix_from_kernel(DY_sym)
                            # DX,DY = DX_sym,DY_sym
                        elif dist == "geodesic":
                            DX = all_pairs_geodesic(A_X, weighted=False)
                            DY = all_pairs_geodesic(A_Y, weighted=False)
    
                        
                        if kernel == "gaussian":
                            KX = FOT.kappa_decreasing_exp(DX, p=2)
                            KY = FOT.kappa_decreasing_exp(DY, p=2)                        
                        elif kernel == "Id":
                            KX, KY = DX, DY
                        elif kernel == "linear_cutoff":
                            KX = FOT.kappa_linear_cutoff(DX,k=1)
                            KY = FOT.kappa_linear_cutoff(DY,k=1)
                        
                        if init == "independent":
                            pi = np.outer(a,b)
                        elif init == "random":
                            pi = random_coupling_with_marginals(a,b)
    
                        # --- Proposed method ---
                        model = FOT.ConvexFusedTransport(
                            alpha=alpha,
                            fw_max_iter=T,
                            fw_stepsize='classic',
                            tol=1e-40,
                            lmo_method='emd',
                            pre_Cf=C_f,
                            pre_DX=KX,
                            pre_DY=KY
                        ).fit(X=np.array(z).reshape(-1, 1),
                              Y=np.array(z).reshape(-1, 1),
                              FX=None, FY=None, init=pi,
                              return_hard_assignment=True)
    
                        # --- FGW baseline (only if alpha > 0) ---
                        if alpha > 0:
                            T_FGW = ot.gromov.fused_gromov_wasserstein(
                                C_f, KX, KY, G0=pi, alpha=alpha, loss_fun='square_loss'
                            )
    
                        # --- Accuracy ---
                        pred_idx = np.argmax(model.P_, axis=1)
                        true_idx = np.argmax(P_true, axis=1)
    
                        prop_node_acc = np.mean(pred_idx == true_idx)
                        prop_block_acc = np.mean(z[pred_idx] == z[true_idx])
                        prop_node_acc_list.append(prop_node_acc)
                        prop_block_acc_list.append(prop_block_acc)
    
                        if alpha > 0:
                            fgw_pred_idx = np.argmax(T_FGW, axis=1)
                            fgw_node_acc = np.mean(fgw_pred_idx == true_idx)
                            fgw_block_acc = np.mean(z[fgw_pred_idx] == z[true_idx])
                            fgw_node_acc_list.append(fgw_node_acc)
                            fgw_block_acc_list.append(fgw_block_acc)
    
                    # --- Aggregate results ---
                    rec = {
                        "sigma2": sigma,
                        "alpha": alpha,
                        "distance": dist,
                        "kernel": kernel,
                        "init": init,
                        "ours_node": np.mean(prop_node_acc_list),
                        "ours_block": np.mean(prop_block_acc_list)
                    }
                    if alpha > 0:
                        rec["fgw_node"] = np.mean(fgw_node_acc_list)
                        rec["fgw_block"] = np.mean(fgw_block_acc_list)
                    else:
                        rec["fgw_node"] = np.nan
                        rec["fgw_block"] = np.nan
    
                    records.append(rec)

sigma=0.0, alpha=0.0, dist=geodesic, kernel=Id, init=random: 100%|█████████████████████| 10/10 [00:09<00:00,  1.04it/s]
sigma=0.0, alpha=0.0, dist=geodesic, kernel=Id, init=independent: 100%|████████████████| 10/10 [00:09<00:00,  1.02it/s]
sigma=0.0, alpha=0.0, dist=heat, kernel=Id, init=random: 100%|█████████████████████████| 10/10 [00:09<00:00,  1.02it/s]
sigma=0.0, alpha=0.0, dist=heat, kernel=Id, init=independent: 100%|████████████████████| 10/10 [00:09<00:00,  1.02it/s]
sigma=0.0, alpha=0.0, dist=diffusion, kernel=Id, init=random: 100%|████████████████████| 10/10 [00:16<00:00,  1.62s/it]
sigma=0.0, alpha=0.0, dist=diffusion, kernel=Id, init=independent: 100%|███████████████| 10/10 [00:16<00:00,  1.65s/it]
sigma=0.0, alpha=0.5, dist=geodesic, kernel=Id, init=random: 100%|█████████████████████| 10/10 [00:15<00:00,  1.50s/it]
sigma=0.0, alpha=0.5, dist=geodesic, kernel=Id, init=independent: 100%|████████████████| 10/10 [00:14<00:00,  1.48s/it]
sigma=0.0, alpha=0.5, dist=heat, kernel=

In [11]:
# -------------------------------
# Convert to DataFrame and view
# -------------------------------
df = pd.DataFrame(records)
print(df)

    sigma2  alpha   distance kernel         init  ours_node  ours_block  \
0      0.0    0.0   geodesic     Id       random   0.011111         1.0   
1      0.0    0.0   geodesic     Id  independent   0.009465         1.0   
2      0.0    0.0       heat     Id       random   0.013169         1.0   
3      0.0    0.0       heat     Id  independent   0.012757         1.0   
4      0.0    0.0  diffusion     Id       random   0.010288         1.0   
5      0.0    0.0  diffusion     Id  independent   0.011934         1.0   
6      0.0    0.5   geodesic     Id       random   1.000000         1.0   
7      0.0    0.5   geodesic     Id  independent   1.000000         1.0   
8      0.0    0.5       heat     Id       random   1.000000         1.0   
9      0.0    0.5       heat     Id  independent   1.000000         1.0   
10     0.0    0.5  diffusion     Id       random   1.000000         1.0   
11     0.0    0.5  diffusion     Id  independent   1.000000         1.0   
12     0.0    1.0   geode