In [83]:
import sys
sys.path.append('../code')
sys.path.append('../simulation')
import FOT
from data_generation import *
from scipy.spatial.distance import cdist
import ot
from tqdm import trange
import pandas as pd

In [86]:
# -------------------------------
# Simulation config
# -------------------------------
rng = None
B = 3
block_perm = {b: (b + 1) % B for b in range(B)}
N = B**(2*B-1)
a = b = np.ones((N,)) / N
pout = 0.02
T = 100
rep = 5
t_heat = 1

# Hyperparameters
alphas = [0.0, 0.5, 1.0]
sigmas = [0.0, 1.0]              # σ² = 0, 1만 사용
dists = [
    ("geodesic", "gaussian"),
    ("heat", "Id")
]

# 결과 저장용
records = []

# -------------------------------
# Main simulation loop
# -------------------------------
for sigma in sigmas:
    for alpha in alphas:
        for dist_type in dists:

            prop_node_acc_list = []
            prop_block_acc_list = []
            fgw_node_acc_list = []
            fgw_block_acc_list = []

            for i in trange(rep, desc=f"sigma={sigma}, alpha={alpha}, dist={dist_type}"):
                
                # --- Graph generation ---
                A_X, z, theta, pin_b = sample_dc_sbm(
                    N, B, pin=None, pout=pout, similar_pair=None, delta=None,
                    rng=rng, within_mode='grid'
                )
                perm = blockwise_permutation(
                    z, block_perm=block_perm, rng=rng,
                    shuffle_within_source=True, shuffle_within_target=True
                )
                P_true = perm_matrix(perm)
                A_Y = P_true.T @ A_X @ P_true

                # --- Features ---
                FX, mu = make_block_features(z, d=B, margin=1, noise=sigma, rng=rng)
                FY = P_true.T @ FX
                C_f = cdist(FX, FY, metric="sqeuclidean")

                # --- Structural kernels ---
                if dist_type == ("heat", "Id"):
                    KX = heat_kernel_from_adj(A_X, t=t_heat, normalize=True, order=2)
                    KY = heat_kernel_from_adj(A_Y, t=t_heat, normalize=True, order=2)
                elif dist_type == ("geodesic", "gaussian"):
                    DX = all_pairs_geodesic(A_X, weighted=False)
                    DY = all_pairs_geodesic(A_Y, weighted=False)
                    KX = FOT.kappa_decreasing_exp(DX, p=2)
                    KY = FOT.kappa_decreasing_exp(DY, p=2)
                else:
                    DX = all_pairs_geodesic(A_X, weighted=False)
                    DY = all_pairs_geodesic(A_Y, weighted=False)
                    KX, KY = DX, DY

                # --- Proposed method ---
                model = FOT.ConvexFusedTransport(
                    alpha=alpha,
                    fw_max_iter=T,
                    fw_stepsize='classic',
                    tol=1e-40,
                    lmo_method='emd',
                    pre_Cf=C_f,
                    pre_DX=KX,
                    pre_DY=KY
                ).fit(X=np.array(z).reshape(-1, 1),
                      Y=np.array(z).reshape(-1, 1),
                      FX=None, FY=None,
                      return_hard_assignment=True)

                # --- FGW baseline (only if alpha > 0) ---
                if alpha > 0:
                    T_FGW = ot.gromov.fused_gromov_wasserstein(
                        C_f, KX, KY, a, b, alpha=alpha, loss_fun='square_loss'
                    )

                # --- Accuracy ---
                pred_idx = np.argmax(model.P_, axis=1)
                true_idx = np.argmax(P_true, axis=1)

                prop_node_acc = np.mean(pred_idx == true_idx)
                prop_block_acc = np.mean(z[pred_idx] == z[true_idx])
                prop_node_acc_list.append(prop_node_acc)
                prop_block_acc_list.append(prop_block_acc)

                if alpha > 0:
                    fgw_pred_idx = np.argmax(T_FGW, axis=1)
                    fgw_node_acc = np.mean(fgw_pred_idx == true_idx)
                    fgw_block_acc = np.mean(z[fgw_pred_idx] == z[true_idx])
                    fgw_node_acc_list.append(fgw_node_acc)
                    fgw_block_acc_list.append(fgw_block_acc)

            # --- Aggregate results ---
            rec = {
                "sigma2": sigma,
                "alpha": alpha,
                "distance": dist_type[0],
                "ours_node": np.mean(prop_node_acc_list),
                "ours_block": np.mean(prop_block_acc_list)
            }
            if alpha > 0:
                rec["fgw_node"] = np.mean(fgw_node_acc_list)
                rec["fgw_block"] = np.mean(fgw_block_acc_list)
            else:
                rec["fgw_node"] = np.nan
                rec["fgw_block"] = np.nan

            records.append(rec)

sigma=0.0, alpha=0.0, dist=('geodesic', 'gaussian'): 100%|█| 5/5 [00:04<00:00,  
sigma=0.0, alpha=0.0, dist=('heat', 'Id'): 100%|██| 5/5 [00:03<00:00,  1.29it/s]
sigma=0.0, alpha=0.5, dist=('geodesic', 'gaussian'): 100%|█| 5/5 [00:04<00:00,  
sigma=0.0, alpha=0.5, dist=('heat', 'Id'): 100%|██| 5/5 [00:05<00:00,  1.07s/it]
sigma=0.0, alpha=1.0, dist=('geodesic', 'gaussian'): 100%|█| 5/5 [00:05<00:00,  
sigma=0.0, alpha=1.0, dist=('heat', 'Id'): 100%|██| 5/5 [00:03<00:00,  1.44it/s]
sigma=1.0, alpha=0.0, dist=('geodesic', 'gaussian'): 100%|█| 5/5 [00:02<00:00,  
sigma=1.0, alpha=0.0, dist=('heat', 'Id'): 100%|██| 5/5 [00:02<00:00,  1.71it/s]
sigma=1.0, alpha=0.5, dist=('geodesic', 'gaussian'): 100%|█| 5/5 [00:03<00:00,  
sigma=1.0, alpha=0.5, dist=('heat', 'Id'): 100%|██| 5/5 [00:04<00:00,  1.13it/s]
sigma=1.0, alpha=1.0, dist=('geodesic', 'gaussian'): 100%|█| 5/5 [00:06<00:00,  
sigma=1.0, alpha=1.0, dist=('heat', 'Id'): 100%|██| 5/5 [00:09<00:00,  1.89s/it]


In [87]:
# -------------------------------
# Convert to DataFrame and view
# -------------------------------
df = pd.DataFrame(records)
print(df.round(3))

    sigma2  alpha  distance  ours_node  ours_block  fgw_node  fgw_block
0      0.0    0.0  geodesic      0.012       1.000       NaN        NaN
1      0.0    0.0      heat      0.009       1.000       NaN        NaN
2      0.0    0.5  geodesic      1.000       1.000     0.995      1.000
3      0.0    0.5      heat      1.000       1.000     0.016      1.000
4      0.0    1.0  geodesic      1.000       1.000     0.974      0.988
5      0.0    1.0      heat      0.002       0.205     0.006      0.360
6      1.0    0.0  geodesic      1.000       1.000       NaN        NaN
7      1.0    0.0      heat      1.000       1.000       NaN        NaN
8      1.0    0.5  geodesic      1.000       1.000     1.000      1.000
9      1.0    0.5      heat      1.000       1.000     1.000      1.000
10     1.0    1.0  geodesic      1.000       1.000     0.987      0.990
11     1.0    1.0      heat      0.002       0.257     0.003      0.360
