In [None]:
import numpy as np

import fairnmf
from utils import load_dataset, exp_standard_NMF, exp_Fairer_NMF, make_plots, exp_MU_convergence

In [None]:
#np.random.seed(42)

SAVE_PLOT = True
SAVE_DATA = False
LOAD_DATA = True

In [None]:
data_20n = load_dataset('20n')

In [None]:
if not LOAD_DATA:

    s_losses_20n, s_errors_20n, s_grperr_20n = exp_standard_NMF(data_20n)
    
    if SAVE_DATA:
        np.save(f"data/20n_s_losses.npy", s_losses_20n)
        np.save(f"data/20n_s_errors.npy", s_errors_20n)
        np.save(f"data/20n_s_grperr.npy", s_grperr_20n)
        
    f_losses_20n_mu, f_errors_20n_mu, f_times_20n_mu = exp_Fairer_NMF(data_20n, 'mu')

    if SAVE_DATA:
        np.save(f"data/20n_f_losses_MU.npy", f_losses_20n_mu)
        np.save(f"data/20n_f_errors_MU.npy", f_errors_20n_mu)
        np.save(f"data/20n_f_times_MU.npy",  f_times_20n_mu)
    
    f_losses_20n_am, f_errors_20n_am, f_times_20n_am = exp_Fairer_NMF(data_20n, 'am')
    
    if SAVE_DATA:
        np.save(f"data/20n_f_losses_AM.npy", f_losses_20n_am)
        np.save(f"data/20n_f_errors_AM.npy", f_errors_20n_am)
        np.save(f"data/20n_f_times_AM.npy",  f_times_20n_am)

In [None]:
if LOAD_DATA:
    s_losses_20n    = np.load("data/20n_s_losses.npy")
    s_errors_20n    = np.load("data/20n_s_errors.npy")
    s_grperr_20n    = np.load("data/20n_s_grperr.npy")

    f_losses_20n_mu = np.load("data/20n_f_losses_MU.npy")
    f_errors_20n_mu = np.load("data/20n_f_errors_MU.npy")
    f_times_20n_mu  = np.load("data/20n_f_times_MU.npy")

    f_losses_20n_am = np.load("data/20n_f_losses_AM.npy")
    f_errors_20n_am = np.load("data/20n_f_errors_AM.npy")
    f_times_20n_am  = np.load("data/20n_f_times_AM.npy")

In [None]:
if SAVE_PLOT:
    file_names = ['plots/20n_snmf_loss.pdf', 'plots/20n_fnmf_loss_MU.pdf', 'plots/20n_fnmf_loss_AM.pdf']
else:
    file_names = None

make_plots(data_20n, [s_losses_20n, f_losses_20n_mu, f_losses_20n_am], 'Loss', file_names)

In [None]:
if SAVE_PLOT:
    file_names = ['plots/20n_snmf_err.pdf', 'plots/20n_snmf_grperr.pdf', 'plots/20n_fnmf_err_MU.pdf', 'plots/20n_fnmf_err_AM.pdf']
else:
    file_names = None
    
make_plots(data_20n, [s_errors_20n, s_grperr_20n, f_errors_20n_mu, f_errors_20n_am], 'Relative Error (%)', file_names)

In [None]:
data_hd = load_dataset("hd")

In [None]:
if not LOAD_DATA:

    s_losses_hd, s_errors_hd, s_grperr_hd = exp_standard_NMF(data_hd)

    if SAVE_DATA:
        np.save(f"data/hd_s_losses.npy", s_losses_hd)
        np.save(f"data/hd_s_errors.npy", s_errors_hd)
        np.save(f"data/hd_s_grperr.npy", s_grperr_hd)
    
    f_losses_hd_mu, f_errors_hd_mu, f_times_hd_mu = exp_Fairer_NMF(data_hd, 'mu')

    
    if SAVE_DATA:
        np.save(f"data/hd_f_losses_MU.npy", f_losses_hd_mu)
        np.save(f"data/hd_f_errors_MU.npy", f_errors_hd_mu)
        np.save(f"data/hd_f_times_MU.npy",  f_times_hd_mu)
    
    f_losses_hd_am, f_errors_hd_am, f_times_hd_am = exp_Fairer_NMF(data_hd, 'am')
    
    if SAVE_DATA:
        np.save(f"data/hd_f_losses_AM.npy", f_losses_hd_am)
        np.save(f"data/hd_f_errors_AM.npy", f_errors_hd_am)
        np.save(f"data/hd_f_times_AM.npy",  f_times_hd_am)

In [None]:
if LOAD_DATA:
    s_losses_hd    = np.load("data/hd_s_losses.npy")
    s_errors_hd    = np.load("data/hd_s_errors.npy")
    s_grperr_hd    = np.load("data/hd_s_grperr.npy")

    f_losses_hd_mu = np.load("data/hd_f_losses_MU.npy")
    f_errors_hd_mu = np.load("data/hd_f_errors_MU.npy")
    f_times_hd_mu  = np.load("data/hd_f_times_MU.npy")

    f_losses_hd_am = np.load("data/hd_f_losses_AM.npy")
    f_errors_hd_am = np.load("data/hd_f_errors_AM.npy")
    f_times_hd_am  = np.load("data/hd_f_times_AM.npy")

In [None]:
if SAVE_PLOT:
    file_names = ['plots/hd_snmf_loss.pdf', 'plots/hd_fnmf_loss_MU.pdf', 'plots/hd_fnmf_loss_AM.pdf']
else:
    file_names = None

make_plots(data_hd, [s_losses_hd, f_losses_hd_mu, f_losses_hd_am], 'Loss', file_names)

In [None]:
if SAVE_PLOT:
    file_names = ['plots/hd_snmf_err.pdf', 'plots/hd_snmf_grperr.pdf', 'plots/hd_fnmf_err_MU.pdf', 'plots/hd_fnmf_err_AM.pdf']
else:
    file_names = None
    
make_plots(data_hd, [s_errors_hd, s_grperr_hd, f_errors_hd_mu, f_errors_hd_am], 'Relative Error (%)', file_names)

In [None]:
data_syn = load_dataset("syn")

In [None]:
if not LOAD_DATA:

    s_losses_syn, s_errors_syn, s_grperr_syn = exp_standard_NMF(data_syn)

    if SAVE_DATA:
        np.save(f"data/syn_s_losses.npy", s_losses_syn)
        np.save(f"data/syn_s_errors.npy", s_errors_syn)
        np.save(f"data/syn_s_grperr.npy", s_grperr_syn)
    
    f_losses_syn_mu, f_errors_syn_mu, f_times_syn_mu = exp_Fairer_NMF(data_syn, 'mu')

    if SAVE_DATA:
        np.save(f"data/syn_f_losses_MU.npy", f_losses_syn_mu)
        np.save(f"data/syn_f_errors_MU.npy", f_errors_syn_mu)
        np.save(f"data/syn_f_times_MU.npy",  f_times_syn_mu)
    
    f_losses_syn_am, f_errors_syn_am, f_times_syn_am = exp_Fairer_NMF(data_syn, 'am')
    
    if SAVE_DATA:
        np.save(f"data/syn_f_losses_AM.npy", f_losses_syn_am)
        np.save(f"data/syn_f_errors_AM.npy", f_errors_syn_am)
        np.save(f"data/syn_f_times_AM.npy",  f_times_syn_am)

In [None]:
if LOAD_DATA:
    s_losses_syn    = np.load("data/syn_s_losses.npy")
    s_errors_syn    = np.load("data/syn_s_errors.npy")
    s_grperr_syn    = np.load("data/syn_s_grperr.npy")

    f_losses_syn_mu = np.load("data/syn_f_losses_MU.npy")
    f_errors_syn_mu = np.load("data/syn_f_errors_MU.npy")
    f_times_syn_mu  = np.load("data/syn_f_times_MU.npy")

    f_losses_syn_am = np.load("data/syn_f_losses_AM.npy")
    f_errors_syn_am = np.load("data/syn_f_errors_AM.npy")
    f_times_syn_am  = np.load("data/syn_f_times_AM.npy")

In [None]:
if SAVE_PLOT:
    file_names = ['plots/syn_snmf_loss.pdf', 'plots/syn_fnmf_loss_MU.pdf', 'plots/syn_fnmf_loss_AM.pdf']
else:
    file_names = None

make_plots(data_syn, [s_losses_syn, f_losses_syn_mu, f_losses_syn_am], 'Loss', file_names)

In [None]:
if SAVE_PLOT:
    file_names = ['plots/syn_snmf_err.pdf', 'plots/syn_snmf_grperr.pdf', 'plots/syn_fnmf_err_MU.pdf', 'plots/syn_fnmf_err_AM.pdf']
else:
    file_names = None
    
make_plots(data_syn, [s_errors_syn, s_grperr_syn, f_errors_syn_mu, f_errors_syn_am], 'Relative Error (%)', file_names)

In [None]:
data_ortho = load_dataset("ortho")

In [None]:
if not LOAD_DATA:

    s_losses_ortho, s_errors_ortho, s_grperr_ortho = exp_standard_NMF(data_ortho)

    if SAVE_DATA:
        np.save(f"data/ortho_s_losses.npy", s_losses_ortho)
        np.save(f"data/ortho_s_errors.npy", s_errors_ortho)
        np.save(f"data/ortho_s_grperr.npy", s_grperr_ortho)
    
    f_losses_ortho_mu, f_errors_ortho_mu, f_times_ortho_mu = exp_Fairer_NMF(data_ortho, 'mu')

    if SAVE_DATA:
        np.save(f"data/ortho_f_losses_MU.npy", f_losses_ortho_mu)
        np.save(f"data/ortho_f_errors_MU.npy", f_errors_ortho_mu)
        np.save(f"data/ortho_f_times_MU.npy",  f_times_ortho_mu)
    
    f_losses_ortho_am, f_errors_ortho_am, f_times_ortho_am = exp_Fairer_NMF(data_ortho, 'am')
    
    if SAVE_DATA:
        np.save(f"data/ortho_f_losses_AM.npy", f_losses_ortho_am)
        np.save(f"data/ortho_f_errors_AM.npy", f_errors_ortho_am)
        np.save(f"data/ortho_f_times_AM.npy",  f_times_ortho_am)

In [None]:
if LOAD_DATA:
    s_losses_ortho    = np.load("data/ortho_s_losses.npy")
    s_errors_ortho    = np.load("data/ortho_s_errors.npy")
    s_grperr_ortho    = np.load("data/ortho_s_grperr.npy")

    f_losses_ortho_mu = np.load("data/ortho_f_losses_MU.npy")
    f_errors_ortho_mu = np.load("data/ortho_f_errors_MU.npy")
    f_times_ortho_mu  = np.load("data/ortho_f_times_MU.npy")

#    f_losses_ortho_am = np.load("data/ortho_f_losses_AM.npy")
#    f_errors_ortho_am = np.load("data/ortho_f_errors_AM.npy")
#    f_times_ortho_am  = np.load("data/ortho_f_times_AM.npy")

In [None]:
if SAVE_PLOT:
    file_names = [
        'plots/ortho_snmf_loss.pdf',
        'plots/ortho_fnmf_loss_MU.pdf',
#        'plots/ortho_fnmf_loss_AM.pdf',
    ]
else:
    file_names = None


losses = [
    s_losses_ortho,
    f_losses_ortho_mu,
#    f_losses_ortho_am,
]

make_plots(data_ortho, losses, 'Loss', file_names)

In [None]:
if SAVE_PLOT:
    file_names = [
        'plots/ortho_snmf_err.pdf',
        'plots/ortho_snmf_grperr.pdf',
        'plots/ortho_fnmf_err_MU.pdf',
#        'plots/ortho_fnmf_err_AM.pdf'.
    ]
else:
    file_names = None

errors = [
    s_errors_ortho,
    s_grperr_ortho,
    f_errors_ortho_mu,
#    f_errors_ortho_am
]
    
make_plots(data_ortho, errors, 'Relative Error (%)', file_names)

## MU convergence plots

In [None]:
data_ortho1 = load_dataset("ortho1")
rank_ortho1 = 3

In [None]:
convergence_losses_ortho1_decaying_lr, convergence_ortho1_errors_decaying_lr = exp_MU_convergence(data_ortho1, rank_ortho1, 'decaying_lr')
convergence_losses_ortho1_largest, convergence_ortho1_errors_largest = exp_MU_convergence(data_ortho1, rank_ortho1, 'largest')

In [None]:
if SAVE_PLOT:
    file_names = [
        'plots/convergence_losses_ortho1_decaying_lr.pdf',
        'plots/convergence_losses_ortho1_largest.pdf',
    ]
else:
    file_names = None

losses = [
    convergence_losses_ortho1_decaying_lr,
    convergence_losses_ortho1_largest,
]
    
make_plots(data_ortho1, losses, 'Loss', file_names, xaxis='iter', use_markers=False)

In [None]:
data_ortho2 = load_dataset("ortho2")
rank_ortho2 = 5

In [None]:
convergence_losses_ortho2_decaying_lr, convergence_ortho2_errors_decaying_lr = exp_MU_convergence(data_ortho2, rank_ortho2, 'decaying_lr')
convergence_losses_ortho2_largest, convergence_ortho2_errors_largest = exp_MU_convergence(data_ortho2, rank_ortho2, 'largest')

In [None]:
if SAVE_PLOT:
    file_names = [
        'plots/convergence_losses_ortho2_decaying_lr.pdf',
        'plots/convergence_losses_ortho2_largest.pdf',
    ]
else:
    file_names = None

losses = [
    convergence_losses_ortho2_decaying_lr,
    convergence_losses_ortho2_largest,
]
    
make_plots(data_ortho2, losses, 'Loss', file_names, xaxis='iter', use_markers=False)