In [17]:
from config import *
import pandas as pd
from numpy import array
import numpy as np
from main_package.utils import data_path_to_abs_path, truncate_interaction_sequences
from main_package.bkt_pyKT_per_skill import bkt_all_interactions
from main_package.bkt_pyKT import get_question_level_prediction, convert_df_strings_to_arrays
from sklearn.metrics import roc_auc_score, accuracy_score
from statsmodels.stats.contingency_tables import mcnemar

In [29]:
df_sample = pd.read_csv(data_path_to_abs_path('mc_nemar/test_predictions_sample.txt'), sep='	')
df_akt = pd.read_csv(data_path_to_abs_path('mc_nemar/test_predictions_akt.txt'), sep='	')
df_saint = pd.read_csv(data_path_to_abs_path('mc_nemar/test_predictions_saint.txt'), sep='	')
df_dkt_plus = pd.read_csv(data_path_to_abs_path('mc_nemar/test_predictions_dkt_plus.txt'), sep='	')
df_test = pd.read_csv(data_path_to_abs_path('isaac/pyKT_processed/test.csv'))
convert_df_strings_to_arrays(df_test)
df_test = truncate_interaction_sequences(df_test)

In [3]:
df_akt.head()
len(df_akt)

87813

In [31]:
df_saint.head()
len(df_saint)

87813

In [30]:
len(df_dkt_plus)

87813

## validate accuracy and auc

In [5]:
print(f"auc: {roc_auc_score(df_akt['late_trues'], df_akt['late_mean'])}")
print(f"accuracy: {accuracy_score(df_akt['late_trues'], np.round(df_akt['late_mean']))}")
print(f'expected AKT auc: 0.6607, accuracy: 0.6212')

auc: 0.6607566602628776
accuracy: 0.6212292029653924
expected AKT auc: 0.6607, accuracy: 0.6212


In [6]:
print(f"auc: {roc_auc_score(df_saint['late_trues'], df_saint['late_mean'])}")
print(f"accuracy: {accuracy_score(df_saint['late_trues'], np.round(df_saint['late_mean']))}")
print(f'expected SAINT auc: 0.6293, accuracy: 0.5916')

auc: 0.6293131028623756
accuracy: 0.591575279286666
expected SAINT auc: 0.6293, accuracy: 0.5916


In [31]:
print(f"auc: {roc_auc_score(df_dkt_plus['late_trues'], df_dkt_plus['late_mean'])}")
print(f"accuracy: {accuracy_score(df_dkt_plus['late_trues'], np.round(df_dkt_plus['late_mean']))}")
print(f'expected dkt_plus auc: 0.6240, accuracy: 0.6077')

auc: 0.6239644116417812
accuracy: 0.6077232300456652
expected dkt_plus auc: 0.6240, accuracy: 0.6077


# BKT standard

## BKT parameters

In [10]:
bkt_params_trained = {'0': array([0.28129558, 0.24881518, 0.49458763, 0.0063549 ]),
 '1': array([0.14936286, 0.23534751, 0.44862168, 0.01379714]),
 '10': array([0.16112846, 0.26505133, 0.5       , 0.04426123]),
 '11': array([5.09428246e-01, 3.40125173e-01, 4.24580414e-01, 1.00000000e-04]),
 '12': array([3.92833020e-01, 3.90782612e-01, 3.69523163e-01, 1.00000000e-04]),
 '13': array([0.32029063, 0.40848163, 0.34818767, 0.02613545]),
 '14': array([0.29072324, 0.32506488, 0.42817709, 0.01917217]),
 '15': array([0.16254359, 0.2481415 , 0.4536599 , 0.04288294]),
 '16': array([0.15174812, 0.14987439, 0.5       , 0.03561117]),
 '17': array([0.39467776, 0.24946422, 0.24677171, 0.18138744]),
 '18': array([9.99900000e-01, 5.00000000e-01, 4.15135869e-01, 1.00000000e-04]),
 '19': array([0.30502508, 0.45028004, 0.4157057 , 0.50778527]),
 '2': array([2.62653357e-01, 3.17270555e-01, 4.22556386e-01, 1.00000000e-04]),
 '20': array([0.6045351 , 0.24355443, 0.49080087, 0.2653337 ]),
 '21': array([3.50898131e-01, 3.61705639e-01, 4.01563911e-01, 1.00000000e-04]),
 '22': array([3.95504777e-01, 3.97034789e-01, 3.34519599e-01, 1.00000000e-04]),
 '23': array([4.31780597e-01, 3.22927564e-01, 4.25169643e-01, 1.00000000e-04]),
 '24': array([4.77127950e-01, 3.04878827e-01, 4.52711371e-01, 1.00000000e-04]),
 '25': array([6.44572931e-01, 2.42763208e-01, 5.00000000e-01, 1.00000000e-04]),
 '26': array([4.18812882e-01, 4.11537219e-01, 3.36719976e-01, 1.00000000e-04]),
 '27': array([0.09706815, 0.2695279 , 0.37426847, 0.05144905]),
 '28': array([9.99900000e-01, 5.00000000e-01, 3.62890149e-01, 1.00000000e-04]),
 '29': array([0.79288396, 0.19301955, 0.35890213, 0.23855013]),
 '3': array([3.49813165e-01, 3.24063830e-01, 4.19718238e-01, 1.00000000e-04]),
 '30': array([0.57201242, 0.41162285, 0.27942031, 0.01170983]),
 '31': array([4.70988184e-01, 3.41473891e-01, 4.28991421e-01, 1.00000000e-04]),
 '32': array([3.90698423e-01, 2.30638302e-01, 4.25289259e-01, 1.00000000e-04]),
 '33': array([0.10451393, 0.284025  , 0.14398427, 0.02465856]),
 '34': array([0.06184114, 0.3756075 , 0.10976529, 0.02438049]),
 '35': array([0.05146758, 0.32406619, 0.21908568, 0.06055558]),
 '36': array([5.61137360e-01, 3.23517752e-01, 4.28042619e-01, 1.00000000e-04]),
 '37': array([3.50414391e-01, 3.66956007e-01, 3.98206128e-01, 1.00000000e-04]),
 '38': array([0.07414367, 0.08196713, 0.35003144, 0.1       ]),
 '39': array([5.42434526e-01, 3.63991130e-01, 4.56687084e-01, 1.00000000e-04]),
 '4': array([0.14876279, 0.2283111 , 0.47657939, 0.10349698]),
 '40': array([0.65220994, 0.27818728, 0.49760459, 0.03408201]),
 '41': array([0.26987761, 0.26562336, 0.5       , 0.03922695]),
 '42': array([0.3235502 , 0.43151115, 0.28569876, 0.9999    ]),
 '43': array([0.0921825 , 0.44208458, 0.05930973, 0.05337252]),
 '44': array([0.48530608, 0.18117751, 0.39234073, 0.1380495 ]),
 '45': array([0.12119925, 0.30869052, 0.26860871, 0.03836575]),
 '46': array([1.27976279e-01, 3.20574861e-01, 1.00000000e-04, 1.00000000e-04]),
 '47': array([2.97239207e-01, 3.63932511e-01, 3.51169337e-01, 1.00000000e-04]),
 '48': array([3.32510794e-01, 3.60568142e-01, 3.63350030e-01, 1.00000000e-04]),
 '49': array([0.12034199, 0.38136935, 0.21121789, 0.01209636]),
 '5': array([1.71257928e-01, 3.34125976e-01, 3.29719578e-01, 1.00000000e-04]),
 '50': array([3.94905031e-01, 4.99799236e-01, 4.97897973e-01, 1.00000000e-04]),
 '51': array([0.63411038, 0.49833556, 0.37575692, 0.9999    ]),
 '52': array([0.18700605, 0.21031422, 0.36760114, 0.1493367 ]),
 '53': array([4.83177259e-01, 2.70223040e-01, 4.82545038e-01, 1.00000000e-04]),
 '54': array([0.03406625, 0.36295471, 0.20831683, 0.00487054]),
 '55': array([1.00000000e-04, 3.10915672e-01, 5.00000000e-01, 7.07955180e-02]),
 '56': array([0.4870775 , 0.12395041, 0.38058129, 0.06650506]),
 '57': array([0.11922182, 0.27840597, 0.5       , 0.09147252]),
 '58': array([0.87507544, 0.01952756, 0.33523152, 0.70080023]),
 '59': array([0.51924336, 0.18840161, 0.49157484, 0.03458294]),
 '6': array([9.32762174e-01, 1.00000000e-04, 4.45190982e-01, 1.00000000e-04]),
 '60': array([1.00000000e-04, 3.53093630e-01, 5.00000000e-01, 5.45625177e-01]),
 '61': array([1.00000000e-04, 3.91670455e-01, 5.00000000e-01, 2.48575514e-01]),
 '62': array([0.06666077, 0.07682419, 0.36016077, 0.1       ]),
 '63': array([1.00000000e-04, 2.25073995e-01, 3.54459904e-01, 5.91601935e-02]),
 '64': array([3.57216692e-01, 3.41858007e-01, 2.57583160e-01, 1.00000000e-04]),
 '65': array([0.50695185, 0.29806438, 0.05508052, 0.1       ]),
 '66': array([0.13954238, 0.24305967, 0.39323411, 0.20383768]),
 '67': array([6.50229016e-02, 2.96041974e-01, 4.81202671e-01, 1.00000000e-04]),
 '68': array([0.36741658, 0.17284294, 0.09809557, 0.1       ]),
 '69': array([1.40616156e-01, 5.00000000e-01, 1.00000000e-04, 1.00000000e-04]),
 '7': array([5.34564696e-01, 3.41723708e-01, 4.24681415e-01, 1.00000000e-04]),
 '70': array([5.86929490e-01, 3.55604347e-01, 2.52894012e-01, 1.00000000e-04]),
 '71': array([8.05766504e-01, 2.73608331e-01, 2.92454873e-01, 1.00000000e-04]),
 '72': array([0.29404374, 0.1115613 , 0.13865167, 0.1       ]),
 '73': array([0.20501716, 0.11606682, 0.2208817 , 0.1       ]),
 '74': array([0.08299564, 0.16438398, 0.2493097 , 0.10472005]),
 '75': array([0.02372674, 0.2404437 , 0.26190836, 0.02336577]),
 '76': array([7.67497346e-01, 5.00000000e-01, 4.53706907e-01, 1.00000000e-04]),
 '77': array([0.07156827, 0.12813943, 0.2454355 , 0.16628182]),
 '78': array([0.01985974, 0.04364419, 0.38226229, 0.05034317]),
 '79': array([1.05448070e-01, 2.16372277e-01, 2.72887086e-01, 1.00000000e-04]),
 '8': array([0.25455385, 0.32872037, 0.37027929, 0.01132733]),
 '80': array([0.46989344, 0.25824454, 0.04117031, 0.1       ]),
 '81': array([0.66610827, 0.5       , 0.16684029, 0.39784953]),
 '82': array([0.31952782, 0.13287757, 0.12461323, 0.1       ]),
 '83': array([0.3464583 , 0.5       , 0.49499908, 0.0971283 ]),
 '84': array([2.68879196e-01, 3.24281600e-01, 3.69762330e-01, 1.00000000e-04]),
 '85': array([8.28465347e-02, 3.75856280e-01, 5.00000000e-01, 1.00000000e-04]),
 '86': array([0.25664039, 0.09948753, 0.15944038, 0.1       ]),
 '87': array([0.12019834, 0.10691127, 0.29247184, 0.1       ]),
 '88': array([0.27369134, 0.10278341, 0.14993026, 0.1       ]),
 '89': array([0.33968643, 0.14971245, 0.11346814, 0.1       ]),
 '9': array([0.39350828, 0.26397287, 0.5       , 0.04122105]),
 '90': array([4.31286606e-01, 1.00000000e-04, 3.18169849e-01, 9.99900000e-01]),
 '91': array([9.999e-01, 5.000e-01, 1.000e-04, 1.000e-01]),
 '92': array([9.999e-01, 5.000e-01, 1.000e-04, 1.000e-01]),
 '93': array([9.999e-01, 5.000e-01, 1.000e-04, 1.000e-01])}

## BKT get predictions

In [22]:
df_test['predictions'] = bkt_all_interactions(bkt_params_trained, df_test, None)
predictions, responses = get_question_level_prediction(df_test)
rounded_prediction = [round(p) for p in predictions]

In [23]:
print(f"number of predicted questions: {len(predictions)}")
auc = roc_auc_score(responses, predictions)
print(f"auc: {auc}")
accuracy = accuracy_score(responses, rounded_prediction)
print(f"accuracy: {accuracy}")

number of predicted questions: 87967
auc: 0.605325208412591
accuracy: 0.5978946650448463


In [24]:
len(df_akt)

27926

# AKT vs. SAINT McNemar

In [14]:
print(sum(df_saint['late_trues'] != df_akt['late_trues']))
saint_predictions = np.round(df_saint['late_mean'])
akt_predictions = np.round(df_akt['late_mean'])
trues = np.array(df_saint['late_trues'])
saint_correct = saint_predictions==trues
akt_correct = akt_predictions==trues
num_both_right = sum(saint_correct & akt_correct)
num_both_wrong = sum(~(saint_correct & akt_correct))
num_akt_right_saint_wrong = sum(~saint_correct & akt_correct)
num_akt_wrong_saint_right = sum(saint_correct & ~akt_correct)
print(f'num_both_right: {num_both_right}')
print(f'num_both_wrong: {num_both_wrong}')
print(f'num_akt_right_saint_wrong: {num_akt_right_saint_wrong}')
print(f'num_akt_wrong_saint_right: {num_akt_wrong_saint_right}')


0
num_both_right: 43808
num_both_wrong: 44005
num_akt_right_saint_wrong: 10744
num_akt_wrong_saint_right: 8140


In [16]:
mc_nemar_statistic = (abs(num_akt_right_saint_wrong - num_akt_wrong_saint_right) - 1)**2 / (num_akt_right_saint_wrong + num_akt_wrong_saint_right)
mc_nemar_statistic

358.8015780554967

In [22]:
cont_table = [[num_both_right, num_akt_right_saint_wrong], [num_akt_wrong_saint_right, num_both_wrong]]
print(mcnemar(table=cont_table, exact=False, correction=True))

pvalue      5.134972030908005e-80
statistic   358.8015780554967


# AKT vs. DKT+

In [33]:
print(sum(df_saint['late_trues'] != df_akt['late_trues']))
akt_predictions = np.round(df_akt['late_mean'])
dkt_plus_predictions = np.round(df_dkt_plus['late_mean'])
trues = np.array(df_akt['late_trues'])
akt_correct = akt_predictions==trues
dkt_plus_correct = dkt_plus_predictions==trues
num_both_right = sum(akt_correct & dkt_plus_correct)
num_both_wrong = sum(~(akt_correct & dkt_plus_correct))
num_dkt_plus_right_akt_wrong = sum(~akt_correct & dkt_plus_correct)
num_dkt_plus_wrong_akt_right = sum(akt_correct & ~dkt_plus_correct)
print(f'num_both_right: {num_both_right}')
print(f'num_both_wrong: {num_both_wrong}')
print(f'num_dkt_plus_right_akt_wrong: {num_dkt_plus_right_akt_wrong}')
print(f'num_dkt_plus_wrong_akt_right: {num_dkt_plus_wrong_akt_right}')

0
num_both_right: 44150
num_both_wrong: 43663
num_dkt_plus_right_akt_wrong: 9216
num_dkt_plus_wrong_akt_right: 10402


In [35]:
cont_table = [[num_both_right, num_dkt_plus_wrong_akt_right], [num_dkt_plus_right_akt_wrong, num_both_wrong]]
print(mcnemar(table=cont_table, exact=False, correction=True))

pvalue      2.6645801027849876e-17
statistic   71.5783973901519
