In [None]:
import numpy as np
import pandas as pd
import math
import json
from tqdm import tqdm

import os
pjoin = os.path.join

from matplotlib import pyplot as plt
plt.ioff()

In [None]:
import sys
sys.path.append('../')

from utils import (
    create_if_no_exist, 
    load_configs, 
    plot_hypnogram,
    intersect,
)

from orp_utils.const import *

In [None]:
ds_name = 'MESA'
channel = 'EEG EEG3'
results_version = 'lookup-table-v1'
orp_version = 'v1'
subjects_version = 'v1FOLD0'
dset = VAL # TRAIN (construct new table) / VAL (only ranking)

# To rank the power using all subjects in dset (ALL_SUBJ) or one subject (BY_SUBJ)
running_type = ALL_SUBJ # ALL_SUBJ / BY_SUBJ

In [None]:
global_path = f'{ds_name}/extracted/{orp_version}/{channel}'
csv_path = pjoin(global_path, 'csv')

subjects_path = f'{ds_name}/subjects/{subjects_version}.npz'
global_path, subjects_path

In [None]:
if dset == TRAIN and running_type == BY_SUBJ:
    ans = str(input('Are you sure you want to construct lookup table for each subject separately? (y/n) : '))
    if ans.lower() == 'y':
        pass
    else:
        raise Exception('''Change `dset` to VAL for running validation set (without constructing new table) 
        OR Change `running_type` to `ALL_SUBJ` to construct lookup table from all subjects.''')
        

In [None]:
def ranking(
        df, 
        col,
        current_rank_df = [],
        fig_path = '',
    ): 
    
    plt.style.use('ggplot')
    plt.rcParams['axes.facecolor']='white'
    plt.rcParams.update({'font.size': 22})
    
    fig, ax = plt.subplots(figsize=(10, 7))
    SUBJ_ID_COL = 'subj_id'
    SAMPLE_ID_COL = 'sample_id'
    ngroups = 10
    nepochs = len(df)
    nepochs_per_group = int(nepochs / ngroups)
    # print('nepochs_per_group:', nepochs_per_group)
    
    values = df[col].values
    deciles = [np.percentile(values, p) for p in range(10, 101, 10)]
    #print(col, 'deciles:', deciles)
    
    #print(np.histogram(values))    
    bins = np.arange(0.0, 1.1, 0.1)
    freq, bins, patches = plt.hist(values, bins, rwidth=0.75, color='#5278B7')
    bin_centers = np.diff(bins)*0.5 + bins[:-1]
    nsamples = len(values)
    
    n = 0
    for fr, x, patch in zip(freq, bin_centers, patches):
        height = int(freq[n])
        plt.annotate("{:.1f}".format(height / nsamples * 100),
                   xy = (x, height),             # top left corner of the histogram bar
                   xytext = (0, 0.2),             # offsetting label position above its bar
                   textcoords = "offset points", # Offset (in points) from the *xy* value
                   ha = 'center', va = 'bottom'
                   )
        n += 1
        
    plt.ylabel('% of 3s samples')
    plt.xlabel('Relative power spectrum')
        
    plt.title(f'{ds_name} | {col} ({nsamples} samples)', y=1.1)
    plt.xlim(-0.01, 1.01)
    plt.ylim(-0.1, nsamples + 1)
    
    if len(fig_path) > 0:
        plt.savefig(pjoin(fig_path, f'{dset}_{col}.jpg'))
    else:
        plt.show()
        
    plt.clf()
    
    """
    print(np.histogram(values))
    plt.hist(values[values > 20])
    plt.title(col)
    plt.show()
    """
    
    RANK_COL = f'rank_{col}'
    BIN_RANK_COL = f'{BIN_NUMBER_COL}_{col}'
    
    rank_df = df.sort_values(by=[col])
    rank_df[RANK_COL] = range(nepochs)
    rank_df[BIN_RANK_COL] = [math.floor(r) for r in rank_df[RANK_COL] / nepochs_per_group]
    # print(np.unique(rank_df[BIN_RANK_COL], return_counts=True))

    rank_df = rank_df.sort_values(by=[SUBJ_ID_COL, SAMPLE_ID_COL])
     
    if len(current_rank_df) == 0:
        return rank_df
    
    else:
        # append column to the existing one
        for c in [RANK_COL, BIN_RANK_COL]:
            current_rank_df[c] = rank_df[c]
            
        return current_rank_df

In [None]:
data = np.load(subjects_path)
files = data[dset]

len(files), files

In [None]:
df = pd.DataFrame()
for idx, subject in enumerate(files):

    print('subject:', subject)
    outpath = pjoin(global_path, 'results', f'subjects-{subjects_version}')
    
    if running_type == BY_SUBJ:
        outpath = pjoin(outpath, subject, results_version)
    else:
        outpath = pjoin(outpath, results_version)
        
    out_csv = pjoin(outpath, f'rank_powers_{dset}.csv')
    out_json = pjoin(outpath, 'lookup-table.json')
    create_if_no_exist(outpath)
    
    # read power and labels from csv
    subj_df = pd.read_csv(pjoin(csv_path, subject+'.csv'))

    if running_type == ALL_SUBJ:
        df = df.append(subj_df)
    else:
        df = subj_df
        
    if running_type == BY_SUBJ or (running_type == ALL_SUBJ and idx == len(files) - 1):
        
        if running_type == BY_SUBJ:
            assert len(df) == len(subj_df)
    
        col = 'delta'
        rank_df = ranking(df, col, fig_path = outpath)

        col = 'theta'
        rank_df = ranking(df, col, rank_df, fig_path = outpath)
        
        col = 'alpha-sigma'
        rank_df = ranking(df, col, rank_df, fig_path = outpath)

        col = 'beta'
        rank_df = ranking(df, col, rank_df, fig_path = outpath)
        
        # combine 4 digits
        rank_df[BIN_NUMBER_COL] = rank_df[f'{BIN_NUMBER_COL}_delta'].map(str) + rank_df[f'{BIN_NUMBER_COL}_theta'].map(str) + rank_df[f'{BIN_NUMBER_COL}_alpha-sigma'].map(str) + rank_df[f'{BIN_NUMBER_COL}_beta'].map(str)
        bin_numbers_rank, n_per_bin = np.unique(rank_df[BIN_NUMBER_COL], return_counts=True)

        bin_numbers = ['{:04d}'.format(b) for b in range(10000)]
        print(len(bin_numbers), 'bin numbers:', bin_numbers[:10], '...')

        display(rank_df)
        assert all([b in bin_numbers for b in bin_numbers_rank])
        
        print(f'saved ranking into {out_csv}')
        rank_df.to_csv(out_csv)
        

        # contruct look-up dict { bin_number: ORP }

        if dset == TRAIN:
            ans = 'c'
            if os.path.exists(out_json):
                ans = str(input(f'{out_json} already exists. Type `l` to load or `c` to construct the new table.'))

            if ans == 'l':
                print(f'loading lookup table from {out_json}...')

                with open(out_json, 'r') as j:
                    lookup_dict = json.load(j)

            elif ans == 'c':
                print(f'constructing new lookup table...')
                lookup_dict = {}
                for b_num in bin_numbers:
                    n_total = len(rank_df.loc[rank_df[BIN_NUMBER_COL] == b_num])
                    n_awake = len(rank_df.loc[(rank_df[BIN_NUMBER_COL] == b_num) & (rank_df[WAKE_SLEEP_COL] == AWAKE)])

                    if n_total < 10:
                        prob = 0.5
                    else:
                        prob = n_awake / n_total

                    print(b_num, 'total:', n_total, 'awake:', n_awake, 'prob:', prob)
                    lookup_dict[b_num] = prob

                with open(out_json, "w") as fout:
                    json.dump(lookup_dict, fout, indent=4)

            else:
                raise Exception('invalid answer.')

            print(len(lookup_dict.keys()))
            print(lookup_dict)
            print(out_json)

        else:
            print('ranking was done without constructing new table.')

    
    del subj_df



In [None]:
print('done')