In [None]:
import glob

def plot_dipole_comps_vs_Lambdas(#fn_comps_data, fns_comps_mocks, 
                                 case_dict_data, case_dict_mock,
                                 title='', fn_fig=None,
                                 dir_results_mocks=None):

    RESULTDIR = '/scratch/aew492/lss-dipoles_results'

    dir_results_data = os.path.join(RESULTDIR, 'results/results_data')
    if dir_results_mocks is None:
        dir_results_mocks = os.path.join(RESULTDIR, 'results/results_mocks')

    # Load data
    fn_comps_data = os.path.join(dir_results_data, f"dipole_comps_Lambdas_{case_dict_data['catalog_name']}{case_dict_data['tag']}.npy")
    result_dict = np.load(fn_comps_data, allow_pickle=True).item()
    dipole_comps_data = result_dict['dipole_comps']
    Lambdas_data = result_dict['Lambdas']

    # Load mock
    Lambdas_mocks = []
    dipole_amps_mocks = []
    pattern = f"{dir_results_mocks}/dipole_comps_lambdas*{case_dict_mock['tag']}*.npy"
    print(f"looking for {pattern}...")
    fn_comps_mock = glob.glob(pattern)
    n_trials = len(fn_comps_mock)
    print(f"found {n_trials} files with this pattern")

    # not necessary in order, careful!
    for i in range(n_trials):
        result_dict = np.load(fn_comps_mock[i],  allow_pickle=True).item()
        dipole_amps_mock = np.linalg.norm(result_dict['dipole_comps'], axis=-1)
        dipole_amps_mocks.append(dipole_amps_mock)
        Lambdas_mocks.append(result_dict['Lambdas'])
    dipole_amps_mocks = np.array(dipole_amps_mocks)

    # Compute the norm of the dipole components for the actual data
    dipole_amps_data = np.linalg.norm(dipole_comps_data, axis=1)

    # Plot data
    plt.figure(figsize=(10, 6))
    plt.plot(Lambdas_data, dipole_amps_data, lw=3, color='k', label='Data', zorder=100)

    # Plot each mock trial with light red lines
    for i in range(n_trials):        
        label = 'Mock' if i==0 else ''
        plt.plot(Lambdas_mocks[i], dipole_amps_mocks[i], color='lightcoral', linewidth=0.5,
                 label=label)

    # Plot the input dipole amplitude for the mocks
    plt.axhline(case_dict_mock['dipole_amp'], color='red', ls='--', alpha=0.8, label='Input dipole amp.')

    # Plot the mean of the mock data with a dark red line
    dipole_amps_mock_mean = np.mean(dipole_amps_mocks, axis=0)
    plt.plot(Lambdas_mocks[0], dipole_amps_mock_mean, color='red', linewidth=3, 
             label=f'Mean of {n_trials} mocks', zorder=10)
    # Also plot the 1sigma
    dipole_amps_mock_std = np.std(dipole_amps_mocks, axis=0)
    # plt.plot(Lambdas_mocks[0], dipole_amps_mock_mean - dipole_amps_mock_std,
    #                     color='red', alpha=0.5)
    # plt.plot(Lambdas_mocks[0], dipole_amps_mock_mean + dipole_amps_mock_std,
    #                     color='red', alpha=0.5)
    plt.fill_between(Lambdas_mocks[0], dipole_amps_mock_mean - dipole_amps_mock_std,
                        dipole_amps_mock_mean + dipole_amps_mock_std,
                        color='darkorange', alpha=0.4, label=r'1$\sigma$')

    # Adding grid, labels and legend
    plt.grid(alpha=0.5, lw=0.5)
    plt.xscale('log')

    plt.xlabel(r'$\Lambda$')
    plt.ylabel(r'$\mathcal{D}$, dipole amplitude')
    plt.title(title)
    plt.legend()

In [None]:
import glob

def get_dipole_comps_mock(case_dict_mock, dir_results_mocks=None):

    RESULTDIR = '/scratch/aew492/lss-dipoles_results'

    dir_results_data = os.path.join(RESULTDIR, 'results/results_data')
    if dir_results_mocks is None:
        dir_results_mocks = os.path.join(RESULTDIR, 'results/results_mocks')

    # Load mock
    Lambdas_mocks = []
    dipole_amps_mocks = []
    pattern = f"{dir_results_mocks}/dipole_comps_lambdas*{case_dict_mock['tag']}*.npy"
    print(f"looking for {pattern}...")
    fn_comps_mock = glob.glob(pattern)
    n_trials = len(fn_comps_mock)
    print(f"found {n_trials} files with this pattern")

    if n_trials==0:
        return np.nan
    
    # not necessary in order, careful!
    for i in range(n_trials):
        result_dict = np.load(fn_comps_mock[i],  allow_pickle=True).item()
        dipole_amps_mock = np.linalg.norm(result_dict['dipole_comps'], axis=-1)
        dipole_amps_mocks.append(dipole_amps_mock)
        Lambdas_mocks.append(result_dict['Lambdas'])
    dipole_amps_mocks = np.array(dipole_amps_mocks)

    dipole_amps_mock_mean = np.mean(dipole_amps_mocks, axis=0)
    # for now, grab the smallest lambda
    i_minlambda = np.argmin(result_dict['Lambdas'])
    dipole_amp_minlambda = dipole_amps_mock_mean[i_minlambda]
    
    #dipole_amps_mock_std = np.std(dipole_amps_mocks, axis=0)
    return dipole_amp_minlambda

In [None]:
def get_dipole_comps_data():
    RESULTDIR = '/scratch/aew492/lss-dipoles_results'

    dir_results_data = os.path.join(RESULTDIR, 'results/results_data')
    # Load data
    fn_comps_data = os.path.join(dir_results_data, f"dipole_comps_Lambdas_{case_dict_data['catalog_name']}{case_dict_data['tag']}.npy")
    result_dict = np.load(fn_comps_data, allow_pickle=True).item()
    dipole_comps_data = result_dict['dipole_comps']
    dipole_amps_data = np.linalg.norm(dipole_comps_data, axis=1)

    i_minlambda = np.argmin(result_dict['Lambdas'])
    dipole_amp_minlambda = dipole_amps_data[i_minlambda]
    
    #dipole_amps_mock_std = np.std(dipole_amps_mocks, axis=0)
    return dipole_amp_minlambda

In [None]:
recovered_dipole_amps = []
for case_dict in case_dicts_grid:
    dipole_amps_mock_mean = get_dipole_comps_mock(case_dict, dir_results_mocks=dir_results_mocks)
    recovered_dipole_amps.append(dipole_amps_mock_mean)
recovered_dipole_amps = np.array(recovered_dipole_amps)

In [None]:
dipole_amp_data = get_dipole_comps_data()
dipole_amp_data

In [None]:
plt.scatter(dipole_amps, excesses, c=recovered_dipole_amps, s=130)
plt.yscale('log')
plt.colorbar(label='recovered dipole amplitude')
plt.xlabel('input dipole amplitude')
plt.ylabel('input excess power')

In [None]:
plt.figure(figsize=(6,6))
plt.scatter(dipole_amps, recovered_dipole_amps, c=np.log10(excesses), s=130, label='Mocks')
plt.plot(dipole_amps, dipole_amps, color='grey')

plt.axhline(dipole_amp_data, label='CatWISE, measured', color='black', ls='--')

plt.xlim(0.9*np.min(dipole_amps), np.max(dipole_amps))
plt.ylim(0.9*np.min(dipole_amps), np.max(dipole_amps))

plt.legend()
plt.colorbar(label='input excess power (log)')
plt.xlabel('input dipole amplitude')
plt.ylabel('recovered dipole amplitude')
