In [None]:
import numpy as np
import pandas as pd
from scipy.optimize import curve_fit

datasets = {
    'Cora': {
        'GCN': [(100, 0.45331318974494933, 0.024756455654897046),
                (500, 0.23856463730335237, 0.014240488795972095),
                (1000, 0.16623241305351258, 0.007939122345212673)],
        'GAT': [(100, 0.4247232496738434, 0.03526640259761619),
                (500, 0.2265682637691498, 0.01985088532073069),
                (1000, 0.17245431542396544, 0.009573755414000502)],
        'GraphSAGE': [(100, 0.5282434463500977, 0.02662578778883726),
                      (500, 0.2611788988113403, 0.011081628608493049),
                      (1000, 0.17424136102199556, 0.01182381363000765)]
    },
    'Reddit': {
        'GCN': [(100, 0.6638607740402221, 0.04893704996814368),
                (500, 0.39872702956199646, 0.031492316429588926),
                (1000, 0.25180703699588775, 0.027561228403015604),
                (5000, 0.11390651613473893, 0.005034993087847129),
                (10000, 0.09257184416055679, 0.0022101293663993942),
                (50000, 0.07923970818519592, 0.0012036254204641356)],
        'GAT': [(100, 0.74683119058609, 0.07390229955504421),
                (500, 0.7327924728393554, 0.032305080962742136),
                (1000, 0.6507743716239929, 0.03551110132146038),
                (5000, 0.41393967866897585, 0.03768536853473231),
                (10000, 0.3230024456977844, 0.04470648755372255),
                (50000, 0.1848956435918808, 0.019780347600618014)],
        'GraphSAGE': [(100, 0.7809877634048462, 0.03528722600583161),
                      (500, 0.5848131775856018, 0.019288915605109015),
                      (1000, 0.452128005027771, 0.02889391719691232),
                      (5000, 0.20689032971858978, 0.0011023282183319496),
                      (10000, 0.15035432875156401, 0.002545760565421577),
                      (50000, 0.07758742719888687, 0.00033193818017408463)]
    },
    'QM9': {
        'GCN': [(100, 2.010820424542159, 0.10295508270121849),
                (500, 1.8500684509941885, 0.05619516912623903),
                (1000, 1.7505244721264597, 0.05373075402225386),
                (5000, 1.5587026009932707, 0.04819121632473715),
                (10000, 1.5350490051916001, 0.03816020820016992),
                (50000, 1.5022996501628227, 0.018802712780239785)],
        'GAT': [(100, 1.1261, 0.2953),
                (500, 1.0342, 0.1615),
                (1000, 0.9506, 0.0545),
                (5000, 0.9287, 0.0138),
                (10000, 0.8948, 0.0240),
                (50000, 0.9009, 0.0140)],
        'GraphSAGE': [(100, 2.166804970126863, 0.2843937350060123),
                      (500, 1.7563000858777307, 0.12584696522988958),
                      (1000, 1.624144569206937, 0.0746565879268105),
                      (5000, 1.4869281925242805, 0.03545359297761349),
                      (10000, 1.457223077500945, 0.023084938119480234),
                      (50000, 1.430181429788361, 0.04495243871141515)]
    },
    'Facebook': {
        'GCN': [(100, 0.1146, 0.0103),
                (500, 0.0719, 0.0032),
                (1000, 0.0555, 0.0080)],
        'GAT': [(100, 0.1671, 0.0249),
                (500, 0.0966, 0.0088),
                (1000, 0.0777, 0.0086)],
        'GraphSAGE': [(100, 0.14081752453760424, 0.018460682369960364),
                      (500, 0.05991271242726692, 0.01133794980311604),
                      (1000, 0.0428319653077297, 0.005503937151818257)]
    }
}

# Fitting functions
def fit_linear_model(x, err, std):
    """
    Fit a weighted linear model using inverse variance weights

    Args:
        x: Independent variable (transformation of n)
        err: Error measurements
        std: Standard deviations of error measurements

    Returns:
        tuple: (intercept, slope, weighted RSS, weighted MSE, weighted R^2)
    """
    A = np.vstack([np.ones_like(x), x]).T
    weights = 1 / std**2  # Using inverse variance as weights
    W = np.diag(weights)

    # Weighted least squares
    theta = np.linalg.inv(A.T @ W @ A) @ A.T @ W @ err
    fit = A @ theta

    # Weighted metrics
    weighted_residuals = weights * (err - fit)**2
    weighted_rss = np.sum(weighted_residuals)
    weighted_mse = np.mean(weighted_residuals)

    # Weighted R^2 calculation
    weighted_mean = np.sum(weights * err) / np.sum(weights)
    weighted_total_ss = np.sum(weights * (err - weighted_mean)**2)
    weighted_r2 = 1 - weighted_rss / weighted_total_ss

    return theta[0], theta[1], weighted_rss, weighted_mse, weighted_r2

def fit_power_law(n, err, std):
    """
    Fit a power law model using weighted curve_fit

    Args:
        n: Sample sizes
        err: Error measurements
        std: Standard deviations of error measurements

    Returns:
        tuple: (gamma, c, weighted RSS, weighted MSE, weighted R^2)
    """
    def func(n, gamma, c):
        return c + 1 / n**gamma

    # Use weights based on standard deviations, with no restrictive bounds
    weights = 1 / std**2
    popt, pcov = curve_fit(func, n, err, sigma=std, absolute_sigma=True, bounds=(0, np.inf))
    gamma, c = popt

    # Calculate fit
    fit = func(n, gamma, c)

    # Weighted metrics
    weighted_residuals = weights * (err - fit)**2
    weighted_rss = np.sum(weighted_residuals)
    weighted_mse = np.mean(weighted_residuals)

    # Weighted R^2
    weighted_mean = np.sum(weights * err) / np.sum(weights)
    weighted_total_ss = np.sum(weights * (err - weighted_mean)**2)
    weighted_r2 = 1 - weighted_rss / weighted_total_ss

    return gamma, c, weighted_rss, weighted_mse, weighted_r2

# Gather all results
rows = []
for dataset, models in datasets.items():
    for model, data in models.items():
        n, err, std = zip(*data)
        n, err, std = np.array(n), np.array(err), np.array(std)

        # Fit models with consistent weighting
        c_sqrt, b_sqrt, rss_sqrt, mse_sqrt, r2_sqrt = fit_linear_model(1 / np.sqrt(n), err, std)
        c_n, b_n, rss_n, mse_n, r2_n = fit_linear_model(1 / n, err, std)
        c_log, b_log, rss_log, mse_log, r2_log = fit_linear_model(1 / np.log(n), err, std)
        gamma, c_gamma, rss_gamma, mse_gamma, r2_gamma = fit_power_law(n, err, std)

        rows.append({
            'Dataset': dataset,
            'Model': model,
            'c_1/sqrt(n)': c_sqrt,
            'b_1/sqrt(n)': b_sqrt,
            'RSS_1/sqrt(n)': rss_sqrt,
            'MSE_1/sqrt(n)': mse_sqrt,
            'R2_1/sqrt(n)': r2_sqrt,
            'c_1/n': c_n,
            'b_1/n': b_n,
            'RSS_1/n': rss_n,
            'MSE_1/n': mse_n,
            'R2_1/n': r2_n,
            'c_1/log(n)': c_log,
            'b_1/log(n)': b_log,
            'RSS_1/log(n)': rss_log,
            'MSE_1/log(n)': mse_log,
            'R2_1/log(n)': r2_log,
            'c_1/n^gamma': c_gamma,
            'RSS_1/n^gamma': rss_gamma,
            'MSE_1/n^gamma': mse_gamma,
            'R2_1/n^gamma': r2_gamma,
            'gamma': gamma
        })

df = pd.DataFrame(rows)

# Find best fit for each model & dataset based on R^2
df['Best_Fit'] = df[['R2_1/sqrt(n)', 'R2_1/n', 'R2_1/log(n)', 'R2_1/n^gamma']].idxmax(axis=1)
df['Best_Fit'] = df['Best_Fit'].str.replace('R2_', '')

# Generate LaTeX table
latex_code = "\\begin{table}[ht]\n\\centering\n"
latex_code += "\\caption{Comparison of Fit Metrics Across All Models and Datasets (Weighted Analysis)}\n"
latex_code += ("\\begin{tabular}{ll"
               "|ccc|ccc|ccc|ccc|c|c}\n\\hline\n")
latex_code += ("\\multirow{2}{*}{Dataset} & \\multirow{2}{*}{Model} "
               "& \\multicolumn{3}{c|}{$c_1 + \\frac{\\alpha}{\\sqrt{n}}$} "
               "& \\multicolumn{3}{c|}{$c_2 + \\frac{\\beta}{n}$} "
               "& \\multicolumn{3}{c|}{$c_3 + \\frac{\\delta}{\\log n}$} "
               "& \\multicolumn{3}{c|}{$c_4 + \\frac{1}{n^{\\gamma}}$} & $\\gamma$ & Best Fit \\\\\n")
latex_code += (" & & RSS & MSE & R$^2$"
               " & RSS & MSE & R$^2$"
               " & RSS & MSE & R$^2$"
               " & RSS & MSE & R$^2$ & & \\\\\n\\hline\n")

for _, row in df.iterrows():
    latex_code += (f"{row['Dataset']} & {row['Model']} & "
                   f"{row['RSS_1/sqrt(n)']:.2e} & {row['MSE_1/sqrt(n)']:.2e} & {row['R2_1/sqrt(n)']:.3f} & "
                   f"{row['RSS_1/n']:.2e} & {row['MSE_1/n']:.2e} & {row['R2_1/n']:.3f} & "
                   f"{row['RSS_1/log(n)']:.2e} & {row['MSE_1/log(n)']:.2e} & {row['R2_1/log(n)']:.3f} & "
                   f"{row['RSS_1/n^gamma']:.2e} & {row['MSE_1/n^gamma']:.2e} & {row['R2_1/n^gamma']:.3f} & "
                   f"{row['gamma']:.3f} & {row['Best_Fit']} \\\\\n")

latex_code += "\\hline\n\\end{tabular}\n"
latex_code += "\\label{tab:comparison_all_models_weighted}\n\\end{table}\n"

# Save to .tex file
with open("final_comparison_table_weighted.tex", "w") as f:
    f.write(latex_code)

print("✅ LaTeX table saved as final_comparison_table_weighted.tex")

# Also create a summary dataframe with just the best fit information and gamma values
summary_df = df[['Dataset', 'Model', 'Best_Fit', 'gamma']]
print("\nSummary of best fits:")
print(summary_df)

# Create a scatter plot to visually check fits for each dataset and model
import matplotlib.pyplot as plt

def create_fit_plots():
    for dataset, models in datasets.items():
        fig, axs = plt.subplots(1, len(models), figsize=(15, 5), sharey=True)
        fig.suptitle(f'Model Fits for {dataset} Dataset')

        for i, (model_name, data) in enumerate(models.items()):
            ax = axs[i] if len(models) > 1 else axs
            n, err, std = zip(*data)
            n, err, std = np.array(n), np.array(err), np.array(std)

            # Get fit parameters from dataframe
            model_row = df[(df['Dataset'] == dataset) & (df['Model'] == model_name)].iloc[0]

            # Create dense x range for smooth curve plotting
            x_dense = np.linspace(min(n), max(n), 100)

            # Plot data points with error bars
            ax.errorbar(n, err, yerr=std, fmt='o', label='Data')

            # Plot fits
            c1, b1 = model_row['c_1/sqrt(n)'], model_row['b_1/sqrt(n)']
            ax.plot(x_dense, c1 + b1/np.sqrt(x_dense), '--', label=r'$c + \frac{\alpha}{\sqrt{n}}$')

            c2, b2 = model_row['c_1/n'], model_row['b_1/n']
            ax.plot(x_dense, c2 + b2/x_dense, '-.', label=r'$c + \frac{\beta}{n}$')

            c3, b3 = model_row['c_1/log(n)'], model_row['b_1/log(n)']
            ax.plot(x_dense, c3 + b3/np.log(x_dense), ':', label=r'$c + \frac{\delta}{\log{n}}$')

            gamma, c4 = model_row['gamma'], model_row['c_1/n^gamma']
            ax.plot(x_dense, c4 + 1/x_dense**gamma, '-', label=r'$c + \frac{1}{n^{\gamma}}$')

            ax.set_title(f'{model_name}')
            ax.set_xlabel('Sample Size (n)')
            if i == 0:
                ax.set_ylabel('Error')
            ax.set_xscale('log')
            ax.grid(True)
            ax.legend()

        plt.tight_layout()
        plt.savefig(f"{dataset}_fits.png", dpi=300)
        plt.close()

try:
    create_fit_plots()
    print("✅ Created visualization plots for all datasets")
except Exception as e:
    print(f"Error creating plots: {e}")

✅ LaTeX table saved as final_comparison_table_weighted.tex

Summary of best fits:
     Dataset      Model   Best_Fit      gamma
0       Cora        GCN   1/log(n)   0.242669
1       Cora        GAT   1/log(n)   0.245456
2       Cora  GraphSAGE   1/log(n)   0.220236
3     Reddit        GCN  1/sqrt(n)   0.304053
4     Reddit        GAT   1/log(n)   0.120367
5     Reddit  GraphSAGE   1/log(n)   0.226867
6        QM9        GCN   1/log(n)   0.129481
7        QM9        GAT  1/sqrt(n)   0.360203
8        QM9  GraphSAGE  1/sqrt(n)   0.145765
9   Facebook        GCN   1/log(n)  10.794176
10  Facebook        GAT   1/log(n)  10.796341
11  Facebook  GraphSAGE  1/sqrt(n)   0.449356
✅ Created visualization plots for all datasets
