In [105]:
import numpy as np
import matplotlib.pyplot as plt
# from gplearn.genetic import SymbolicClassifier
import pandas as pd
import random
# import graphviz
import scipy
from sklearn.ensemble import RandomForestClassifier
from sklearn.utils import check_random_state
from sklearn.metrics import roc_curve, roc_auc_score, precision_score, recall_score
from sklearn.model_selection import StratifiedKFold, cross_val_score
import time


def prep_labels(label_df):
    ### Prepare labels
    # Find unique labels and counts
    class_df = pd.DataFrame(label_df.value_counts()).reset_index()
    class_df = class_df.rename(columns={0:'population'})
    class_df['label'] = class_df.index
    # Convert to numerical labels
    proto2label = {proto:label for proto,label in class_df[['Prototype','label']].to_numpy()}
#     label2proto = {label:proto for proto,label in class_df[['Prototype','label']].to_numpy()}
    label_df['label'] = [proto2label[proto] for proto in label_df.Prototype]
    return label_df, class_df, proto2label

## Prep Data

In [136]:
### CONSTANTS ###########################################################
label_path = '../3_generate_features/final_label_array.csv'
feat_path = '../3_generate_features/dimensionless_cropped_final_feature_array.csv'
SEED = 0
N_SPLITS = 5
select_prototypes = ['Laves(cub)#MgCu2', 'Laves(2H)#MgZn2']
#########################################################################

### Import data
label_df, class_df, proto2label = prep_labels(pd.read_csv(label_path).drop(columns='Unnamed: 0',errors='ignore'))
feat_df = pd.read_csv(feat_path).drop(columns='Unnamed: 0',errors='ignore')

### Convert to binary (1 vs rest), where 1 => select_prototype
label_binary = np.zeros(len(label_df))
for select_proto in select_prototypes:
    label_binary = label_binary | (label_df.Prototype == select_proto)
label_df['label_binary'] = label_binary.astype(int)

### Subsample to attain balance
n_pos = np.where(label_df.label_binary == 1)[0].size
pos_idxs = rng.permutation(np.where(label_df.label_binary == 1)[0])
neg_idxs = rng.permutation(np.where(label_df.label_binary == 0)[0])
rng = np.random.default_rng(seed=SEED)
select_neg_idxs = neg_idxs[:n_pos]
extra_neg_idxs = neg_idxs[n_pos:]
select_idxs = rng.permutation(np.concatenate([select_neg_idxs, pos_idxs]))
bal_feat_df = feat_df.iloc[select_idxs]
bal_label_df = label_df.iloc[select_idxs]

### Train-test-split
skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED) 
train_idx, test_idx = next(skf.split(bal_feat_df, bal_label_df.label_binary))
feat_df_train = bal_feat_df.iloc[train_idx]#.copy()
feat_df_test =  bal_feat_df.iloc[test_idx]#.copy()
label_df_train = bal_label_df.iloc[train_idx]#.copy()
label_df_test =  bal_label_df.iloc[test_idx]#.copy()
print(f'0: {len(label_df_train[label_df_train.label_binary == 0])}')
print(f'1: {len(label_df_train[label_df_train.label_binary == 1])}')

0: 1605
1: 1606


## Instantiate learner

In [147]:
clf = RandomForestClassifier(n_estimators=50, class_weight='balanced', criterion='log_loss', 
                             random_state=SEED, ccp_alpha=0.000, n_jobs=-1)

## Active learning loop

In [174]:
### Constants
INITIAL_FRAC = 0.05 # fraction of training data to train on initially
AL_BATCH_SIZE = 100 # number of samples to train
SCORE_THRESH = 0.99
MAX_SUPER_EPOCHS = 50

### Start active learning loop
# t_idxs = list(
score = 0
current_batch_idxs = feat_df_train.index[:int(INITIAL_FRAC*len(feat_df_train))]
oob_idxs = feat_df_train.index[int(INITIAL_FRAC*len(feat_df_train)):]
i = 0

while True:
    
    ### Check for base cases
    if len(oob_idxs) == 0:
        print('Exhausted OOB samples')
        break
    if score >= SCORE_THRESH:
        print('Reached thershold score')
        break
    if i >= MAX_SUPER_EPOCHS:
        print('Reached max number of generations')
        break
    
    ### Train on current batch
    start = time.time()
    batch_feats = feat_df_train.loc[current_batch_idxs].to_numpy()
    batch_labels = label_df_train.loc[current_batch_idxs].label_binary.to_numpy()
    clf.fit(batch_feats, batch_labels)
    
    ### Calc acc on batch, out-of-batch -- entire set? (or remaining set?) (TODO should i allow retraining on the same example? --> No, defeats the purpose of mo
    batch_predict_labels = clf.predict(batch_feats)
#     batch_prec = precision_score(batch_labels, batch_predict_labels)
#     batch_rec = recall_score(batch_labels, batch_predict_labels)
    oob_feats = feat_df_train.loc[oob_idxs].to_numpy()
    oob_labels = label_df_train.loc[oob_idxs].label_binary.to_numpy()
    oob_predict_labels = clf.predict(oob_feats)
#     batch_prec = precision_score(batch_labels, batch_predict_labels)
#     batch_rec = recall_score(batch_labels, batch_predict_labels)
    prec = precision_score(np.concatenate([batch_labels,oob_labels]),
                           np.concatenate([batch_predict_labels, oob_predict_labels]))
    rec = recall_score(np.concatenate([batch_labels,oob_labels]),
                           np.concatenate([batch_predict_labels, oob_predict_labels]))
    
    ### Calc score on other negative examples
    extra_feats = feat_df.iloc[extra_neg_idxs].to_numpy()
    extra_labels = label_df.iloc[extra_neg_idxs].label_binary.to_numpy()
    extra_predict_labels = clf.predict(extra_feats)
    extra_score = np.mean(extra_labels == extra_predict_labels)

    ### Calculate aggregate score
    score = np.mean([prec,rec,extra_score])

    ### Update current batch -- separate N worst examples as new batch
    oob_losses = (oob_labels == oob_predict_labels)
    sorted_idxs = list(list(zip(*sorted(zip(oob_losses, oob_idxs))))[1])
    if len(sorted_idxs) < AL_BATCH_SIZE:
        batch_idxs = sorted_idxs
        oob_idxs = []
        print('Exhausted OOB samples')
    else:
        batch_idxs = sorted_idxs[:AL_BATCH_SIZE]
        oob_idxs = sorted_idxs[AL_BATCH_SIZE:]
    
    ### Report superepoch details
#     print(f'Elapsed Time: {np.round(time.time()-start,1)}s')
    print(f'Super epoch {i}: Score = {np.round(score,3)} -- (Precision | Recall | AccNeg) = ({np.round(prec,2)} | {np.round(rec,2)} | {np.round(extra_score,2)})')
#     print(f'Super epoch {i}: Score on Neg: {np.round(extra_score,2)}')#  |  Recall: {np.round(extra_rec,2)}')
    i += 1

Super epoch 0: Score = 0.914 -- (Precision | Recall | AccNeg) = (0.91 | 0.93 | 0.9)
Super epoch 1: Score = 0.934 -- (Precision | Recall | AccNeg) = (0.95 | 0.95 | 0.9)
Super epoch 2: Score = 0.955 -- (Precision | Recall | AccNeg) = (0.98 | 0.98 | 0.9)
Super epoch 3: Score = 0.968 -- (Precision | Recall | AccNeg) = (1.0 | 1.0 | 0.9)
Super epoch 4: Score = 0.968 -- (Precision | Recall | AccNeg) = (1.0 | 1.0 | 0.9)
Super epoch 5: Score = 0.968 -- (Precision | Recall | AccNeg) = (1.0 | 1.0 | 0.9)
Super epoch 6: Score = 0.968 -- (Precision | Recall | AccNeg) = (1.0 | 1.0 | 0.9)
Super epoch 7: Score = 0.968 -- (Precision | Recall | AccNeg) = (1.0 | 1.0 | 0.9)
Super epoch 8: Score = 0.968 -- (Precision | Recall | AccNeg) = (1.0 | 1.0 | 0.9)
Super epoch 9: Score = 0.968 -- (Precision | Recall | AccNeg) = (1.0 | 1.0 | 0.9)
Super epoch 10: Score = 0.968 -- (Precision | Recall | AccNeg) = (1.0 | 1.0 | 0.9)
Super epoch 11: Score = 0.968 -- (Precision | Recall | AccNeg) = (1.0 | 1.0 | 0.9)
Super ep

## Examine Failures

In [None]:
#Test