## Synthetic dataset generation
Author: Lin Lee Cheong <br>
Date: 12/12/ 2020 <br>

Goal of this synthetic dataset is to create datasets to help understand how different relationships between tokens affect attention, SHAP and other interpretability factors.
- length of events (30, 300, 900)
- spacing between 2+ coupled events, i.e. order of sequence matters
- amount of noise, i.e. performance vs interpretability
- vocabulary space

In [1]:
import yaml
import string
import random
import os
import pandas as pd

In [2]:
token_names_fp = './tokens.yaml'
with open(token_names_fp, 'r') as tokens:
    tokens = yaml.safe_load(tokens)

for key in tokens.keys():
    print(f"{key}: {len(tokens[key])} tokens")

adverse_tokens: 4 tokens
adverse_helper_tokens: 6 tokens
adverse_unhelper_tokens: 5 tokens
noise_tokens: 15 tokens


In [3]:
output_file = './train_ae_prob1.csv'
uid_colname = 'patient_id'
uid_len = 10

seq_len = 30

Base functions

In [29]:
def get_uid(uid_len):
    return ''.join(random.choices(string.ascii_uppercase + string.digits, k=uid_len))

def get_idx_tok(seq_len, token_dict, token_key, n_pairs):
    return ([(random.choices(range(seq_len), k=1)[0], random.choices(token_dict[token_key], k=1)[0]) 
             for _ in range(n_pairs)])

def get_tokens(seq_len, token_dict, token_key, n_tokens):
    return random.choices(token_dict[token_key], k=n_tokens)

def get_label(prob_label, target):
    return target if random.random() <= prob_label else 1 - target

In [30]:
def save_csv(df, fp):
    if not os.path.isdir(os.path.split(fp)[0]):
        os.makedirs(os.path.split(fp)[0])
        
    df.to_csv(fp, index=False)

Sequence generation functions

In [31]:
def get_a_sequence(adverse, helper, unhelper, seq_len, label):
    '''creates sequence + label (at the end of list). returns list of list'''
    n_noise = random.choices(range(adverse + helper + unhelper + 1, seq_len), k=1)[0] - (adverse + helper + unhelper) 
    
    sel_adverse, sel_helper, sel_unhelper = [], [], []

    if adverse:
        sel_adverse = get_idx_tok(seq_len, tokens, 'adverse_tokens', adverse)
    if helper:
        sel_helper = get_idx_tok(seq_len, tokens, 'adverse_helper_tokens', helper)
    if unhelper:
        sel_unhelper = get_idx_tok(seq_len, tokens, 'adverse_unhelper_tokens', unhelper)

    sel_noise = get_tokens(seq_len, tokens, 'noise_tokens', n_noise)

    for idx, event in sel_adverse + sel_helper + sel_unhelper:
        sel_noise.insert(idx, event)
        
    sel_noise = ["<PAD>"] * (seq_len - len(sel_noise)) + sel_noise
    
    sim_lab = get_label(0.9, target=label)
    return sel_noise + [sim_lab]

In [32]:
def get_sequences(adverse, helper, unhelper, seq_len, label, uid_len, uid_colname, n_seq):
    sequences = [
        get_a_sequence(
            adverse=adverse, 
            helper=helper, 
            unhelper=unhelper,
            seq_len=seq_len,
            label=label
        ) + [get_uid(uid_len)]
    for _ in range(n_seq)]
    
    seq_df = pd.DataFrame(sequences)
    seq_df.columns = [str(x) for x in range(seq_len-1, -1, -1)] + ['label', uid_colname]
    
    return seq_df

### Simple dataset

Get simple dataset:
- positive set: (+++, 1 major + a helper), (++, 1 major), (+, 3 helper)
- negative set: (---, 3 unhelper), (--, 1 helper + 2 unhelper), (-, 2 helper + 1 unhelper)


**NOTES**<br>
n_ppp_adverse = 2000 # 1 adverse event + 1 helper event <br>
n_pp_adverse = 2000 # 1 adverse event <br>
n_p_adverse = 2000 # 3 helper events <br><br>
n_nnn_adverse = 2000 # 3 unhelper events <br>
n_nn_adverse = 2000 # 1 helper + 2 unhelper <br>
n_n_adverse = 2000 # 2 helper + 1 unhelper <br>

In [33]:
def get_simple_dataset(seq_len, uid_len, uid_colname, count_dict):
    ppp = get_sequences(adverse=1, helper=1, unhelper=0, seq_len=seq_len, label=1, uid_len=uid_len,
                    uid_colname=uid_colname, n_seq=count_dict['n_ppp_adverse'])
    pp = get_sequences(adverse=1, helper=0, unhelper=0, seq_len=seq_len, label=1, uid_len=uid_len,
                    uid_colname=uid_colname, n_seq=count_dict['n_pp_adverse'])
    p = get_sequences(adverse=0, helper=3, unhelper=0, seq_len=seq_len, label=1, uid_len=uid_len,
                    uid_colname=uid_colname, n_seq=count_dict['n_p_adverse'])
    nnn = get_sequences(adverse=0, helper=0, unhelper=3, seq_len=seq_len, label=0, uid_len=uid_len,
                    uid_colname=uid_colname, n_seq=count_dict['n_nnn_adverse'])
    nn = get_sequences(adverse=0, helper=1, unhelper=2, seq_len=seq_len, label=0, uid_len=uid_len,
                    uid_colname=uid_colname, n_seq=count_dict['n_nn_adverse'])
    n = get_sequences(adverse=0, helper=2, unhelper=1, seq_len=seq_len, label=0, uid_len=uid_len,
                    uid_colname=uid_colname, n_seq=count_dict['n_n_adverse'])
    
    dataset = pd.concat([ppp, pp, p, n, nn, nnn], axis=0)
    dataset.reset_index(inplace=True)
    indexes = [idx for idx in range(dataset.shape[0])]
    random.shuffle(indexes)
    dataset = dataset.iloc[indexes, :]
    
    print(f"dataset: {dataset.shape}")
    print(f"ratio:\n{dataset.label.value_counts(normalize=True)}\n")
    
    return dataset

In [36]:
nrows = 1000
train_count_dict = {
    'n_ppp_adverse': nrows,
    'n_pp_adverse': nrows,
    'n_p_adverse': nrows,
    'n_nnn_adverse': nrows,
    'n_nn_adverse': nrows,
    'n_n_adverse': nrows
}

test_count_dict = {
    'n_ppp_adverse': nrows,
    'n_pp_adverse': nrows,
    'n_p_adverse': nrows,
    'n_nnn_adverse': nrows,
    'n_nn_adverse': nrows,
    'n_n_adverse': nrows
}


In [37]:
train_simple_data = get_simple_dataset(
    seq_len=seq_len, uid_len=uid_len, uid_colname=uid_colname, count_dict=train_count_dict)

test_simple_data = get_simple_dataset(
    seq_len=seq_len, uid_len=uid_len, uid_colname=uid_colname, count_dict=test_count_dict)

dataset: (18000, 33)
ratio:
0    0.503278
1    0.496722
Name: label, dtype: float64

dataset: (18000, 33)
ratio:
1    0.502
0    0.498
Name: label, dtype: float64



In [38]:
save_csv(train_simple_data, './train.csv')
save_csv(test_simple_data, './test.csv')

In [1]:
import pandas as pd

In [2]:
df = pd.read_csv("./test.csv")

In [3]:
df.head()

Unnamed: 0,index,29,28,27,26,25,24,23,22,21,...,7,6,5,4,3,2,1,0,label,patient_id
0,359,<PAD>,<PAD>,<PAD>,<PAD>,<PAD>,<PAD>,<PAD>,<PAD>,<PAD>,...,cut_finger,quad_injury,dental_exam,myopia,quad_injury,foot_pain,ingrown_nail,pneumonia,1,4EP65V0W0P
1,2099,<PAD>,<PAD>,<PAD>,<PAD>,<PAD>,<PAD>,<PAD>,<PAD>,ingrown_nail,...,ingrown_nail,backache,cold_sore,annual_physical,annual_physical,cut_finger,cold_sore,resistent_hyp,1,2R08SO4PSC
2,1999,<PAD>,<PAD>,<PAD>,<PAD>,<PAD>,<PAD>,<PAD>,<PAD>,<PAD>,...,<PAD>,<PAD>,<PAD>,cut_finger,ACL_tear,pneumonia,pneumonia,ACE_inhibitors,0,74RRLFB67Y
3,1091,<PAD>,<PAD>,<PAD>,<PAD>,<PAD>,<PAD>,<PAD>,<PAD>,<PAD>,...,<PAD>,<PAD>,headache,myopia,hay_fever,apnea,PCI,low_salt_diet,0,98Y975BL1S
4,2683,<PAD>,<PAD>,<PAD>,<PAD>,<PAD>,annual_physical,ankle_sprain,cold_sore,myopia,...,myopia,ACL_tear,annual_physical,ankle_sprain,ACL_tear,high_creatinine,high_creatinine,normal_bmi,0,2FYWDMZI6R
