In [2]:
import pickle
import os
import pandas as pd
from src.utils.additional_plotting_functions import loss_archetype_plot


In [26]:
def load_result_obj(path: str):
    file = open(path,'rb')
    object_file = pickle.load(file)
    file.close()
    return object_file

def load_analyses(analysis_dir: str):
    """
    Function that loads results from a given analysis.
    The format is a nested dictionary on the form results[AA_method][n_archetypes][repetition_num]
    The result objects saved have all matrices and parameters inside them. E
    """
    folder = f'synthetic_results/{analysis_dir}'
    results = {'RBOAA': {}, 'OAA': {}, 'CAA': {}} if 'OSM' not in analysis_dir else {'TSAA': {}}

    for method in results.keys():
        method_dir = f'{folder}/{method}_objects'
        all_files = os.listdir(method_dir)
        for file in all_files:
            obj = load_result_obj(f'{method_dir}/{file}')
            K = int(file.split('_')[1][-1])
            rep = int(file.split('_')[-1][-1])
            if f'K{K}' not in results[method].keys():
                results[method][f'K{K}'] = {}
            
            results[method][f'K{K}'][rep] = obj
    return results


In [27]:
complex_corr = load_analyses('complex_corrupted_results')

In [28]:
print('Example of extracting results')
print("A matrix (S): ", complex_corr['RBOAA']['K0'][0].A.shape)
print("B matrix (C): ", complex_corr['RBOAA']['K0'][0].B.shape)
print("archetype matrix: ", complex_corr['RBOAA']['K0'][0].Z.shape)
print("betas: ", complex_corr['RBOAA']['K0'][0].b.shape)

Example of extracting results
A matrix (S):  (50, 1000)
B matrix (C):  (1000, 50)
archetype matrix:  (20, 50)
betas:  (1000, 4)


In [29]:
complex_corr['CAA']['K0'][0].A

array([[2.92391883e-06, 9.68056270e-07, 2.98350733e-05, ...,
        1.22192898e-06, 8.57570626e-07, 1.20411607e-06],
       [5.02369119e-07, 6.59679301e-07, 9.05763898e-07, ...,
        2.95551331e-06, 1.22838856e-06, 3.95413764e-07],
       [1.35461528e-06, 4.35321954e-06, 1.96922383e-06, ...,
        6.36297613e-02, 5.02805449e-07, 6.93182374e-05],
       ...,
       [2.23997426e-06, 3.94997895e-01, 4.64582627e-06, ...,
        5.34586012e-02, 1.86042905e-06, 7.13271618e-07],
       [1.11571126e-01, 1.14502927e-06, 1.00055240e-05, ...,
        1.22106462e-06, 1.98472344e-06, 3.62871134e-07],
       [6.60232445e-06, 1.73802600e-05, 4.26272936e-06, ...,
        1.04532883e-04, 6.57081227e-06, 1.51987376e-06]], dtype=float32)

In [32]:
complex_corr['RBOAA']['K0'][0]

<src.utils.AA_result_class._OAA_result at 0x2606877dff0>

In [34]:
OAA_test = load_result_obj('synthetic_results/TestMHA/OAA_objects/RBOAA_K=1_rep=0')