In [1]:
import shap
import pandas as pd
import numpy as np
from collections import deque
import gc
from sklearn.metrics import roc_auc_score
from scipy.stats import kurtosis, skew
from sklearn.preprocessing import OneHotEncoder
from bitarray import bitarray
from collections import defaultdict
from tqdm.notebook import tqdm
import lightgbm as lgb
import dill as pickle
import matplotlib.pyplot as plt
import random
import seaborn as sns
sns.set()
pd.options.mode.chained_assignment = None # discard the SettingWithCopyWarning
np.seterr(divide='ignore', invalid='ignore'); # discard /0 warnings

In [2]:
########################################################################################
# PATHS
########################################################################################


# Datasets

question_file = '../input/riiid-test-answer-prediction/questions.csv'
lecture_file = '../input/riiid-test-answer-prediction/lectures.csv'
example_test_file = '../input/riiid-test-answer-prediction/example_test.csv'


# Offline preprocessed features


offline_features_train_paths = { 'answered_correctly_avg_u': '../input/riiid-avg-score-u-corrected/avg_score_u_train_arr.npy',
                                 'answered_correctly_sum_u': '../input/riiid-avg-score-u-corrected/sum_score_u_train_arr.npy',
                                 'q_count_u': '../input/riiid-avg-score-u-corrected/q_count_u_train_arr.npy',
                                 'is_first_attempt': '../input/riiid-is-first-atpt-rows-q/is_first_attempt_rows_train_q.npy',
                                 'answered_correctly_avg_c': '../input/riiid-feats-rows-arr-01/avg_score_c_rows_train_questions.npy',
                                 'answered_correctly_std_c': '../input/riiid-feats-rows-arr-01/std_score_c_rows_train_questions.npy',
                                 'answered_correctly_median_c': '../input/riiid-feats-rows-arr-01/median_score_c_rows_train_questions.npy',
                                 'part': '../input/riiid-feats-rows-arr-01/part_rows_train_questions.npy',
                                 'prior_question_had_explanation': '../input/riiid-feats-rows-arr-01/prior_explanation_rows_train_questions.npy',
                                 'prior_question_elapsed_time': '../input/riiid-feats-rows-arr-01/prior_elapsed_time_rows_train_questions.npy',
                                 'l_count_u': '../input/riiid-count-u-lec-rows-q/count_u_lec_rows_train_q.npy',
                                 'l_part_count_u': '../input/riiid-lec-part-u-rows-q/cnt_lec_part_u_rows_train_q.npy',
                                 'session_num': '../input/riiid-session-basics-fe/session_num_train_arr.npy' ,
                                 'same_container_as_last': '../input/riiid-same-cont-as-last-arr/same_container_as_last_train_arr.npy',
                                 'last_container_sum_answ': '../input/riiid-last-cont-sum-answ-arr/last_container_sum_answ_train_arr.npy',
                                 'last_cont_q_count': '../input/riiid-last-cont-q-count-arr/last_container_q_count_train_arr.npy' ,
                                 'last_session_break_time': '../input/riiid-last-sess-break-time-pkl-arrs/last_session_break_time_train_arr.npy',
                                 'avg_break_time': '../input/riiid-avg-break-time-pkl-arr/avg_break_time_train_arr.npy',
                                 'avg_session_time': '../input/riiid-session-avg-time/session_avg_time_train_arr.npy',
                                 'avg_session_q_count': '../input/riiid-avg-session-q-count-arr/avg_session_q_count_train_arr.npy',
                                 'current_session_avg_score': '../input/riiid-curr-sess-avg-score-pkl-arr/current_session_avg_score_train_arr.npy',
                                 'same_session_as_last': '../input/riiid-session-basics-fe/same_session_as_last_train_arr.npy',
                                 'current_session_time': '../input/riiid-curr-session-time-arr-pkl/current_session_time_train_arr.npy',
                                 'current_session_q_count': '../input/riiid-curr-sess-q-count-arr-pkl/current_session_q_count_train_arr.npy',
                                 'avg_score_u_hist_100_75': '../input/riiid-hist-100-score-slope-clean/avg_score_u_hist_100_75_train_arr.npy',
                                 'avg_score_u_hist_75_50': '../input/riiid-hist-100-score-slope-clean/avg_score_u_hist_75_50_train_arr.npy',
                                 'avg_score_u_hist_50_25': '../input/riiid-hist-100-score-slope-clean/avg_score_u_hist_50_25_train_arr.npy',
                                 'avg_score_u_hist_25_0': '../input/riiid-hist-100-score-slope-clean/avg_score_u_hist_25_0_train_arr.npy',
                                 'hist_score_diff_u_100_50': '../input/riiid-hist-100-score-slope-clean/hist_score_diff_u_100_50_train_arr.npy',
                                 'hist_score_diff_u_75_25': '../input/riiid-hist-100-score-slope-clean/hist_score_diff_u_75_25_train_arr.npy',
                                 'hist_score_diff_u_50_0': '../input/riiid-hist-100-score-slope-clean/hist_score_diff_u_50_0_train_arr.npy',
                                 'hist_100_score_slope_u': '../input/riiid-hist-100-score-slope-clean/hist_100_score_slope_u_train_arr.npy',
                                 'current_right_answ_streak': '../input/riiid-answ-streak-u-pkl-arr/current_right_answ_streak_train_arr.npy',
                                 'max_right_answ_streak': '../input/riiid-answ-streak-u-pkl-arr/max_right_answ_streak_train_arr.npy',
                                 'hist_1_right_answ_streak': '../input/riiid-answ-streak-u-pkl-arr/hist_1_right_answ_streak_train_arr.npy',
                                 'hist_2_right_answ_streak': '../input/riiid-answ-streak-u-pkl-arr/hist_2_right_answ_streak_train_arr.npy',
                                 'hist_3_right_answ_streak': '../input/riiid-answ-streak-u-pkl-arr/hist_3_right_answ_streak_train_arr.npy',
                                 'avg_right_answ_streak_hist_3': '../input/riiid-answ-streak-u-pkl-arr/avg_right_answ_streak_hist_3_train_arr.npy',
                                 'first_try_success_count_u': '../input/riiid-first-try-success-count-u-arr-pkl/first_try_success_count_u_train_arr.npy',
                                 'unique_count_attempted_q_u': '../input/riiid-cnt-attempted-q-arr-pkl/count_attempted_q_u_train_arr.npy',
                                 'avg_score_t': '../input/riiid-fe-tags-score-stats/avg_score_t_train_arr.npy',
                                 'max_avg_score_t': '../input/riiid-fe-tags-score-stats/max_avg_score_t_train_arr.npy',
                                 'min_avg_score_t': '../input/riiid-fe-tags-score-stats/min_avg_score_t_train_arr.npy',
                                 'std_avg_score_t': '../input/riiid-fe-tags-score-stats/std_avg_score_t_train_arr.npy',
                                 'avg_tag_freq': '../input/riiid-fe-tags-freq-stats/tag_freq_feats_arrs/avg_tag_frequency_train_arr.npy',
                                 'max_tag_freq': '../input/riiid-fe-tags-freq-stats/tag_freq_feats_arrs/max_tag_frequency_train_arr.npy',
                                 'min_tag_freq': '../input/riiid-fe-tags-freq-stats/tag_freq_feats_arrs/min_tag_frequency_train_arr.npy',
                                 'std_tag_freq': '../input/riiid-fe-tags-freq-stats/tag_freq_feats_arrs/std_tag_frequency_train_arr.npy',
                                 'tag_count_q': '../input/riiid-q-tag-count-dict/tag_count_q_train_arr.npy',
                                 'avg_question_elapsed_time_c': '../input/riiid-question-time-stats/avg_question_elapsed_time_c_train_arr.npy',
                                 'std_question_elapsed_time_c': '../input/riiid-question-time-stats/std_question_elapsed_time_c_train_arr.npy',
                                 'avg_time_since_last_c': '../input/riiid-question-time-stats/avg_time_since_last_c_train_arr.npy',
                                 'std_time_since_last_c': '../input/riiid-question-time-stats/std_time_since_last_c_train_arr.npy',
                                 'diff_avg_question_elapsed_time_c': '../input/riiid-question-time-stats/diff_avg_question_elapsed_time_c_train_arr.npy',
                                 'diff_avg_time_since_last_c': '../input/riiid-question-time-stats/diff_avg_time_since_last_c_train_arr.npy',
                                 'ratio_diff_avg_question_elapsed_time_c': '../input/riiid-question-time-stats/ratio_diff_avg_question_elapsed_time_c_train_arr.npy',
                                 'ratio_diff_avg_time_since_last_c': '../input/riiid-question-time-stats/ratio_diff_avg_time_since_last_c_train_arr.npy',
                                 'q_explanation_avg_u': '../input/riiid-q-explanation-sum-avg/q_explanation_avg_u_train_arr.npy',
                                 'q_explanation_sum_u': '../input/riiid-q-explanation-sum-avg/q_explanation_sum_u_train_arr.npy',
                                 'avg_part_score_c': '../input/riiid-part-score-stats-avg-std/avg_part_score_c_train_arr.npy',
                                 'std_part_score_c': '../input/riiid-part-score-stats-avg-std/std_part_score_c_train_arr.npy',
                                 'time_since_last_sum_u': '../input/riiid-time-since-last-sum-avg/time_since_last_sum_u_train_arr.npy',
                                 'time_since_last_avg_u': '../input/riiid-time-since-last-sum-avg/time_since_last_avg_u_train_arr.npy',
                                 'curr_cont_score_avg_u': '../input/riiid-curr-cont-stats-arr-pkl/curr_cont_score_avg_u_train_arr.npy',
                                 'curr_cont_score_sum_u': '../input/riiid-curr-cont-stats-arr-pkl/curr_cont_score_sum_u_train_arr.npy',
                                 'curr_cont_tackled_q_count_u': '../input/riiid-curr-cont-stats-arr-pkl/curr_cont_tackled_q_count_u_train_arr.npy',
                                 'avg_question_elapsed_time_c_right': '../input/riiid-q-elps-time-score-avg-arrs/avg_question_elapsed_time_c_right_train_arr.npy',
                                 'avg_question_elapsed_time_c_wrong': '../input/riiid-q-elps-time-score-avg-arrs/avg_question_elapsed_time_c_wrong_train_arr.npy',
                                 'avg_time_since_last_c_right': '../input/riiid-q-time-lag-score-avg-arrs/avg_time_since_last_c_right_train_arr.npy',
                                 'avg_time_since_last_c_wrong': '../input/riiid-q-time-lag-score-avg-arrs/avg_time_since_last_c_wrong_train_arr.npy',
                                 'diff_avg_question_elapsed_time_c_right': '../input/riiid-q-elps-time-score-diff-arrs/diff_avg_question_elapsed_time_c_right_train_arr.npy',
                                 'diff_avg_question_elapsed_time_c_wrong': '../input/riiid-q-elps-time-score-diff-arrs/diff_avg_question_elapsed_time_c_wrong_train_arr.npy',
                                 'diff_avg_time_since_last_c_right': '../input/riiid-q-time-lag-score-diff-arrs/diff_avg_time_since_last_c_right_train_arr.npy',
                                 'diff_avg_time_since_last_c_wrong': '../input/riiid-q-time-lag-score-diff-arrs/diff_avg_time_since_last_c_wrong_train_arr.npy',
                                 'ratio_avg_time_since_last_c_right': '../input/riiid-q-elps-time-right-ratios-arrs/ratio_avg_time_since_last_c_right_train_arr.npy',
                                 'ratio_avg_time_since_last_c_wrong': '../input/riiid-q-elps-time-wrong-ratios-arrs/ratio_avg_time_since_last_c_wrong_train_arr.npy',
                                 'ratio_diff_avg_question_elapsed_time_c_right': '../input/riiid-q-time-lag-right-ratios-arrs/ratio_diff_avg_question_elapsed_time_c_right_train_arr.npy',
                                 'ratio_diff_avg_question_elapsed_time_c_wrong': '../input/riiid-q-time-lag-wrong-ratios-arrs/ratio_diff_avg_question_elapsed_time_c_wrong_train_arr.npy',
                                 'difficulty_level': '../input/riiid-difficulty-lvl-q-arr-dicts/difficulty_q_level_train_arr.npy',
                                 'difficulty_level_avg_score_c': '../input/riiid-difficulty-lvl-q-arr-dicts/difficulty_q_level_c_avg_train_arr.npy',
                                 'difficulty_level_std_score_c': '../input/riiid-difficulty-lvl-q-arr-dicts/difficulty_q_level_c_std_train_arr.npy',
                                 'sum_score_q_level': '../input/riiid-u-score-difficulty-lvl-corr/sum_score_q_level_train_arr.npy',
                                 'q_count_level_u': '../input/riiid-u-score-difficulty-lvl-corr/q_count_level_train_arr.npy',
                                 'avg_score_q_level_u': '../input/riiid-u-score-difficulty-lvl-corr/avg_score_q_level_train_arr.npy',
                                 'avg_q_elapsed_time_lvl': '../input/riiid-q-elpsd-time-dffclty-lvl-arr-dicts/avg_q_elapsed_time_level_train_arr.npy',
                                 'diff_avg_q_elapsed_time_lvl': '../input/riiid-q-elpsd-time-dffclty-lvl-arr-dicts/diff_avg_q_elapsed_time_per_lvl_train_arr.npy',
                                 'ratio_diff_avg_q_elapsed_time_lvl': '../input/riiid-q-elpsd-time-dffclty-lvl-arr-dicts/ratio_avg_q_elapsed_time_per_lvl_train_arr.npy',
                                 'diff_avg_score_q_lvl': '../input/riiid-u-score-difficulty-lvl/diff_avg_score_q_lvl_train_arr.npy',
                                 'ratio_diff_avg_score_q_lvl': '../input/riiid-u-score-difficulty-lvl/ratio_diff_avg_score_q_lvl_train_arr.npy',
                                 'ratio_diff_avg_score_u_c': '../input/riiid-ratio-diff-score-u-c/ratio_diff_avg_score_u_c_train_arr.npy',
                                 'num_answers_q': '../input/riiid-question-meta-stats/num_answers_q_train_arr.npy',
                                 'q_count_trainset_c': '../input/riiid-question-meta-stats/q_count_trainset_c_train_arr.npy',
                                 'tag_cluster': '../input/riiid-tags-cluster-dict-arrs/q_tags_clusters_train_arr.npy',
                                 'avg_score_u_cluster_start': '../input/riiid-user-start-cluster/avg_score_u_cluster_start_train_arr.npy',
                                 'first_answ': '../input/riiid-user-start-cluster/first_answ_train_arr.npy',
                                 'u_cluster_start': '../input/riiid-user-start-cluster/u_cluster_start_train_arr.npy',
                                 'avg_score_u_curr_day': '../input/riiid-fe-days-weeks-months/avg_score_u_curr_day_train_arr.npy',
                                 'avg_score_u_curr_week': '../input/riiid-fe-days-weeks-months/avg_score_u_curr_week_train_arr.npy',
                                 'avg_score_u_curr_month': '../input/riiid-fe-days-weeks-months/avg_score_u_curr_month_train_arr.npy',
                                 'avg_score_u_last_day': '../input/riiid-fe-days-weeks-months/avg_score_u_last_day_train_arr.npy',
                                 'avg_score_u_last_week': '../input/riiid-fe-days-weeks-months/avg_score_u_last_week_train_arr.npy',
                                 'avg_score_u_last_month': '../input/riiid-fe-days-weeks-months/avg_score_u_last_month_train_arr.npy',
                                 'curr_day_count_u': '../input/riiid-fe-days-weeks-months/curr_day_count_u_train_arr.npy',
                                 'curr_week_count_u': '../input/riiid-fe-days-weeks-months/curr_week_count_u_train_arr.npy',
                                 'curr_month_count_u': '../input/riiid-fe-days-weeks-months/curr_month_count_u_train_arr.npy',
                                 'q_count_u_curr_day': '../input/riiid-fe-days-weeks-months/q_count_u_curr_day_train_arr.npy',
                                 'q_count_u_curr_week': '../input/riiid-fe-days-weeks-months/q_count_u_curr_week_train_arr.npy',
                                 'q_count_u_curr_month': '../input/riiid-fe-days-weeks-months/q_count_u_curr_month_train_arr.npy',
                                 'q_count_u_last_day': '../input/riiid-fe-days-weeks-months/q_count_u_last_day_train_arr.npy',
                                 'q_count_u_last_week': '../input/riiid-fe-days-weeks-months/q_count_u_last_week_train_arr.npy',
                                 'q_count_u_last_month': '../input/riiid-fe-days-weeks-months/q_count_u_last_month_train_arr.npy',
                                 'sum_score_u_curr_day': '../input/riiid-fe-days-weeks-months/sum_score_u_curr_day_train_arr.npy',
                                 'sum_score_u_curr_week': '../input/riiid-fe-days-weeks-months/sum_score_u_curr_week_train_arr.npy',
                                 'sum_score_u_curr_month': '../input/riiid-fe-days-weeks-months/sum_score_u_curr_month_train_arr.npy',
                                 'time_since_last': '../input/riiid-time-since-last-hist-3/time_since_last_train_arr.npy',
                                 'time_since_last_2': '../input/riiid-time-since-last-hist-3/time_since_last_2_train_arr.npy',
                                 'time_since_last_3': '../input/riiid-time-since-last-hist-3/time_since_last_3_train_arr.npy',
                                 'time_since_last_right': '../input/riiid-time-since-last-right-wrong-corrected/time_since_last_right_train_arr.npy',
                                 'time_since_last_wrong': '../input/riiid-time-since-last-right-wrong-corrected/time_since_last_wrong_train_arr.npy',
                                 'curr_avg_q_elapsed_time': '../input/riiid-avg-q-elapsed-time-c/curr_avg_q_elapsed_time_train_arr.npy',
                                 'curr_avg_q_is_explained': '../input/riiid-curr-avg-q-is-explained/curr_avg_q_is_explained_train_arr.npy',
                                 'avg_score_u_part': '../input/riiid-u-score-part/avg_score_u_part_train_arr.npy',
                                 'sum_score_u_part': '../input/riiid-u-score-part/sum_score_u_part_train_arr.npy',
                                 'q_count_u_part': '../input/riiid-u-score-part/q_count_u_part_train_arr.npy',
                                 'avg_prior_cont_time': '../input/riiid-prior-q-cont-elapsed-time/avg_prior_cont_time_train_arr.npy',
                                 'avg_prior_q_elapsed_time': '../input/riiid-prior-q-cont-elapsed-time/avg_prior_q_elapsed_time_train_arr.npy',
                                 'avg_prior_cont_time_part': '../input/riiid-prior-q-cont-elasped-time-part/avg_prior_cont_time_part_train_arr.npy',
                                 'avg_prior_q_elapsed_time_part': '../input/riiid-prior-q-cont-elasped-time-part/avg_prior_q_elapsed_time_part_train_arr.npy',
                                 'ratio_curr_by_avg_session_time': '../input/riiid-ratio-curr-by-avg-sess-time/ratio_curr_by_avg_session_time_train_arr.npy',
                                 'avg_u_time_since_last': '../input/riiid-t-diff-u-avg-and-ratios/avg_u_time_since_last_train_arr.npy',
                                 'ratio_tdiff_by_avg_tdiff_u': '../input/riiid-t-diff-u-avg-and-ratios/ratio_tdiff_by_avg_tdiff_u_train_arr.npy',
                                 'answered_correctly': '../input/riiid-target-question-rows-arr/target_train_question_rows_arr.npy'
                               }



offline_features_valid_paths = { 'answered_correctly_avg_u': '../input/riiid-avg-score-u-corrected/avg_score_u_valid_arr.npy',
                                 'answered_correctly_sum_u': '../input/riiid-avg-score-u-corrected/sum_score_u_valid_arr.npy',
                                 'q_count_u': '../input/riiid-avg-score-u-corrected/q_count_u_valid_arr.npy',
                                 'is_first_attempt': '../input/riiid-is-first-atpt-rows-q/is_first_attempt_rows_valid_q_full.npy',
                                 'answered_correctly_avg_c': '../input/riiid-feats-rows-arr-01/avg_score_c_rows_valid_questions.npy',
                                 'answered_correctly_std_c': '../input/riiid-feats-rows-arr-01/std_score_c_rows_valid_questions.npy',
                                 'answered_correctly_median_c': '../input/riiid-feats-rows-arr-01/median_score_c_rows_valid_questions.npy',
                                 'part': '../input/riiid-feats-rows-arr-01/part_rows_valid_questions.npy',
                                 'prior_question_had_explanation': '../input/riiid-feats-rows-arr-01/prior_explanation_rows_valid_questions.npy',
                                 'prior_question_elapsed_time': '../input/riiid-feats-rows-arr-01/prior_elapsed_time_rows_valid_questions.npy',
                                 'l_count_u': '../input/riiid-count-u-lec-rows-q/count_u_lec_rows_valid_q_full.npy',
                                 'l_part_count_u': '../input/riiid-lec-part-u-rows-q/cnt_lec_part_u_rows_valid_q_full.npy',
                                 'session_num': '../input/riiid-session-basics-fe/session_num_valid_arr.npy' ,
                                 'same_container_as_last': '../input/riiid-same-cont-as-last-arr/same_container_as_last_valid_arr.npy',
                                 'last_container_sum_answ': '../input/riiid-last-cont-sum-answ-arr/last_container_sum_answ_valid_arr.npy',
                                 'last_cont_q_count': '../input/riiid-last-cont-q-count-arr/last_container_q_count_valid_arr.npy' ,
                                 'last_session_break_time': '../input/riiid-last-sess-break-time-pkl-arrs/last_session_break_time_valid_arr.npy',
                                 'avg_break_time': '../input/riiid-avg-break-time-pkl-arr/avg_break_time_valid_arr.npy',
                                 'avg_session_time': '../input/riiid-session-avg-time/session_avg_time_valid_arr.npy',
                                 'avg_session_q_count': '../input/riiid-avg-session-q-count-arr/avg_session_q_count_valid_arr.npy',
                                 'current_session_avg_score': '../input/riiid-curr-sess-avg-score-pkl-arr/current_session_avg_score_valid_arr.npy',
                                 'same_session_as_last': '../input/riiid-session-basics-fe/same_session_as_last_valid_arr.npy',
                                 'current_session_time': '../input/riiid-curr-session-time-arr-pkl/current_session_time_valid_arr.npy',
                                 'current_session_q_count': '../input/riiid-curr-sess-q-count-arr-pkl/current_session_q_count_valid_arr.npy',
                                 'avg_score_u_hist_100_75': '../input/riiid-hist-100-score-slope-clean/avg_score_u_hist_100_75_valid_arr.npy',
                                 'avg_score_u_hist_75_50': '../input/riiid-hist-100-score-slope-clean/avg_score_u_hist_75_50_valid_arr.npy',
                                 'avg_score_u_hist_50_25': '../input/riiid-hist-100-score-slope-clean/avg_score_u_hist_50_25_valid_arr.npy',
                                 'avg_score_u_hist_25_0': '../input/riiid-hist-100-score-slope-clean/avg_score_u_hist_25_0_valid_arr.npy',
                                 'hist_score_diff_u_100_50': '../input/riiid-hist-100-score-slope-clean/hist_score_diff_u_100_50_valid_arr.npy',
                                 'hist_score_diff_u_75_25': '../input/riiid-hist-100-score-slope-clean/hist_score_diff_u_75_25_valid_arr.npy',
                                 'hist_score_diff_u_50_0': '../input/riiid-hist-100-score-slope-clean/hist_score_diff_u_50_0_valid_arr.npy',
                                 'hist_100_score_slope_u': '../input/riiid-hist-100-score-slope-clean/hist_100_score_slope_u_valid_arr.npy',
                                 'current_right_answ_streak': '../input/riiid-answ-streak-u-pkl-arr/current_right_answ_streak_valid_arr.npy',
                                 'max_right_answ_streak': '../input/riiid-answ-streak-u-pkl-arr/max_right_answ_streak_valid_arr.npy',
                                 'hist_1_right_answ_streak': '../input/riiid-answ-streak-u-pkl-arr/hist_1_right_answ_streak_valid_arr.npy',
                                 'hist_2_right_answ_streak': '../input/riiid-answ-streak-u-pkl-arr/hist_2_right_answ_streak_valid_arr.npy',
                                 'hist_3_right_answ_streak': '../input/riiid-answ-streak-u-pkl-arr/hist_3_right_answ_streak_valid_arr.npy',
                                 'avg_right_answ_streak_hist_3': '../input/riiid-answ-streak-u-pkl-arr/avg_right_answ_streak_hist_3_valid_arr.npy',
                                 'first_try_success_count_u': '../input/riiid-first-try-success-count-u-arr-pkl/first_try_success_count_u_valid_arr.npy',
                                 'unique_count_attempted_q_u': '../input/riiid-cnt-attempted-q-arr-pkl/count_attempted_q_u_val_arr.npy',
                                 'avg_score_t': '../input/riiid-fe-tags-score-stats/avg_score_t_valid_arr.npy',
                                 'max_avg_score_t': '../input/riiid-fe-tags-score-stats/max_avg_score_t_valid_arr.npy',
                                 'min_avg_score_t': '../input/riiid-fe-tags-score-stats/min_avg_score_t_valid_arr.npy',
                                 'std_avg_score_t': '../input/riiid-fe-tags-score-stats/std_avg_score_t_valid_arr.npy',
                                 'avg_tag_freq': '../input/riiid-fe-tags-freq-stats/tag_freq_feats_arrs/avg_tag_frequency_valid_arr.npy',
                                 'max_tag_freq': '../input/riiid-fe-tags-freq-stats/tag_freq_feats_arrs/max_tag_frequency_valid_arr.npy',
                                 'min_tag_freq': '../input/riiid-fe-tags-freq-stats/tag_freq_feats_arrs/min_tag_frequency_valid_arr.npy',
                                 'std_tag_freq': '../input/riiid-fe-tags-freq-stats/tag_freq_feats_arrs/std_tag_frequency_valid_arr.npy',
                                 'tag_count_q': '../input/riiid-q-tag-count-dict/tag_count_q_valid_arr.npy',
                                 'avg_question_elapsed_time_c': '../input/riiid-question-time-stats/avg_question_elapsed_time_c_valid_arr.npy',
                                 'std_question_elapsed_time_c': '../input/riiid-question-time-stats/std_question_elapsed_time_c_valid_arr.npy',
                                 'avg_time_since_last_c': '../input/riiid-question-time-stats/avg_time_since_last_c_valid_arr.npy',
                                 'std_time_since_last_c': '../input/riiid-question-time-stats/std_time_since_last_c_valid_arr.npy',
                                 'diff_avg_question_elapsed_time_c': '../input/riiid-question-time-stats/diff_avg_question_elapsed_time_c_valid_arr.npy',
                                 'diff_avg_time_since_last_c': '../input/riiid-question-time-stats/diff_avg_time_since_last_c_valid_arr.npy',
                                 'ratio_diff_avg_question_elapsed_time_c': '../input/riiid-question-time-stats/ratio_diff_avg_question_elapsed_time_c_valid_arr.npy',
                                 'ratio_diff_avg_time_since_last_c': '../input/riiid-question-time-stats/ratio_diff_avg_time_since_last_c_valid_arr.npy',
                                 'q_explanation_avg_u': '../input/riiid-q-explanation-sum-avg/q_explanation_avg_u_valid_arr.npy',
                                 'q_explanation_sum_u': '../input/riiid-q-explanation-sum-avg/q_explanation_sum_u_valid_arr.npy',
                                 'avg_part_score_c': '../input/riiid-part-score-stats-avg-std/avg_part_score_c_valid_arr.npy',
                                 'std_part_score_c': '../input/riiid-part-score-stats-avg-std/std_part_score_c_valid_arr.npy',
                                 'time_since_last_sum_u': '../input/riiid-time-since-last-sum-avg/time_since_last_sum_u_valid_arr.npy',
                                 'time_since_last_avg_u': '../input/riiid-time-since-last-sum-avg/time_since_last_avg_u_valid_arr.npy',
                                 'curr_cont_score_avg_u': '../input/riiid-curr-cont-stats-arr-pkl/curr_cont_score_avg_u_valid_arr.npy',
                                 'curr_cont_score_sum_u': '../input/riiid-curr-cont-stats-arr-pkl/curr_cont_score_sum_u_valid_arr.npy',
                                 'curr_cont_tackled_q_count_u': '../input/riiid-curr-cont-stats-arr-pkl/curr_cont_tackled_q_count_u_valid_arr.npy',
                                 'avg_question_elapsed_time_c_right': '../input/riiid-q-elps-time-score-avg-arrs/avg_question_elapsed_time_c_right_valid_arr.npy',
                                 'avg_question_elapsed_time_c_wrong': '../input/riiid-q-elps-time-score-avg-arrs/avg_question_elapsed_time_c_wrong_valid_arr.npy',
                                 'avg_time_since_last_c_right': '../input/riiid-q-time-lag-score-avg-arrs/avg_time_since_last_c_right_valid_arr.npy',
                                 'avg_time_since_last_c_wrong': '../input/riiid-q-time-lag-score-avg-arrs/avg_time_since_last_c_wrong_valid_arr.npy',
                                 'diff_avg_question_elapsed_time_c_right': '../input/riiid-q-elps-time-score-diff-arrs/diff_avg_question_elapsed_time_c_right_valid_arr.npy',
                                 'diff_avg_question_elapsed_time_c_wrong': '../input/riiid-q-elps-time-score-diff-arrs/diff_avg_question_elapsed_time_c_wrong_valid_arr.npy',
                                 'diff_avg_time_since_last_c_right': '../input/riiid-q-time-lag-score-diff-arrs/diff_avg_time_since_last_c_right_valid_arr.npy',
                                 'diff_avg_time_since_last_c_wrong': '../input/riiid-q-time-lag-score-diff-arrs/diff_avg_time_since_last_c_wrong_valid_arr.npy',
                                 'ratio_avg_time_since_last_c_right': '../input/riiid-q-elps-time-right-ratios-arrs/ratio_avg_time_since_last_c_right_valid_arr.npy',
                                 'ratio_avg_time_since_last_c_wrong': '../input/riiid-q-elps-time-wrong-ratios-arrs/ratio_avg_time_since_last_c_wrong_valid_arr.npy',
                                 'ratio_diff_avg_question_elapsed_time_c_right': '../input/riiid-q-time-lag-right-ratios-arrs/ratio_diff_avg_question_elapsed_time_c_right_valid_arr.npy',
                                 'ratio_diff_avg_question_elapsed_time_c_wrong': '../input/riiid-q-time-lag-wrong-ratios-arrs/ratio_diff_avg_question_elapsed_time_c_wrong_valid_arr.npy',
                                 'difficulty_level': '../input/riiid-difficulty-lvl-q-arr-dicts/difficulty_q_level_valid_arr.npy',
                                 'difficulty_level_avg_score_c': '../input/riiid-difficulty-lvl-q-arr-dicts/difficulty_q_level_c_avg_valid_arr.npy',
                                 'difficulty_level_std_score_c': '../input/riiid-difficulty-lvl-q-arr-dicts/difficulty_q_level_c_std_valid_arr.npy',
                                 'sum_score_q_level': '../input/riiid-u-score-difficulty-lvl-corr/sum_score_q_level_valid_arr.npy',
                                 'q_count_level_u': '../input/riiid-u-score-difficulty-lvl-corr/q_count_level_valid_arr.npy',
                                 'avg_score_q_level_u': '../input/riiid-u-score-difficulty-lvl-corr/avg_score_q_level_valid_arr.npy',
                                 'avg_q_elapsed_time_lvl': '../input/riiid-q-elpsd-time-dffclty-lvl-arr-dicts/avg_q_elapsed_time_level_valid_arr.npy',
                                 'diff_avg_q_elapsed_time_lvl': '../input/riiid-q-elpsd-time-dffclty-lvl-arr-dicts/diff_avg_q_elapsed_time_per_lvl_valid_arr.npy',
                                 'ratio_diff_avg_q_elapsed_time_lvl': '../input/riiid-q-elpsd-time-dffclty-lvl-arr-dicts/ratio_avg_q_elapsed_time_per_lvl_valid_arr.npy',
                                 'diff_avg_score_q_lvl': '../input/riiid-u-score-difficulty-lvl/diff_avg_score_q_lvl_valid_arr.npy',
                                 'ratio_diff_avg_score_q_lvl': '../input/riiid-u-score-difficulty-lvl/ratio_diff_avg_score_q_lvl_valid_arr.npy',
                                 'ratio_diff_avg_score_u_c': '../input/riiid-ratio-diff-score-u-c/ratio_diff_avg_score_u_c_valid_arr.npy', 
                                 'num_answers_q': '../input/riiid-question-meta-stats/num_answers_q_valid_arr.npy',
                                 'q_count_trainset_c': '../input/riiid-question-meta-stats/q_count_trainset_c_valid_arr.npy',
                                 'tag_cluster': '../input/riiid-tags-cluster-dict-arrs/q_tags_clusters_valid_arr.npy',
                                 'avg_score_u_cluster_start': '../input/riiid-user-start-cluster/avg_score_u_cluster_start_valid_arr.npy',
                                 'first_answ': '../input/riiid-user-start-cluster/first_answ_valid_arr.npy',
                                 'u_cluster_start': '../input/riiid-user-start-cluster/u_cluster_start_valid_arr.npy',
                                 'avg_score_u_curr_day': '../input/riiid-fe-days-weeks-months/avg_score_u_curr_day_valid_arr.npy',
                                 'avg_score_u_curr_week': '../input/riiid-fe-days-weeks-months/avg_score_u_curr_week_valid_arr.npy',
                                 'avg_score_u_curr_month': '../input/riiid-fe-days-weeks-months/avg_score_u_curr_month_valid_arr.npy',
                                 'avg_score_u_last_day': '../input/riiid-fe-days-weeks-months/avg_score_u_last_day_valid_arr.npy',
                                 'avg_score_u_last_week': '../input/riiid-fe-days-weeks-months/avg_score_u_last_week_valid_arr.npy',
                                 'avg_score_u_last_month': '../input/riiid-fe-days-weeks-months/avg_score_u_last_month_valid_arr.npy',
                                 'curr_day_count_u': '../input/riiid-fe-days-weeks-months/curr_day_count_u_valid_arr.npy',
                                 'curr_week_count_u': '../input/riiid-fe-days-weeks-months/curr_week_count_u_valid_arr.npy',
                                 'curr_month_count_u': '../input/riiid-fe-days-weeks-months/curr_month_count_u_valid_arr.npy',
                                 'q_count_u_curr_day': '../input/riiid-fe-days-weeks-months/q_count_u_curr_day_valid_arr.npy',
                                 'q_count_u_curr_week': '../input/riiid-fe-days-weeks-months/q_count_u_curr_week_valid_arr.npy',
                                 'q_count_u_curr_month': '../input/riiid-fe-days-weeks-months/q_count_u_curr_month_valid_arr.npy',
                                 'q_count_u_last_day': '../input/riiid-fe-days-weeks-months/q_count_u_last_day_valid_arr.npy',
                                 'q_count_u_last_week': '../input/riiid-fe-days-weeks-months/q_count_u_last_week_valid_arr.npy',
                                 'q_count_u_last_month': '../input/riiid-fe-days-weeks-months/q_count_u_last_month_valid_arr.npy',
                                 'sum_score_u_curr_day': '../input/riiid-fe-days-weeks-months/sum_score_u_curr_day_valid_arr.npy',
                                 'sum_score_u_curr_week': '../input/riiid-fe-days-weeks-months/sum_score_u_curr_week_valid_arr.npy',
                                 'sum_score_u_curr_month': '../input/riiid-fe-days-weeks-months/sum_score_u_curr_month_valid_arr.npy',
                                 'time_since_last': '../input/riiid-time-since-last-hist-3/time_since_last_valid_arr.npy',
                                 'time_since_last_2': '../input/riiid-time-since-last-hist-3/time_since_last_2_valid_arr.npy',
                                 'time_since_last_3': '../input/riiid-time-since-last-hist-3/time_since_last_3_valid_arr.npy',
                                 'time_since_last_right': '../input/riiid-time-since-last-right-wrong-corrected/time_since_last_right_valid_arr.npy',
                                 'time_since_last_wrong': '../input/riiid-time-since-last-right-wrong-corrected/time_since_last_wrong_valid_arr.npy',
                                 'curr_avg_q_elapsed_time': '../input/riiid-avg-q-elapsed-time-c/curr_avg_q_elapsed_time_valid_arr.npy',
                                 'curr_avg_q_is_explained': '../input/riiid-curr-avg-q-is-explained/curr_avg_q_is_explained_valid_arr.npy',
                                 'avg_score_u_part': '../input/riiid-u-score-part/avg_score_u_part_valid_arr.npy',
                                 'sum_score_u_part': '../input/riiid-u-score-part/sum_score_u_part_valid_arr.npy',
                                 'q_count_u_part': '../input/riiid-u-score-part/q_count_u_part_valid_arr.npy',
                                 'avg_prior_cont_time': '../input/riiid-prior-q-cont-elapsed-time/avg_prior_cont_time_valid_arr.npy',
                                 'avg_prior_q_elapsed_time': '../input/riiid-prior-q-cont-elapsed-time/avg_prior_q_elapsed_time_valid_arr.npy',
                                 'avg_prior_cont_time_part': '../input/riiid-prior-q-cont-elasped-time-part/avg_prior_cont_time_part_valid_arr.npy',
                                 'avg_prior_q_elapsed_time_part': '../input/riiid-prior-q-cont-elasped-time-part/avg_prior_q_elapsed_time_part_valid_arr.npy',
                                 'ratio_curr_by_avg_session_time': '../input/riiid-ratio-curr-by-avg-sess-time/ratio_curr_by_avg_session_time_valid_arr.npy',
                                 'avg_u_time_since_last': '../input/riiid-t-diff-u-avg-and-ratios/avg_u_time_since_last_valid_arr.npy',
                                 'ratio_tdiff_by_avg_tdiff_u': '../input/riiid-t-diff-u-avg-and-ratios/ratio_tdiff_by_avg_tdiff_u_valid_arr.npy',
                                 'answered_correctly': '../input/riiid-target-question-rows-arr/target_valid_question_rows_arr.npy'
                               }


# Raw df (Q or Q+L) paths :
#---------------------------
# -> Q : online features (in data_preprocessing)
# -> Q+L : for local-CV (target in iter_test API)

MODE_10M_ROWS = True


if MODE_10M_ROWS :
    # 10M
    df_questions_kfold_paths = { 'fullset': {'train': '../input/riiid-df-q-kfold/raw_df_kfolds_10m/train_cv0_q_df_10M',
                                             'valid': '../input/riiid-df-q-kfold/raw_df_kfolds/valid_cv0_q_df'},
                                 'cv0': {'train': '../input/riiid-df-q-kfold/raw_df_kfolds_10m/train_cv0_q_df_10M',
                                         'valid': '../input/riiid-df-q-kfold/raw_df_kfolds/valid_cv0_q_df'},
                               }
else:
    # 5M
    df_questions_kfold_paths = { 'fullset': {'train': '../input/riiid-df-q-kfold/raw_df_kfolds/train_cv0_q_df',
                                             'valid': '../input/riiid-df-q-kfold/raw_df_kfolds/valid_cv0_q_df'},
                                 'cv0': {'train': '../input/riiid-df-q-kfold/raw_df_kfolds/train_cv0_q_df',
                                         'valid': '../input/riiid-df-q-kfold/raw_df_kfolds/valid_cv0_q_df'},
                               }


    

df_questions_lectures_kfold_paths = { 'fullset': {'train': '../input/riiid-df-q-kfold/raw_df_kfolds/train_cv0_ql_df',
                                                  'valid': '../input/riiid-df-q-kfold/raw_df_kfolds/valid_cv0_ql_df'},
                                      'cv0': {'train': '../input/riiid-df-q-kfold/raw_df_kfolds/train_cv0_ql_df',
                                              'valid': '../input/riiid-df-q-kfold/raw_df_kfolds/valid_cv0_ql_df'},
                                    }



# Offline preprocessed state dicts


# User states : K-fold dicts
#-------------------------------


users_feature_states_paths = {
                        'fullset': {
                              'answered_correctly_sum': '../input/riiid-avg-score-u-corrected/sum_score_u_fullset_dict.pkl',
                              'q_count': '../input/riiid-avg-score-u-corrected/q_count_u_fullset_dict.pkl',
                              'l_count': '../input/riiid-l-count-u-kfolds/l_count_u_dict_full.pkl',
                              'q_attempted': '../input/riiid-attempted-q-u-kfolds/attempted_q_u_dict_full.pkl',
                              'curr_cont_id': '../input/riiid-time-sin-last-cont-kfold/curr_cont_u_dict_full.pkl',
                              'curr_timestamp': '../input/riiid-time-sin-last-cont-kfold/curr_time_u_dict_full.pkl',
                              'last_timestamp': '../input/riiid-time-sin-last-cont-kfold/last_time_u_dict_full.pkl',
                              'last_t_diff': '../input/riiid-time-related-kfolds/last_t_diff_per_u_dict_full.pkl',
                              'session_num': '../input/riiid-session-related-kfolds/session_num_u_dict_full.pkl',
                              'last_cont_sum_answ': '../input/riiid-container-related-kfolds/last_cont_sum_answ_u_dict_full.pkl',
                              'curr_cont_sum_answ': '../input/riiid-container-related-kfolds/curr_cont_sum_answ_u_dict_full.pkl',
                              'curr_cont_q_count': '../input/riiid-container-related-kfolds/curr_cont_q_count_u_dict_full.pkl',
                              'last_cont_q_count': '../input/riiid-container-related-kfolds/last_cont_q_count_u_dict_full.pkl',
                              'last_session_break_time': '../input/riiid-session-related-kfolds/last_session_break_time_u_dict_full.pkl',
                              'sum_session_break_time': '../input/riiid-session-related-kfolds/session_break_time_sum_u_dict_full.pkl',
                              'curr_session_sum_score': '../input/riiid-session-related-kfolds/curr_session_sum_scores_u_dict_full.pkl',
                              'first_cont_session': '../input/riiid-session-related-kfolds/first_cont_session_u_dict_full.pkl',
                              'time_start_session': '../input/riiid-session-related-kfolds/time_start_session_u_dict_full.pkl',
                              'q_count_start_session': '../input/riiid-session-related-kfolds/q_count_start_session_u_dict_full.pkl',
                              'hist_100_answ': '../input/riiid-hist-100-to-10-score-deque/hist_100_bin_u_fullset_dict.pkl',
                              'curr_answ_streak': '../input/riiid-answ-streak-kfolds/curr_answ_streak_u_dict_full.pkl',
                              'max_answ_streak': '../input/riiid-answ-streak-kfolds/max_answ_streak_u_dict_full.pkl',
                              'hist_3_answ_streak': '../input/riiid-answ-streak-kfolds/hist_3_answ_streak_u_dict_full.pkl',
                              'first_try_success_count': '../input/riiid-first-try-success-count-u-arr-pkl/first_try_success_count_u_fullset_dict.pkl',
                              'unique_count_attempted_q': '../input/riiid-cnt-attempted-q-arr-pkl/count_attempted_q_u_fullset_dict.pkl',
                              'q_explanation_sum': '../input/riiid-q-explanation-sum-avg/q_explanation_sum_u_fullset_dict.pkl',
                              'time_since_last_sum_u': '../input/riiid-time-since-last-sum-avg/time_since_last_sum_u_fullset_dict.pkl',
                              'difficulty_lvl_count_u': '../input/riiid-u-score-difficulty-lvl-corr/cv_difficulty_lvl_count_u_fullset_dict.pkl',
                              'difficulty_lvl_sum_u': '../input/riiid-u-score-difficulty-lvl-corr/cv_difficulty_lvl_sum_u_fullset_dict.pkl',
                              'u_first_answ': '../input/riiid-user-start-cluster/u_first_answ_fullset_dict.pkl',
                              'u_first_q_cluster': '../input/riiid-user-start-cluster/u_first_q_cluster_fullset_dict.pkl',
                              'avg_score_u_last_day': '../input/riiid-fe-days-weeks-months/avg_score_u_last_day_fullset_dict.pkl',
                              'avg_score_u_last_week': '../input/riiid-fe-days-weeks-months/avg_score_u_last_week_fullset_dict.pkl',
                              'avg_score_u_last_month': '../input/riiid-fe-days-weeks-months/avg_score_u_last_month_fullset_dict.pkl',
                              'curr_day_count_u': '../input/riiid-fe-days-weeks-months/curr_day_count_u_fullset_dict.pkl',
                              'curr_week_count_u': '../input/riiid-fe-days-weeks-months/curr_week_count_u_fullset_dict.pkl',
                              'curr_month_count_u': '../input/riiid-fe-days-weeks-months/curr_month_count_u_fullset_dict.pkl',
                              'q_count_u_curr_day': '../input/riiid-fe-days-weeks-months/q_count_u_curr_day_fullset_dict.pkl',
                              'q_count_u_curr_week': '../input/riiid-fe-days-weeks-months/q_count_u_curr_week_fullset_dict.pkl',
                              'q_count_u_curr_month': '../input/riiid-fe-days-weeks-months/q_count_u_curr_month_fullset_dict.pkl',
                              'q_count_u_last_day': '../input/riiid-fe-days-weeks-months/q_count_u_last_day_fullset_dict.pkl',
                              'q_count_u_last_week': '../input/riiid-fe-days-weeks-months/q_count_u_last_week_fullset_dict.pkl',
                              'q_count_u_last_month': '../input/riiid-fe-days-weeks-months/q_count_u_last_month_fullset_dict.pkl',
                              'sum_score_u_curr_day': '../input/riiid-fe-days-weeks-months/sum_score_u_curr_day_fullset_dict.pkl',
                              'sum_score_u_curr_week': '../input/riiid-fe-days-weeks-months/sum_score_u_curr_week_fullset_dict.pkl',
                              'sum_score_u_curr_month': '../input/riiid-fe-days-weeks-months/sum_score_u_curr_month_fullset_dict.pkl',
                              'time_since_last_hist_3': '../input/riiid-time-since-last-hist-3/time_since_last_hist_3_fullset_dict.pkl',
                              'last_time_right': '../input/riiid-time-since-last-right-wrong-corrected/time_last_right_fullset_dict.pkl',
                              'last_time_wrong': '../input/riiid-time-since-last-right-wrong-corrected/time_last_wrong_fullset_dict.pkl',
                              'q_count_u_part': '../input/riiid-u-score-part/q_count_u_part_fullset_dict.pkl',
                              'sum_score_u_part': '../input/riiid-u-score-part/sum_score_u_part_fullset_dict.pkl',
                              'sum_prior_q_elapsed_time': '../input/riiid-prior-q-cont-elapsed-time/sum_prior_q_elapsed_time_fullset_dict.pkl',
                              'cont_count_part': '../input/riiid-prior-q-cont-elasped-time-part/cont_count_part_fullset_dict.pkl',
                              'last_cont_part': '../input/riiid-prior-q-cont-elasped-time-part/last_cont_part_fullset_dict.pkl', 
                              'sum_prior_q_elapsed_time_part': '../input/riiid-prior-q-cont-elasped-time-part/sum_prior_q_elapsed_time_part_fullset_dict.pkl',
                              'sum_time_since_last_u': '../input/riiid-t-diff-u-avg-and-ratios/sum_time_since_last_u_fullset_dict.pkl',
                             },
                        'cv0': {
                              'answered_correctly_sum': '../input/riiid-avg-score-u-corrected/sum_score_u_cv0_dict.pkl',
                              'q_count': '../input/riiid-avg-score-u-corrected/q_count_u_cv0_dict.pkl',
                              'l_count': '../input/riiid-l-count-u-kfolds/l_count_u_dict_cv0.pkl',
                              'q_attempted': '../input/riiid-attempted-q-u-kfolds/attempted_q_u_dict_cv0.pkl',
                              'curr_cont_id': '../input/riiid-time-sin-last-cont-kfold/curr_cont_u_dict_cv0.pkl',
                              'curr_timestamp': '../input/riiid-time-sin-last-cont-kfold/curr_time_u_dict_cv0.pkl',
                              'last_timestamp': '../input/riiid-time-sin-last-cont-kfold/last_time_u_dict_cv0.pkl',
                              'last_t_diff': '../input/riiid-time-related-kfolds/last_t_diff_per_u_dict_cv0.pkl',
                              'session_num': '../input/riiid-session-related-kfolds/session_num_u_dict_cv0.pkl',
                              'last_cont_sum_answ': '../input/riiid-container-related-kfolds/last_cont_sum_answ_u_dict_cv0.pkl',
                              'curr_cont_sum_answ': '../input/riiid-container-related-kfolds/curr_cont_sum_answ_u_dict_cv0.pkl',
                              'curr_cont_q_count': '../input/riiid-container-related-kfolds/curr_cont_q_count_u_dict_cv0.pkl',
                              'last_cont_q_count': '../input/riiid-container-related-kfolds/last_cont_q_count_u_dict_cv0.pkl',
                              'last_session_break_time': '../input/riiid-session-related-kfolds/last_session_break_time_u_dict_cv0.pkl',
                              'sum_session_break_time': '../input/riiid-session-related-kfolds/session_break_time_sum_u_dict_cv0.pkl',
                              'curr_session_sum_score': '../input/riiid-session-related-kfolds/curr_session_sum_scores_u_dict_cv0.pkl',
                              'first_cont_session': '../input/riiid-session-related-kfolds/first_cont_session_u_dict_cv0.pkl',
                              'time_start_session': '../input/riiid-session-related-kfolds/time_start_session_u_dict_cv0.pkl',
                              'q_count_start_session': '../input/riiid-session-related-kfolds/q_count_start_session_u_dict_cv0.pkl',
                              'hist_100_answ': '../input/riiid-hist-100-to-10-score-deque/hist_100_bin_u_cv0_dict.pkl',
                              'curr_answ_streak': '../input/riiid-answ-streak-kfolds/curr_answ_streak_u_dict_cv0.pkl',
                              'max_answ_streak': '../input/riiid-answ-streak-kfolds/max_answ_streak_u_dict_cv0.pkl',
                              'hist_3_answ_streak': '../input/riiid-answ-streak-kfolds/hist_3_answ_streak_u_dict_cv0.pkl',
                              'first_try_success_count': '../input/riiid-first-try-success-count-u-arr-pkl/first_try_success_count_u_cv0_dict.pkl',
                              'unique_count_attempted_q': '../input/riiid-cnt-attempted-q-arr-pkl/count_attempted_q_u_cv0_dict.pkl',
                              'q_explanation_sum': '../input/riiid-q-explanation-sum-avg/q_explanation_sum_u_cv0_dict.pkl',
                              'time_since_last_sum_u': '../input/riiid-time-since-last-sum-avg/time_since_last_sum_u_cv0_dict.pkl',
                              'difficulty_lvl_count_u': '../input/riiid-u-score-difficulty-lvl-corr/cv_difficulty_lvl_count_u_cv0_dict.pkl',
                              'difficulty_lvl_sum_u': '../input/riiid-u-score-difficulty-lvl-corr/cv_difficulty_lvl_sum_u_cv0_dict.pkl',
                              'u_first_answ': '../input/riiid-user-start-cluster/u_first_answ_cv0_dict.pkl',
                              'u_first_q_cluster': '../input/riiid-user-start-cluster/u_first_q_cluster_cv0_dict.pkl',
                              'avg_score_u_last_day': '../input/riiid-fe-days-weeks-months/avg_score_u_last_day_cv0_dict.pkl',
                              'avg_score_u_last_week': '../input/riiid-fe-days-weeks-months/avg_score_u_last_week_cv0_dict.pkl',
                              'avg_score_u_last_month': '../input/riiid-fe-days-weeks-months/avg_score_u_last_month_cv0_dict.pkl',
                              'curr_day_count_u': '../input/riiid-fe-days-weeks-months/curr_day_count_u_cv0_dict.pkl',
                              'curr_week_count_u': '../input/riiid-fe-days-weeks-months/curr_week_count_u_cv0_dict.pkl',
                              'curr_month_count_u': '../input/riiid-fe-days-weeks-months/curr_month_count_u_cv0_dict.pkl',
                              'q_count_u_curr_day': '../input/riiid-fe-days-weeks-months/q_count_u_curr_day_cv0_dict.pkl',
                              'q_count_u_curr_week': '../input/riiid-fe-days-weeks-months/q_count_u_curr_week_cv0_dict.pkl',
                              'q_count_u_curr_month': '../input/riiid-fe-days-weeks-months/q_count_u_curr_month_cv0_dict.pkl',
                              'q_count_u_last_day': '../input/riiid-fe-days-weeks-months/q_count_u_last_day_cv0_dict.pkl',
                              'q_count_u_last_week': '../input/riiid-fe-days-weeks-months/q_count_u_last_week_cv0_dict.pkl',
                              'q_count_u_last_month': '../input/riiid-fe-days-weeks-months/q_count_u_last_month_cv0_dict.pkl',
                              'sum_score_u_curr_day': '../input/riiid-fe-days-weeks-months/sum_score_u_curr_day_cv0_dict.pkl',
                              'sum_score_u_curr_week': '../input/riiid-fe-days-weeks-months/sum_score_u_curr_week_cv0_dict.pkl',
                              'sum_score_u_curr_month': '../input/riiid-fe-days-weeks-months/sum_score_u_curr_month_cv0_dict.pkl',
                              'time_since_last_hist_3': '../input/riiid-time-since-last-hist-3/time_since_last_hist_3_cv0_dict.pkl',
                              'last_time_right': '../input/riiid-time-since-last-right-wrong-corrected/time_last_right_cv0_dict.pkl',
                              'last_time_wrong': '../input/riiid-time-since-last-right-wrong-corrected/time_last_wrong_cv0_dict.pkl',
                              'q_count_u_part': '../input/riiid-u-score-part/q_count_u_part_cv0_dict.pkl',
                              'sum_score_u_part': '../input/riiid-u-score-part/sum_score_u_part_cv0_dict.pkl',
                              'sum_prior_q_elapsed_time': '../input/riiid-prior-q-cont-elapsed-time/sum_prior_q_elapsed_time_cv0_dict.pkl',
                              'cont_count_part': '../input/riiid-prior-q-cont-elasped-time-part/cont_count_part_cv0_dict.pkl',
                              'last_cont_part': '../input/riiid-prior-q-cont-elasped-time-part/last_cont_part_cv0_dict.pkl', 
                              'sum_prior_q_elapsed_time_part': '../input/riiid-prior-q-cont-elasped-time-part/sum_prior_q_elapsed_time_part_cv0_dict.pkl',
                              'sum_time_since_last_u': '../input/riiid-t-diff-u-avg-and-ratios/sum_time_since_last_u_cv0_dict.pkl',
                            },
                        }


question_feature_stats_paths = {'answered_correctly_avg': '../input/riiid-q-stats-dicts-pkl/answered_correctly_avg_c_dict.pkl',
                                'answered_correctly_std': '../input/riiid-q-stats-dicts-pkl/answered_correctly_std_c_dict.pkl',
                                'answered_correctly_median': '../input/riiid-q-stats-dicts-pkl/answered_correctly_median_c_dict.pkl',
                                'part': '../input/riiid-q-stats-dicts-pkl/questions_parts_dict.pkl',
                                'tags': '../input/riiid-q-stats-dicts-pkl/tags_per_questions_no_nan.pkl.pkl',
                                'tag_count': '../input/riiid-q-tag-count-dict/dict_q_tag_count.pkl', # buy time with RAM
                                'avg_question_elapsed_time': '../input/riiid-question-time-stats/avg_question_elapsed_time_c_dict.pkl',
                                'std_question_elapsed_time': '../input/riiid-question-time-stats/std_question_elapsed_time_c_dict.pkl',
                                'avg_time_since_last': '../input/riiid-question-time-stats/avg_time_since_last_c_dict.pkl',
                                'std_time_since_last': '../input/riiid-question-time-stats/std_time_since_last_c_dict.pkl',
                                'avg_part_score': '../input/riiid-part-score-stats-avg-std/avg_part_score_c_dict.pkl',
                                'std_part_score': '../input/riiid-part-score-stats-avg-std/std_part_score_c_dict.pkl',
                                'avg_question_elapsed_time_right': '../input/riiid-question-time-score-stats-dicts/avg_question_elapsed_time_c_right_dict.pkl',
                                'avg_question_elapsed_time_wrong': '../input/riiid-question-time-score-stats-dicts/avg_question_elapsed_time_c_wrong_dict.pkl',
                                'avg_time_since_last_c_right': '../input/riiid-question-time-score-stats-dicts/avg_time_since_last_c_right_dict.pkl',
                                'avg_time_since_last_c_wrong': '../input/riiid-question-time-score-stats-dicts/avg_time_since_last_c_wrong_dict.pkl',
                                'difficulty_level': '../input/riiid-difficulty-lvl-q-arr-dicts/difficulty_q_lvl_c_dict.pkl', # c_id -> level
                                'num_answers_q': '../input/riiid-question-meta-stats/num_answers_q_dict.pkl',
                                'q_count_trainset_c': '../input/riiid-question-meta-stats/q_count_trainset_c_dict.pkl',
                                'tag_cluster': '../input/riiid-tags-cluster-dict-arrs/q_tags_clusters_dict.pkl',
                                'curr_avg_q_elapsed_time_c': '../input/riiid-avg-q-elapsed-time-c/curr_avg_q_elapsed_time_c_dict.pkl',
                                'curr_avg_q_is_explained_c': '../input/riiid-curr-avg-q-is-explained/curr_avg_is_explained_c_dict.pkl',
                               }



user_cluster_avg_score_start_200_path = "../input/riiid-user-start-cluster/u_cluster_avg_start_200_dict.pkl"



question_level_feature_stats_paths = {'difficulty_level_avg_score_c': '../input/riiid-difficulty-lvl-q-arr-dicts/difficulty_q_lvl_c_avg_score_dict.pkl',
                                      'difficulty_level_std_score_c': '../input/riiid-difficulty-lvl-q-arr-dicts/difficulty_q_lvl_c_std_score_dict.pkl',
                                      'avg_q_elapsed_time_level' : '../input/riiid-q-elpsd-time-dffclty-lvl-arr-dicts/avg_q_elapsed_time_per_lvl_dict.pkl',
                                      }




tags_feature_stats_paths = {'answered_correctly_sum': '../input/riiid-t-pckls-of-dicts/answered_correctly_sum_t_dict.pkl',
                            'q_count': '../input/riiid-t-pckls-of-dicts/q_count_t_dict.pkl', # key t : kow much q in the train set with t
                            'ncount': '../input/riiid-fe-tags-freq-stats/ncount_tags_dict.pkl', # key t : kow much unique q with t
                            }


lecture_feature_stats_paths = {'part': '../input/riiid-part-per-lecture-pkl/part_per_lecture_dict.pkl'
                               }







# SAKT hyperparameters


SAKT_GROUP_PICKLE_FULLSET_PATH = '../input/riiid-sakt-02/group.pkl'
SAKT_MODEL_STATE_DICT_FULLSET_PATH = '../input/riiid-sakt-enc1/sakt_enc1.pth'

SAKT_SKILL_NP_PATH = '../input/riiid-sakt-train-metadata-1/sakt_skills.npy'



########################################################################################
# GLOBAL VARIABLES
########################################################################################


# CV config

INDEX_CV_SLICE = 0

# Constants

TIMESTAMP_MAX_TRAINSET = 82106074887

LOG_TIME_RATIO_MEDIAN = -3.304619
PRIOR_QUESTION_ELAPSED_TIME_TRAIN_MEAN = 25439.41

TIME_SINCE_LAST_MEDIAN = 67233
LIMIT_TIME_SESSION = TIME_SINCE_LAST_MEDIAN*10

CLUSTERING_U_START_Q_IDS = [7900, 128, 5692]



NUM_TRAIN_ROWS = 10000000 if MODE_10M_ROWS else 5000000 
NUM_VALID_ROWS = 2500000


MAX_ROWS_TRAIN = 96817414
MAX_ROWS_VALID = 2453886

NUM_UNIQUE_QUESTIONS = 13523


EARLY_STOPPING_ROUNDS_LGBM = 10




# Running args 

debug = False

training_flg = True # do training / load pretrained weights ?
validaten_flg = False # apply CV / submit ?
unit_test_inference_flg = False

smaller_local_cv_validset = False

from_train_flg = False if debug else True

feature_importance_flg = True

verbose = True


# Utils

def log(msg, end = '\n'):
    if verbose :
        print(msg, end=end)
    # add txt export
    
# Seeds        
random.seed(1)

In [3]:
# Train / Valid CV split :

i_cv_slice = INDEX_CV_SLICE

def get_train_valid_slices(idx_cv_slice):
    train_slice_r = [MAX_ROWS_TRAIN - (idx_cv_slice+1)*NUM_TRAIN_ROWS, MAX_ROWS_TRAIN - idx_cv_slice*NUM_TRAIN_ROWS]

    if idx_cv_slice == 0:
        # Former train/valid split : end of the trainset / full valid set
        valid_slice_r = [0, MAX_ROWS_VALID]
    else:
        # Split from train set : slice from the trainset/ the following subset as valid
        valid_slice_r = [train_slice_r[1], train_slice_r[1]+MAX_ROWS_VALID]
    
    return train_slice_r, valid_slice_r
    

train_slice_rows, valid_slice_rows = get_train_valid_slices(i_cv_slice)

print('(Train/Valid slices config) ind slice :', i_cv_slice, '| train slice :', train_slice_rows, '| valid slice :', valid_slice_rows)


def get_kfold_name(i_cv_slice, valid_flag):
    'Get the name of the fold corresponding to i_cv_slice & the valid / submission config.'
    assert i_cv_slice <= 2 and i_cv_slice >= 0
    if not valid_flag:
        return 'fullset'
    else:
        return 'cv'+str(i_cv_slice)
    
    

(Train/Valid slices config) ind slice : 0 | train slice : [86817414, 96817414] | valid slice : [0, 2453886]


# Data preparation

In [4]:
class DataPreprocessor(object):
    """ Object prepare data for training. """
    
    
    def __init__(self, offline_feats_train_paths, offline_feats_valid_paths, \
                     train_raw_q_path, valid_raw_q_path, X_feats, y_feats, train_slice_rows_i, valid_slice_rows_i, valid_flag=False): 

        self.feats_user_dicts = {}

        log('Offline features loading ...')
        self.train = self.init_df_offline(offline_feats_train_paths, slice_rows=train_slice_rows_i)
        if not valid_flag:
            if valid_slice_rows_i[0] == 0 : # beginning of the valid set
                self.valid = self.init_df_offline(offline_feats_valid_paths, slice_rows=valid_slice_rows_i)
            else: # slice from the train set
                self.valid = self.init_df_offline(offline_feats_train_paths, slice_rows=valid_slice_rows_i)
                
        
        log('Online features loading ...')
        self.train = self.get_feats_online(self.train, train_raw_q_path)
        if not valid_flag:
            self.valid = self.get_feats_online(self.valid, valid_raw_q_path)
        
        log('Features selection ...')
        dro_cols = list(set(self.train.columns) - set(X_feats))
        self.y_tr = self.train[y_feats]
        self.train.drop(dro_cols, axis=1, inplace=True)
        if not valid_flag:
            self.y_va = self.valid[y_feats]
            self.valid.drop(dro_cols, axis=1, inplace=True)
        _=gc.collect()
    
        log('Features type casting ...')
        self.train = self.init_types_df(self.train)
        if not valid_flag:
            self.valid = self.init_types_df(self.valid)
    
        log('Data preprocessor init complete.')

        
    ########################################################################################
    # PREPROCESSING MAIN FUNCTIONS
    ########################################################################################

        
    def init_df_offline(self, dict_features_paths, slice_rows = None):
        """ Initialize a dataframe from a dict of path to each features np array. 
        Get the last 5% of raws if smaller_df is True. """

        start_row, end_row = slice_rows
        
        df = pd.DataFrame()

        for feat in tqdm(dict_features_paths, desc= 'Loading offline features'):
            df[feat] = np.load(dict_features_paths[feat])[start_row:end_row]

        return df

    
    def get_feats_online(self, df, raw_df_q_feather_path):
        """Extract from raw_df the needed features, and add them to df. Return the updated df.
        Get the last 5% of raws if smaller_df is True."""
        
        
        field_needed = ['row_id','user_id', 'timestamp','content_type_id', 'answered_correctly', 'task_container_id', 'content_id', 'prior_question_had_explanation']
        
        raw_df = pd.read_feather(raw_df_q_feather_path)[field_needed]
        assert(len(df) == len(raw_df))

        # Raw feats
        log('...raw feats', ' ')
        df['timestamp'] = raw_df['timestamp'].values
        df['task_container_id'] = raw_df['task_container_id'].values
        df['content_id'] = raw_df['content_id'].values
        df['prior_question_had_explanation'] = raw_df['prior_question_had_explanation'].fillna(False).astype('int8').values

        # Straightforward computations
        log('...straightforward computations', ' ')
        df['time_btw_cont_mean'] = self.get_time_btw_containers_mean(df)
        df['time_per_action_mean'] = self.get_time_per_action_mean(df)
        df['is_first_question'] = self.get_is_first_question(df)
        df['last_cont_score_mean'] = self.get_last_cont_score_mean(df)

        df['cont_q_count'] = self.get_cont_q_count_preproc(raw_df)
        df['question_elapsed_time'] = self.get_question_elapsed_time(df)
        df['time_ratio'] = self.get_time_ratio(df)
        df['log_time_ratio'] = self.get_log_time_ratio(df)
        
        df['ratio_score_u_time_ratio'] = self.get_ratio_score_u_time_ratio(df)
        df['ratio_q_count_u_time_ratio'] = self.get_ratio_q_count_u_time_ratio(df)
        
        del raw_df
        gc.collect()

        log('...done')
        return df

 
    def init_types_df(self, df):
        conversion_table = {"avg_score_t": np.float16,
                          "answered_correctly_avg_u": np.float16,
                          "answered_correctly_avg_c": np.float16,
                          "answered_correctly_std_c": np.float16,
                          "prior_question_elapsed_time": np.float16,
                          "avg_session_q_count": np.float16,
                          "avg_score_u_hist_100_75": np.float16,
                          "avg_score_u_hist_75_50": np.float16,
                          "avg_score_u_hist_50_25": np.float16,
                          "avg_score_u_hist_25_0": np.float16,
                          "hist_score_diff_u_100_50": np.float16,
                          "hist_score_diff_u_75_25": np.float16,
                          "hist_score_diff_u_50_0": np.float16,
                          "hist_100_score_slope_u": np.float16,
                          "max_avg_score_t": np.float16,
                          "min_avg_score_t": np.float16,
                          "std_avg_score_t": np.float16,
                          "avg_tag_freq": np.float16,
                          "max_tag_freq": np.float16,
                          "min_tag_freq": np.float16,
                          "std_tag_freq" : np.float16, 
                          "ratio_diff_avg_question_elapsed_time_c": np.float16,
                          "ratio_diff_avg_time_since_last_c": np.float16,
                          "avg_q_elapsed_time_lvl": np.float16,
                          "diff_avg_q_elapsed_time_lvl": np.float16,
                          "last_cont_score_mean": np.float16,
                          "question_elapsed_time": np.float16,
                          "ratio_score_u_time_ratio": np.float16,
                          "q_count_u_curr_day": np.int16,
                          "q_count_u_curr_week": np.int16,
                          "q_count_u_last_day": np.int16,
                          "q_count_u_last_week": np.int16,}
        
        
        conversion_table_feats = set(conversion_table)
        train_feats = set(df.columns)
        
        required_conversion = train_feats.intersection(conversion_table_feats)
        
        conversion_output_dict={}
        for feat in required_conversion:
            conversion_output_dict[feat] = conversion_table[feat]
            
        return df.astype(conversion_output_dict)


    

    ########################################################################################
    # ONLINE FEATURES COMPUTATION FUNCTIONS
    ########################################################################################
    
    
    def get_time_btw_containers_mean(self, df):
        '''Mean time between containers. Return the array of features to store as "time_btw_cont_mean".'''
        return (df.timestamp/df.task_container_id).replace([np.inf, -np.inf], np.nan).fillna(0).astype(np.int64) 

    def get_time_per_action_mean(self, df):
        return (df.timestamp/(df.q_count_u + df.l_count_u)).fillna(0).astype(np.float32)

    def get_is_first_question(self, df):
        '''Boolean (stored in int8) feature which describe whether a row corresponds to the first question of the user
        or not. count_u must have been already initialize. Should be the case if df is "train" or "valid", or even
        "test_df" (in inference loop after having added user feats).'''
        return (df.q_count_u==0).values.astype(np.int8) 


    def get_last_cont_score_mean(self, df):
        '''Mean score (answered_correctly) obtained in the last container. 
        Return the array of features to store as "last_cont_score_mean"
        'last_container_sum_answ' and 'last_cont_num_q' must have already been initialized (should be
        the case for "train" and "valid", even "test_df" (end of the inference loop).'''
        return (df.last_container_sum_answ/df.last_cont_q_count).fillna(0).astype(np.float32)


    def get_cont_q_count_preproc(self, df):
        '''Get the number of question per container for df.
        Requires user_id',task_container_id,row_id.
        Return an ndarray contianint the values of "num_q_cont". (values : from 1 to 6). '''

        g = df[['user_id','task_container_id','row_id']].groupby(['user_id','task_container_id']).count()
        g.reset_index(inplace=True)
        g.rename(columns={'row_id':'num_q_cont'},inplace=True)
        g['num_q_cont'] = g['num_q_cont'].astype(np.int8)

        df = df.merge(g, on=['user_id','task_container_id'])
        df.loc[df[df.num_q_cont >= 6].index,'num_q_cont'] = 6

        del g
        gc.collect()

        return df.num_q_cont.values


    def get_question_elapsed_time(self, df):
        '''Get the mean time the user spent on each question of the current container.
        Requires : time_since_last, num_q_cont.'''
        return (df.time_since_last/df.cont_q_count).astype(np.float32).fillna(0)

    
    def get_time_ratio(self, df):
        '''Get the current time ratio, w.r.t. the max timestamp of the given training set. Use train set max
        to be more robust to very large timestamp of the testset.'''
        return (df.timestamp/TIMESTAMP_MAX_TRAINSET).astype(np.float32).fillna(0)
    
    
    def get_log_time_ratio(self, df):
        '''Get the log of the current time ratio, to emphasize differences btw small ratios, which are more 
        common than large one.'''
        return np.log(df.time_ratio).fillna(LOG_TIME_RATIO_MEDIAN)


    def get_ratio_score_u_time_ratio(self, df):
        '''Get the ratio between user score and its time ratio. Type : float32'''
        return (df.answered_correctly_avg_u/df.time_ratio).fillna(0)
        
    def get_ratio_q_count_u_time_ratio(self, df):
        '''Get "ratio_q_count_time_ratio" : q_count_u/time_ratio_u. Type : int32 '''
        ratio_q_count_u_time_ratio_arr = (df.q_count_u.to_numpy()/df.time_ratio.to_numpy()).astype(np.int32)
        ratio_q_count_u_time_ratio_arr = np.nan_to_num(ratio_q_count_u_time_ratio_arr, nan=0, posinf=2e7, neginf=0)
        np.clip(ratio_q_count_u_time_ratio_arr, 0, 2e7, out=ratio_q_count_u_time_ratio_arr)
        return ratio_q_count_u_time_ratio_arr
        
    
    
 
    ########################################################################################
    # UTILS
    ########################################################################################
        

    def train_info(self):
        self.train.info(memory_usage='deep')
        
    def get_train(self, FEATS):
        return self.train[FEATS], self.y_tr
    
    def get_valid(self, FEATS):
        return self.valid[FEATS], self.y_va
    
    
    def free_train(self):
        del self.train, self.y_tr # raw_data are free. Keep valid for ROC CV
        _=gc.collect()
    

In [5]:

FEATS = ['answered_correctly_avg_c', 'avg_score_q_level_u',
       'avg_score_u_part', 'is_first_attempt', 'time_since_last',
       'time_since_last_2', 'answered_correctly_std_c',
       'current_session_q_count', 'current_session_time',
       'question_elapsed_time', 'ratio_diff_avg_question_elapsed_time_c',
       'current_right_answ_streak', 'ratio_diff_avg_time_since_last_c',
       'curr_avg_q_is_explained', 'avg_score_u_hist_25_0',
       'time_since_last_right', 'first_try_success_count_u',
       'answered_correctly_avg_u', 'avg_prior_q_elapsed_time_part',
       'difficulty_level', 'curr_avg_q_elapsed_time',
       'unique_count_attempted_q_u', 'answered_correctly_median_c',
       'q_count_level_u', 'last_cont_score_mean', 'sum_score_q_level',
       'task_container_id',
       'prior_question_elapsed_time', 'avg_score_u_cluster_start',
       'min_tag_freq',
       'ratio_diff_avg_q_elapsed_time_lvl', 'sum_score_u_part',
       'time_since_last_wrong', 'avg_session_q_count',
       'answered_correctly_sum_u', 'time_since_last_3',
       'std_time_since_last_c',
       'current_session_avg_score', 'avg_tag_freq',
       'hist_1_right_answ_streak', 'part', 'q_explanation_sum_u',
       'ratio_score_u_time_ratio',
       'last_session_break_time', 'q_count_trainset_c',
       'avg_right_answ_streak_hist_3', 'q_explanation_avg_u',
       'avg_time_since_last_c', 'l_count_u']


TARGET = 'answered_correctly'


kfold_name = get_kfold_name(i_cv_slice, validaten_flg)

data_preproc_params = {'offline_feats_train_paths': offline_features_train_paths, 
                       'offline_feats_valid_paths': offline_features_valid_paths, 
                       'train_raw_q_path': df_questions_kfold_paths[kfold_name]['train'], 
                       'valid_raw_q_path': df_questions_kfold_paths[kfold_name]['valid'],
                       'X_feats': FEATS,
                       'y_feats': TARGET,
                       'train_slice_rows_i': train_slice_rows,
                       'valid_slice_rows_i': valid_slice_rows,
                       'valid_flag': False} # always False for training
                       
data_preprocessor = DataPreprocessor(**data_preproc_params)


if verbose :
    data_preprocessor.train_info()


Offline features loading ...


HBox(children=(FloatProgress(value=0.0, description='Loading offline features', max=132.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=0.0, description='Loading offline features', max=132.0, style=ProgressStyle…


Online features loading ...
...raw feats ...straightforward computations ...done
...raw feats ...straightforward computations ...done
Features selection ...
Features type casting ...
Data preprocessor init complete.
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 10000000 entries, 0 to 9999999
Data columns (total 49 columns):
 #   Column                                  Dtype  
---  ------                                  -----  
 0   answered_correctly_avg_u                float16
 1   answered_correctly_sum_u                int32  
 2   is_first_attempt                        int8   
 3   answered_correctly_avg_c                float16
 4   answered_correctly_std_c                float16
 5   answered_correctly_median_c             int8   
 6   part                                    int8   
 7   prior_question_elapsed_time             float16
 8   l_count_u                               int32  
 9   last_session_break_time                 int64  
 10  avg_session_q_count         

# Training

In [6]:
X_train, y_train = data_preprocessor.get_train(FEATS)
X_valid, y_valid = data_preprocessor.get_valid(FEATS)

# LGB datasets
lgb_train = lgb.Dataset(X_train, y_train)
lgb_valid = lgb.Dataset(X_valid, y_valid)


data_preprocessor.free_train()


In [7]:
print(f"Ready for {'training' if training_flg else 'loading'}")

Ready for training


In [8]:

# Training
if training_flg:
    model = lgb.train( 
                        {'objective': 'binary',
                        'lambda_l1': 1.0,
                        'lambda_l2': 1.0,
                        'subsample': 0.5,
                        'feature_fraction': 0.7,
                        }, 
                        lgb_train,
                        valid_sets=[lgb_train, lgb_valid],
                        verbose_eval=100,
                        num_boost_round=10000,
                        early_stopping_rounds=EARLY_STOPPING_ROUNDS_LGBM
                    )


    # Feature importance
    #_ = lgb.plot_importance(model) # Rough approach : rely on SHAP instead.


    
else:
    PRETRAINED_MODEL_PATH = '../input/riiid-lgb-fi-no-overfeat-49f/lgb_FI_no_overfeat_49feats.txt'
    model = lgb.Booster(model_file=PRETRAINED_MODEL_PATH)
    
    _ = lgb.plot_importance(model)
    

Training until validation scores don't improve for 10 rounds
[100]	training's binary_logloss: 0.528484	valid_1's binary_logloss: 0.53143
[200]	training's binary_logloss: 0.525318	valid_1's binary_logloss: 0.528426
[300]	training's binary_logloss: 0.523797	valid_1's binary_logloss: 0.527125
[400]	training's binary_logloss: 0.522781	valid_1's binary_logloss: 0.526369
[500]	training's binary_logloss: 0.521962	valid_1's binary_logloss: 0.5258
[600]	training's binary_logloss: 0.521309	valid_1's binary_logloss: 0.525391
[700]	training's binary_logloss: 0.520703	valid_1's binary_logloss: 0.525044
[800]	training's binary_logloss: 0.52021	valid_1's binary_logloss: 0.524818
[900]	training's binary_logloss: 0.519709	valid_1's binary_logloss: 0.524569
[1000]	training's binary_logloss: 0.519206	valid_1's binary_logloss: 0.524301
[1100]	training's binary_logloss: 0.518781	valid_1's binary_logloss: 0.524126
[1200]	training's binary_logloss: 0.51834	valid_1's binary_logloss: 0.523927
[1300]	training's

In [9]:
# Export
if training_flg:
    model.save_model('lgb_sakt_' + str(len(FEATS)) + 'feats.txt')

In [10]:
# Feature impoprtance
if feature_importance_flg:
    
    def get_feature_importance(shap_values, num_feat_max, X_valid):
        vals = np.abs(shap_values).mean(0)
        feature_importance = pd.DataFrame(list(zip(X_valid.columns, sum(vals))), columns=['col_name','feature_importance_vals'])
        feature_importance.sort_values(by=['feature_importance_vals'], ascending=False,inplace=True)
        return feature_importance[:num_feat_max]

    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(X_valid)
    
    top_features_df_1 = get_feature_importance(shap_values, len(FEATS), X_valid)
    display(top_features_df_1)
    

LightGBM binary classifier with TreeExplainer shap values output has changed to a list of ndarray


Unnamed: 0,col_name,feature_importance_vals
0,answered_correctly_avg_c,1122389.0
1,avg_score_q_level_u,1061189.0
2,avg_score_u_part,399786.6
3,is_first_attempt,299584.9
4,time_since_last,254003.3
5,time_since_last_2,253863.8
16,first_try_success_count_u,211633.0
7,current_session_q_count,174690.3
8,current_session_time,172640.5
26,task_container_id,171782.3


# Inference

Credit : The local CV API has been written by : https://www.kaggle.com/its7171/time-series-api-iter-test-emulator

### Setup

In [11]:
########################################################################################
# LOCAL-CV API
########################################################################################


class Iter_Valid(object):
    def __init__(self, df, max_user=1000):
        df = df.reset_index(drop=True)
        self.df = df
        self.user_answer = df['user_answer'].astype(str).values
        self.answered_correctly = df['answered_correctly'].astype(str).values
        df['prior_group_responses'] = "[]"
        df['prior_group_answers_correct'] = "[]"
        self.sample_df = df[df['content_type_id'] == 0][['row_id']]
        self.sample_df['answered_correctly'] = 0
        self.len = len(df)
        self.user_id = df.user_id.values
        self.task_container_id = df.task_container_id.values
        self.content_type_id = df.content_type_id.values
        self.max_user = max_user
        self.current = 0
        self.pre_user_answer_list = []
        self.pre_answered_correctly_list = []

    def __iter__(self):
        return self
    
    def fix_df(self, user_answer_list, answered_correctly_list, pre_start):
        df= self.df[pre_start:self.current].copy()
        sample_df = self.sample_df[pre_start:self.current].copy()
        df.loc[pre_start,'prior_group_responses'] = '[' + ",".join(self.pre_user_answer_list) + ']'
        df.loc[pre_start,'prior_group_answers_correct'] = '[' + ",".join(self.pre_answered_correctly_list) + ']'
        self.pre_user_answer_list = user_answer_list
        self.pre_answered_correctly_list = answered_correctly_list
        return df, sample_df

    def __next__(self):
        added_user = set()
        pre_start = self.current
        pre_added_user = -1
        pre_task_container_id = -1

        user_answer_list = []
        answered_correctly_list = []
        while self.current < self.len:
            crr_user_id = self.user_id[self.current]
            crr_task_container_id = self.task_container_id[self.current]
            crr_content_type_id = self.content_type_id[self.current]
            if crr_content_type_id == 1:
                # no more than one task_container_id of "questions" from any single user
                # so we only care for content_type_id == 0 to break loop
                user_answer_list.append(self.user_answer[self.current])
                answered_correctly_list.append(self.answered_correctly[self.current])
                self.current += 1
                continue
            if crr_user_id in added_user and ((crr_user_id != pre_added_user) or (crr_task_container_id != pre_task_container_id)):
                # known user(not prev user or differnt task container)
                return self.fix_df(user_answer_list, answered_correctly_list, pre_start)
            if len(added_user) == self.max_user:
                if  crr_user_id == pre_added_user and crr_task_container_id == pre_task_container_id:
                    user_answer_list.append(self.user_answer[self.current])
                    answered_correctly_list.append(self.answered_correctly[selfa.current])
                    self.current += 1
                    continue
                else:
                    return self.fix_df(user_answer_list, answered_correctly_list, pre_start)
            added_user.add(crr_user_id)
            pre_added_user = crr_user_id
            pre_task_container_id = crr_task_container_id
            user_answer_list.append(self.user_answer[self.current])
            answered_correctly_list.append(self.answered_correctly[self.current])
            self.current += 1
        if pre_start < self.current:
            return self.fix_df(user_answer_list, answered_correctly_list, pre_start)
        else:
            raise StopIteration()

In [12]:
########################################################################################
# TRIGGER LOCAL CV / SUBMISSION MODE
########################################################################################


if validaten_flg:

    target_feather_path = df_questions_lectures_kfold_paths[kfold_name]['valid']

    if smaller_local_cv_validset:
        target_df = pd.read_feather(target_feather_path)[:100]
    else:
        target_df = pd.read_feather(target_feather_path)

    if debug:
        target_df = target_df[:10000]

    iter_test = Iter_Valid(target_df,max_user=1000)

    predicted = []
    def set_predict(df):
        predicted.append(df)

    log('(Local CV) Environmnent initialization : OK')

else:
    import riiideducation
    env = riiideducation.make_env()
    iter_test = env.iter_test() 
    set_predict = env.predict

    log('(Submission mode) Environmnent initialization : OK')



(Submission mode) Environmnent initialization : OK


### State variables 

In [13]:
def load_obj(name):
    with open(name, 'rb') as f:
        return pickle.load(f)

In [14]:

class UsersState(object):
    """ Object which keeps track of user's variables state during inference. """
    
    def __init__(self, users_feature_states_paths, users_cluster_avg_score_start_paths):
        self.data = {}
        
        log('Loading offline features dict to user state ...')
        for feat in users_feature_states_paths:
            self.add_feat_to_data(feat, users_feature_states_paths[feat])
            
        log('Loading user cluster avg start hist 200 dict.')
        self.u_cluster_avg_start_200_dict = load_obj(users_cluster_avg_score_start_paths)
        
        log('User state init complete.')
    
    
    def add_feat_to_data(self, feat_name, feat_path):
        """ Load a feature state stored as a dict (pickle file) and set it to data for each user. """ 
        feat_dict = load_obj(feat_path)
        
        for u_id in feat_dict:
            
            if u_id not in self.data:
                self.add_new_user(u_id)
            
            self.data[u_id][feat_name] = feat_dict[u_id]
            
            
    def add_new_user(self, u_id):
        """Initialize data state dicts for a new user."""
        self.data[u_id] = {}
        self.data[u_id]['answered_correctly_sum'] = 0
        self.data[u_id]['q_count'] = 0
        self.data[u_id]['l_count'] = 0
        self.data[u_id]['q_attempted'] = bitarray(13550, endian='little')
        self.data[u_id]['q_attempted'].setall(0)
        self.data[u_id]['curr_cont_id'] = -1
        self.data[u_id]['curr_timestamp'] = 0
        self.data[u_id]['last_timestamp'] = 0
        self.data[u_id]['last_t_diff'] = -1
        self.data[u_id]['last_cont_sum_answ'] = 0
        self.data[u_id]['curr_cont_sum_answ'] = 0
        self.data[u_id]['curr_cont_q_count'] = 0
        self.data[u_id]['last_cont_q_count'] = 0
        self.data[u_id]['session_num'] = 0
        self.data[u_id]['first_cont_session'] = 0
        self.data[u_id]['time_start_session'] = 0
        self.data[u_id]['q_count_start_session'] = 0
        self.data[u_id]['last_session_break_time'] = 0
        self.data[u_id]['sum_session_break_time'] = 0
        self.data[u_id]['curr_session_sum_score'] = 0
        self.data[u_id]['hist_100_answ'] = deque(maxlen=100)
        self.data[u_id]['curr_answ_streak'] = 0
        self.data[u_id]['max_answ_streak'] = 0
        self.data[u_id]['hist_3_answ_streak'] = deque(maxlen=3)
        self.data[u_id]['first_try_success_count'] = 0
        self.data[u_id]['unique_count_attempted_q'] = 0
        self.data[u_id]['q_explanation_sum'] = 0
        self.data[u_id]['time_since_last_sum_u'] = 0
        self.data[u_id]['difficulty_lvl_count_u'] = {}
        self.data[u_id]['difficulty_lvl_count_u'].update({0: 0, 1: 0, 2: 0, 3: 0, 4: 0})
        self.data[u_id]['difficulty_lvl_sum_u'] = {}
        self.data[u_id]['difficulty_lvl_sum_u'].update({0: 0, 1: 0, 2: 0, 3: 0, 4: 0})
        self.data[u_id]['u_first_answ'] = np.int8(-1)
        self.data[u_id]['u_first_q_cluster'] = np.int16(0) # convention : 0 is undefined, -1 is 'others'
        
        self.data[u_id]['avg_score_u_last_day'] = 0
        self.data[u_id]['avg_score_u_last_week'] = 0
        self.data[u_id]['avg_score_u_last_month'] = 0
        self.data[u_id]['curr_day_count_u'] = 0
        self.data[u_id]['curr_week_count_u'] = 0
        self.data[u_id]['curr_month_count_u'] = 0
        self.data[u_id]['q_count_u_curr_day'] = 0
        self.data[u_id]['q_count_u_curr_week'] = 0
        self.data[u_id]['q_count_u_curr_month'] = 0
        self.data[u_id]['q_count_u_last_day'] = 0
        self.data[u_id]['q_count_u_last_week'] = 0
        self.data[u_id]['q_count_u_last_month'] = 0
        self.data[u_id]['sum_score_u_curr_day'] = 0
        self.data[u_id]['sum_score_u_curr_week'] = 0
        self.data[u_id]['sum_score_u_curr_month'] = 0
        
        self.data[u_id]['time_since_last_hist_3'] = deque(maxlen=3)
        self.data[u_id]['last_time_right'] = 0
        self.data[u_id]['last_time_wrong'] = 0
        
        self.data[u_id]['q_count_u_part'] = {}
        self.data[u_id]['q_count_u_part'].update({1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0})
        self.data[u_id]['sum_score_u_part'] = {}
        self.data[u_id]['sum_score_u_part'].update({1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0})
        
        self.data[u_id]['sum_prior_q_elapsed_time'] = 0
        
        self.data[u_id]['sum_prior_q_elapsed_time_part'] = {}
        self.data[u_id]['sum_prior_q_elapsed_time_part'].update({1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0})
        self.data[u_id]['cont_count_part'] = {}
        self.data[u_id]['cont_count_part'].update({1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0})
        self.data[u_id]['last_cont_part'] = -1
    
        self.data[u_id]['sum_time_since_last_u'] = 0
    
    
    def add_online_feats_to_data(self, feat_name, feat_dict):
        """ Load a feature state stored as a dict (pickle file) and set it to data for each user. """ 

        for u_id in feat_dict:
            
            if u_id not in self.data:
                self.add_new_user(u_id)
            
            self.data[u_id][feat_name] = feat_dict[u_id]
        

        
    ########################################################################################
    # DICTS UPDATE
    ########################################################################################
      

    def update_dicts_from_prev_state(self, df_prev):
        """ Update data using information contained by df for feats features. """
        
        feats = ['user_id','answered_correctly','content_type_id', 'task_container_id', 'content_id', 'timestamp', 'prior_question_had_explanation', 'difficulty_level', 'part', 'prior_question_elapsed_time']
        
        for (u_id, answ_correctly, c_type_id, cont_id, c_id, u_time, pre_is_expl, lvl, part, prior_q_elpsd_t) in df_prev[feats].values:
            
            # Type casts
            u_id = np.int32(u_id)
            answ_correctly = np.int8(answ_correctly)
            c_type_id = np.int8(c_type_id)
            cont_id = np.int16(cont_id)
            c_id = np.int16(c_id)
            u_time = np.int64(u_time)
            pre_is_expl = np.int8(pre_is_expl)
            lvl = np.int8(lvl)
            part = np.int8(part)
            prior_q_elpsd_t = np.float32(prior_q_elpsd_t)
            
            
            is_lecture = (c_type_id == 1)
            if is_lecture:
                self.data[u_id]['l_count'] += 1
                continue
                
                
            if u_id not in self.data:
                self.add_new_user(u_id)

                
            self.data[u_id]['answered_correctly_sum'] += np.int8(answ_correctly)
            self.data[u_id]['q_count'] += 1
            
            if self.data[u_id]['q_attempted'][c_id] == 0:
                self.data[u_id]['q_attempted'][c_id] = 1
                self.data[u_id]['first_try_success_count'] += np.int8(answ_correctly)
                self.data[u_id]['unique_count_attempted_q'] += 1


            is_new_cont_first_row = cont_id != self.data[u_id]['curr_cont_id'] 
            if is_new_cont_first_row :
                self.data[u_id]['last_timestamp'] = self.data[u_id]['curr_timestamp']
                self.data[u_id]['last_cont_q_count'] = self.data[u_id]['curr_cont_q_count']
                self.data[u_id]['last_cont_sum_answ'] = self.data[u_id]['curr_cont_sum_answ']
                
                self.data[u_id]['curr_cont_id'] = np.int16(cont_id)
                
                self.data[u_id]['curr_timestamp'] = np.int64(u_time)
                self.data[u_id]['curr_cont_q_count'] = 0
                self.data[u_id]['curr_cont_sum_answ'] = 0
            
            
            self.data[u_id]['curr_cont_sum_answ'] += np.int8(answ_correctly)
            self.data[u_id]['curr_cont_q_count'] += 1
            
            
            t_diff = u_time - self.data[u_id]['last_timestamp']

            
            
            len_time_since_hist = len(self.data[u_id]['time_since_last_hist_3'])

            if t_diff > 0:
                if len_time_since_hist > 0:
                    if t_diff != self.data[u_id]['time_since_last_hist_3'][-1] :
                        self.data[u_id]['time_since_last_hist_3'].append(t_diff)
                else:
                    self.data[u_id]['time_since_last_hist_3'].append(t_diff)
            
            
            if answ_correctly :
                self.data[u_id]['last_time_right'] = np.int64(u_time)
            else :
                self.data[u_id]['last_time_wrong'] = np.int64(u_time)
    
            
            
            
            is_new_sess = (self.data[u_id]['last_t_diff'] != -1) & \
                              ((t_diff > 100*self.data[u_id]['last_t_diff']) or (t_diff > LIMIT_TIME_SESSION)) & \
                                  (t_diff != 0)
            
            
            if is_new_sess: 
                self.data[u_id]['session_num'] += 1
                self.data[u_id]['first_cont_session'] = np.int16(cont_id)
                self.data[u_id]['time_start_session'] = np.int64(u_time)
                self.data[u_id]['q_count_start_session'] = self.data[u_id]['q_count'] - 1 # discard current q (next sess)
                self.data[u_id]['last_session_break_time'] = t_diff
                self.data[u_id]['sum_session_break_time'] += t_diff
                self.data[u_id]['curr_session_sum_score'] = 0
                
            if t_diff != 0:
                self.data[u_id]['last_t_diff'] = t_diff
                
            self.data[u_id]['hist_100_answ'].append(np.int8(answ_correctly))
            

            if answ_correctly == 1 :
                self.data[u_id]['curr_answ_streak'] += 1
            else:
                if self.data[u_id]['curr_answ_streak'] > self.data[u_id]['max_answ_streak']:
                    self.data[u_id]['max_answ_streak'] = self.data[u_id]['curr_answ_streak']
                
                self.data[u_id]['hist_3_answ_streak'].append(self.data[u_id]['curr_answ_streak'])
                self.data[u_id]['curr_answ_streak'] = 0            
        
        
            self.data[u_id]['q_explanation_sum'] += np.int8(pre_is_expl)
            
            if is_new_cont_first_row:
                self.data[u_id]['time_since_last_sum_u'] += t_diff
                                                       
        
            self.data[u_id]['difficulty_lvl_sum_u'][lvl] += np.int8(answ_correctly)
            self.data[u_id]['difficulty_lvl_count_u'][lvl] += 1
            

            
            is_first_answ_undefined = self.data[u_id]['u_first_answ'] == -1
            if is_first_answ_undefined :
                self.data[u_id]['u_first_answ'] = np.int8(answ_correctly)
            
            is_first_q_cluster_undefined = self.data[u_id]['u_first_q_cluster'] == 0
            if is_first_q_cluster_undefined :
                self.data[u_id]['u_first_q_cluster'] = np.int16(c_id) if (c_id in CLUSTERING_U_START_Q_IDS) else np.int16(-1)
            
            
            
            self.data[u_id]['q_count_u_part'][part] += 1
            self.data[u_id]['sum_score_u_part'][part] += np.int8(answ_correctly)
            
            if not np.isnan(prior_q_elpsd_t):
                self.data[u_id]['sum_prior_q_elapsed_time'] += int(prior_q_elpsd_t)*self.data[u_id]['last_cont_q_count']
            
            if is_new_cont_first_row :
                
                is_last_cont_part_first_row = self.data[u_id]['last_cont_part'] == -1        
                if not is_last_cont_part_first_row :
                    part_last_cont = self.data[u_id]['last_cont_part']

                    if not np.isnan(prior_q_elpsd_t):
                        self.data[u_id]['sum_prior_q_elapsed_time_part'][part_last_cont] += int(prior_q_elpsd_t)*self.data[u_id]['last_cont_q_count']
                        
                self.data[u_id]['cont_count_part'][part] += 1
                
                
            self.data[u_id]['last_cont_part'] = part
            
            if is_new_cont_first_row and (t_diff > 0):
                self.data[u_id]['sum_time_since_last_u'] += t_diff 
            
        
        
        
    ########################################################################################
    # FEATURES GETTERS
    ########################################################################################
            
    
    def get_q_feats_from_curr_state(self, df_curr):
        """ Get question features data using current user's states for each row of df_curr. Return a dataframe with 
        the new columns concatenated to df_curr.
        Warning : df_curr must only contains questions. """
        
        answ_corr_sum_arr = np.zeros(len(df_curr), dtype=np.int32)
        q_count_arr = np.zeros(len(df_curr), dtype=np.int32)
        is_first_attempt_arr = np.zeros(len(df_curr),dtype=np.int8)
        
        sess_num_arr = np.zeros(len(df_curr), dtype=np.int16)
        same_container_as_last_arr = np.zeros(len(df_curr),dtype=np.int8)
        last_cont_sum_answ_u_arr = np.zeros(len(df_curr),dtype=np.int8)
        last_cont_q_count_u_arr = np.zeros(len(df_curr),dtype=np.int8)
        same_sess_as_last_u_arr = np.full(len(df_curr), 1, dtype=np.int8)
        current_session_time_arr = np.zeros(len(df_curr), dtype=np.int64)
        current_session_q_count_arr = np.zeros(len(df_curr), dtype=np.int32)
        last_session_break_time_arr = np.zeros(len(df_curr), dtype=np.int64)
        avg_break_time_arr = np.zeros(len(df_curr), dtype=np.int64)
        session_avg_time_arr = np.zeros(len(df_curr), dtype=np.int64)
        avg_session_q_count_arr = np.zeros(len(df_curr), dtype=np.float32)
        current_session_avg_score_arr = np.zeros(len(df_curr), dtype=np.float16)
        
        avg_score_u_hist_25_0_arr = np.zeros(len(df_curr), dtype=np.float32)
        avg_score_u_hist_50_25_arr = np.zeros(len(df_curr), dtype=np.float32)
        avg_score_u_hist_75_50_arr = np.zeros(len(df_curr), dtype=np.float32)
        avg_score_u_hist_100_75_arr = np.zeros(len(df_curr), dtype=np.float32)
        hist_score_diff_u_50_0_arr = np.zeros(len(df_curr), dtype=np.float32)
        hist_score_diff_u_75_25_arr = np.zeros(len(df_curr), dtype=np.float32)
        hist_score_diff_u_100_50_arr = np.zeros(len(df_curr), dtype=np.float32)
        hist_100_score_slope_u_arr = np.zeros(len(df_curr), dtype=np.float32)
        
        current_right_answ_streak_arr = np.zeros(len(df_curr), dtype=np.int16)
        max_right_answ_streak_arr = np.zeros(len(df_curr), dtype=np.int16)
        hist_1_right_answ_streak_arr = np.zeros(len(df_curr), dtype=np.int16)
        hist_2_right_answ_streak_arr = np.zeros(len(df_curr), dtype=np.int16)
        hist_3_right_answ_streak_arr = np.zeros(len(df_curr), dtype=np.int16)
        avg_right_answ_streak_hist_3_arr = np.zeros(len(df_curr), dtype=np.float32)
        
        first_try_success_count_u_arr = np.zeros(len(df_curr),dtype=np.int16)
        unique_count_attempted_q_u_arr = np.zeros(len(df_curr),dtype=np.int16)
        
        q_explanation_sum_u_arr = np.zeros(len(df_curr),dtype=np.int16)
        q_explanation_avg_u_arr = np.zeros(len(df_curr),dtype=np.float16)
        
        time_since_last_sum_u_arr = np.zeros(len(df_curr),dtype=np.int64)
        time_since_last_avg_u_arr = np.zeros(len(df_curr),dtype=np.int32)
        
        curr_cont_score_sum_u_arr = np.zeros(len(df_curr),dtype=np.int8)
        curr_cont_tackled_q_count_u_arr = np.zeros(len(df_curr),dtype=np.int8)
        curr_cont_score_avg_u_arr = np.zeros(len(df_curr),dtype=np.float16)
        
        sum_score_q_level_arr = np.zeros(len(df_curr), dtype=np.int32)
        q_count_level_arr = np.zeros(len(df_curr), dtype=np.int32)
        avg_score_q_level_arr = np.zeros(len(df_curr), dtype=np.float16)
        
        first_answ_arr = np.zeros(len(df_curr), dtype=np.int8)
        u_cluster_start_arr = np.zeros(len(df_curr), dtype=np.int16)
        avg_score_u_cluster_start_arr = np.zeros(len(df_curr), dtype=np.float16)
        
        time_since_last_arr = np.zeros(len(df_curr), dtype=np.int64)
        time_since_last_2_arr = np.zeros(len(df_curr), dtype=np.int64)
        time_since_last_3_arr = np.zeros(len(df_curr), dtype=np.int64)
        
        time_since_last_right_arr = np.zeros(len(df_curr), dtype=np.int64)
        time_since_last_wrong_arr = np.zeros(len(df_curr), dtype=np.int64)
        
        sum_score_u_part_arr = np.zeros(len(df_curr), dtype=np.int32)
        q_count_u_part_arr = np.zeros(len(df_curr), dtype=np.int32)
        avg_score_u_part_arr = np.zeros(len(df_curr), dtype=np.float16)
        
        avg_prior_cont_time_arr = np.zeros(len(df_curr), dtype=np.float32)
        avg_prior_q_elapsed_time_arr = np.zeros(len(df_curr), dtype=np.float32)
    
        avg_prior_cont_time_part_arr = np.zeros(len(df_curr), dtype=np.float32)
        avg_prior_q_elapsed_time_part_arr = np.zeros(len(df_curr), dtype=np.float32)
        
        ratio_curr_by_avg_session_time_arr = np.zeros(len(df_curr), dtype=np.float16)
    
        avg_u_time_since_last_arr = np.zeros(len(df_curr), dtype=np.int64)
        ratio_tdiff_by_avg_tdiff_u_arr = np.zeros(len(df_curr), dtype=np.float16)
    
        
    
    
        df_curr_feats = df_curr[['user_id','content_id','task_container_id','timestamp', 'prior_question_had_explanation', 'difficulty_level', 'part', 'prior_question_elapsed_time']].values
    
        for i_row, (u_id, c_id, cont_id, u_time, pre_is_expl, lvl, part, prior_q_elpsd_t) in enumerate(df_curr_feats):
            
            # Type casts
            u_id = np.int32(u_id)
            c_id = np.int16(c_id)
            cont_id = np.int16(cont_id)
            u_time = np.int64(u_time)
            pre_is_expl = np.int8(pre_is_expl)
            lvl = np.int8(lvl)
            part = np.int8(part)
            prior_q_elpsd_t = np.float32(prior_q_elpsd_t)
            
            
            if u_id not in self.data:
                self.add_new_user(u_id)
                    
                    
            answ_corr_sum_arr[i_row] = self.data[u_id]['answered_correctly_sum']
            q_count_arr[i_row] = self.data[u_id]['q_count']
            
            if self.data[u_id]['q_attempted'][c_id] == 0:
                is_first_attempt_arr[i_row] = 1
            
            is_new_container = cont_id != self.data[u_id]['curr_cont_id']
            if is_new_container:
                t_diff = u_time - self.data[u_id]['curr_timestamp']
                last_cont_q_count_u_arr[i_row] = self.data[u_id]['curr_cont_q_count']
                last_cont_sum_answ_u_arr[i_row] = self.data[u_id]['curr_cont_sum_answ']
                
                curr_cont_score_avg_u_arr[i_row] = 0
                curr_cont_score_sum_u_arr[i_row] = 0
                curr_cont_tackled_q_count_u_arr[i_row] = 0 
                
            else:                
                t_diff = u_time - self.data[u_id]['last_timestamp']
                last_cont_q_count_u_arr[i_row] = self.data[u_id]['last_cont_q_count']
                last_cont_sum_answ_u_arr[i_row] = self.data[u_id]['last_cont_sum_answ']
                same_container_as_last_arr[i_row] = 1
                
                curr_cont_score_avg_u_arr[i_row] = self.data[u_id]['curr_cont_sum_answ'] / self.data[u_id]['curr_cont_q_count'] if (self.data[u_id]['curr_cont_q_count'] != 0) else 0
                curr_cont_score_sum_u_arr[i_row] = self.data[u_id]['curr_cont_sum_answ']
                curr_cont_tackled_q_count_u_arr[i_row] = self.data[u_id]['curr_cont_q_count']
                
                
            
            len_time_since_hist = len(self.data[u_id]['time_since_last_hist_3'])
        
            if is_new_container:
                time_since_last_arr[i_row] = t_diff # == u_time - curr_timestamp_dict[u_id]
                time_since_last_2_arr[i_row] = self.data[u_id]['time_since_last_hist_3'][-1] if len_time_since_hist > 0 else 0
                time_since_last_3_arr[i_row] = self.data[u_id]['time_since_last_hist_3'][-2] if len_time_since_hist > 1 else 0
            else:                
                time_since_last_arr[i_row] = self.data[u_id]['time_since_last_hist_3'][-1] if len_time_since_hist > 0 else 0
                time_since_last_2_arr[i_row] = self.data[u_id]['time_since_last_hist_3'][-2] if len_time_since_hist > 1 else 0
                time_since_last_3_arr[i_row] = self.data[u_id]['time_since_last_hist_3'][-3] if len_time_since_hist > 2 else 0

                
            time_since_last_right_arr[i_row] = u_time - self.data[u_id]['last_time_right']
            time_since_last_wrong_arr[i_row] = u_time - self.data[u_id]['last_time_wrong']
            
    
    
            is_new_sess = (self.data[u_id]['last_t_diff'] != -1) & \
                              ((t_diff > 100*self.data[u_id]['last_t_diff']) or (t_diff > LIMIT_TIME_SESSION)) & \
                                  (t_diff != 0) # returns 1 until session_num is updated
            

            
            sess_num = self.data[u_id]['session_num'] + int(is_new_sess)
            sess_num_arr[i_row] = sess_num 
            
            if is_new_sess or (cont_id == self.data[u_id]['first_cont_session']):
                same_sess_as_last_u_arr[i_row] = 0
            
            current_session_time_arr[i_row] = 0 if is_new_sess else (u_time - self.data[u_id]['time_start_session'])
            
            curr_sess_q_count = 0 if is_new_sess else (self.data[u_id]['q_count'] - self.data[u_id]['q_count_start_session']) # noisy effect on q spread across batches
            current_session_q_count_arr[i_row] = curr_sess_q_count
            
            if is_new_sess :
                last_session_break_time_arr[i_row] = t_diff
            else: 
                last_session_break_time_arr[i_row] = self.data[u_id]['last_session_break_time']
            
            
            
            curr_session_break_time = self.data[u_id]['sum_session_break_time']
            if is_new_sess:
                curr_session_break_time += t_diff
            
            
            if sess_num != 0:
                avg_break_time_arr[i_row] = curr_session_break_time / sess_num    # sess_num == number of breaks

            session_avg_time_arr[i_row] = (u_time - curr_session_break_time) / (sess_num + 1) # sess_num == (number of run session + 1)
            

            avg_session_q_count_arr[i_row] = self.data[u_id]['q_count'] / (sess_num + 1)
            

            if curr_sess_q_count != 0 :
                current_session_avg_score_arr[i_row] = self.data[u_id]['curr_session_sum_score'] / curr_sess_q_count
                

                
            if len(self.data[u_id]['hist_100_answ']) != 0 :
                hist_100 = np.array(self.data[u_id]['hist_100_answ'])
                avg_score_u_hist_25_0_arr[i_row] = hist_100[-25:].mean() if (len(hist_100) > 0) else 0
                avg_score_u_hist_50_25_arr[i_row] = hist_100[-50:-25].mean() if (len(hist_100) > 25) else 0
                avg_score_u_hist_75_50_arr[i_row] = hist_100[-75:-50].mean() if (len(hist_100) > 50) else 0
                avg_score_u_hist_100_75_arr[i_row] = hist_100[:-75].mean() if (len(hist_100) > 75) else 0
                
                hist_score_diff_u_50_0_arr[i_row] = avg_score_u_hist_25_0_arr[i_row] - avg_score_u_hist_50_25_arr[i_row]
                hist_score_diff_u_75_25_arr[i_row] = avg_score_u_hist_50_25_arr[i_row] - avg_score_u_hist_75_50_arr[i_row]
                hist_score_diff_u_100_50_arr[i_row] = avg_score_u_hist_75_50_arr[i_row] - avg_score_u_hist_100_75_arr[i_row]
                
                hist_100_score_slope_u_arr[i_row] = np.mean([hist_score_diff_u_50_0_arr[i_row], \
                                                              hist_score_diff_u_75_25_arr[i_row] , \
                                                               hist_score_diff_u_100_50_arr[i_row]])
                
                
            current_right_answ_streak_arr[i_row] = self.data[u_id]['curr_answ_streak']
            max_right_answ_streak_arr[i_row] = self.data[u_id]['max_answ_streak']
            
            if len(self.data[u_id]['hist_3_answ_streak']) > 0:
                hist_1_right_answ_streak_arr[i_row] = self.data[u_id]['hist_3_answ_streak'][-1]
            if len(self.data[u_id]['hist_3_answ_streak']) > 1:
                hist_2_right_answ_streak_arr[i_row] = self.data[u_id]['hist_3_answ_streak'][-2]
            if len(self.data[u_id]['hist_3_answ_streak']) > 2:
                hist_3_right_answ_streak_arr[i_row] = self.data[u_id]['hist_3_answ_streak'][-3]

            if len(self.data[u_id]['hist_3_answ_streak']) > 0:
                avg_right_answ_streak_hist_3_arr[i_row] = sum(self.data[u_id]['hist_3_answ_streak'])/len(self.data[u_id]['hist_3_answ_streak'])

                
            first_try_success_count_u_arr[i_row] = self.data[u_id]['first_try_success_count']
            unique_count_attempted_q_u_arr[i_row] = self.data[u_id]['unique_count_attempted_q']
            
            
            q_explanation_sum_u_arr[i_row] = self.data[u_id]['q_explanation_sum']+pre_is_expl
            q_explanation_avg_u_arr[i_row] = (self.data[u_id]['q_explanation_sum']+pre_is_expl)/self.data[u_id]['q_count'] if (self.data[u_id]['q_count'] != 0) else 0
        
            time_since_last_sum_u_arr[i_row] = self.data[u_id]['time_since_last_sum_u']
            time_since_last_avg_u_arr[i_row] = self.data[u_id]['time_since_last_sum_u']/cont_id if (cont_id != 0) else 0
                

            sum_score_q_level_arr[i_row] = self.data[u_id]['difficulty_lvl_sum_u'][lvl]
            q_count_level_arr[i_row] = self.data[u_id]['difficulty_lvl_count_u'][lvl]
            avg_score_q_level_arr[i_row] = self.data[u_id]['difficulty_lvl_sum_u'][lvl] / self.data[u_id]['difficulty_lvl_count_u'][lvl] if (self.data[u_id]['difficulty_lvl_count_u'][lvl] != 0) else 0
                           
        
        
            first_answ = self.data[u_id]['u_first_answ']
            first_answ_arr[i_row] = first_answ
            
            u_clstr_start = self.data[u_id]['u_first_q_cluster']            
            u_cluster_start_arr[i_row] = u_clstr_start
        
        
            is_clustr_start_defined = (u_clstr_start != 0) & (cont_id != 0)
            if is_clustr_start_defined:
                if cont_id < 200:
                    avg_score_u_cluster_start_arr[i_row] = self.u_cluster_avg_start_200_dict[u_clstr_start][first_answ][cont_id]
                else:
                    avg_score_u_cluster_start_arr[i_row] = self.u_cluster_avg_start_200_dict[u_clstr_start][first_answ][199]
            else:
                avg_score_u_cluster_start_arr[i_row] = -1


            
            sum_score_u_part_arr[i_row] = self.data[u_id]['sum_score_u_part'][part]
            q_count_u_part_arr[i_row] = self.data[u_id]['q_count_u_part'][part]
            avg_score_u_part_arr[i_row] = self.data[u_id]['sum_score_u_part'][part] / self.data[u_id]['q_count_u_part'][part] if (self.data[u_id]['q_count_u_part'][part] != 0) else 0 

            
            if not np.isnan(prior_q_elpsd_t):
                updated_sum_prior_q_elapsed_time = int(prior_q_elpsd_t)*last_cont_q_count_u_arr[i_row] + self.data[u_id]['sum_prior_q_elapsed_time']
                avg_prior_cont_time_arr[i_row] = updated_sum_prior_q_elapsed_time/cont_id if (cont_id != 0) else 0
                avg_prior_q_elapsed_time_arr[i_row] = updated_sum_prior_q_elapsed_time/self.data[u_id]['q_count'] if (self.data[u_id]['q_count'] != 0) else 0
        
            
            is_same_part_as_last_cont = self.data[u_id]['last_cont_part'] == part
            if is_new_container and is_same_part_as_last_cont and not np.isnan(prior_q_elpsd_t):
                # Last cont is related to the same part => dicts are not updated with the current prior_question_elapsed_time
                # value yet, but we can already use it to compute current features.
                updated_sum_prior_q_elapsed_time_part = int(prior_q_elpsd_t)*last_cont_q_count_u_arr[i_row] + self.data[u_id]['sum_prior_q_elapsed_time_part'][part]

                avg_prior_cont_time_part_arr[i_row] = updated_sum_prior_q_elapsed_time_part/self.data[u_id]['cont_count_part'][part] if (self.data[u_id]['cont_count_part'][part] != 0) else 0
                avg_prior_q_elapsed_time_part_arr[i_row] = updated_sum_prior_q_elapsed_time_part/self.data[u_id]['q_count_u_part'][part] if (self.data[u_id]['q_count_u_part'][part] != 0) else 0

            else:
                # first row, or dicts related to this part have already been updated
                avg_prior_cont_time_part_arr[i_row] = self.data[u_id]['sum_prior_q_elapsed_time_part'][part]/self.data[u_id]['cont_count_part'][part] if (self.data[u_id]['cont_count_part'][part] != 0) else 0
                avg_prior_q_elapsed_time_part_arr[i_row] = self.data[u_id]['sum_prior_q_elapsed_time_part'][part]/self.data[u_id]['q_count_u_part'][part] if (self.data[u_id]['q_count_u_part'][part] != 0) else 0
        
        
            ratio_curr_by_avg_session_time_arr[i_row] = current_session_time_arr[i_row] / session_avg_time_arr[i_row] if (session_avg_time_arr[i_row] != 0) else 0
        
            
            if is_new_container:
                # sum of tdiff dict not updated yet
                sum_time_since_last_updated = self.data[u_id]['sum_time_since_last_u']  + t_diff
                avg_u_time_since_last_arr[i_row] = sum_time_since_last_updated/cont_id if (cont_id != 0) else 0
                ratio_tdiff_by_avg_tdiff_u_arr[i_row] = t_diff/avg_u_time_since_last_arr[i_row] if (avg_u_time_since_last_arr[i_row] != 0) else 0
            else:
                avg_u_time_since_last_arr[i_row] = self.data[u_id]['sum_time_since_last_u']/cont_id if (cont_id != 0) else 0
                ratio_tdiff_by_avg_tdiff_u_arr[i_row] = t_diff/avg_u_time_since_last_arr[i_row] if (avg_u_time_since_last_arr[i_row] != 0) else 0
        
        
        user_feats_df = pd.DataFrame({'answered_correctly_avg_u': answ_corr_sum_arr/q_count_arr, \
                                      'answered_correctly_sum_u': answ_corr_sum_arr, \
                                      'q_count_u': q_count_arr, \
                                      'is_first_attempt': is_first_attempt_arr, \
                                      'session_num': sess_num_arr, \
                                      'same_container_as_last': same_container_as_last_arr, \
                                      'last_container_sum_answ': last_cont_sum_answ_u_arr, \
                                      'last_cont_q_count': last_cont_q_count_u_arr,
                                      'same_session_as_last': same_sess_as_last_u_arr,
                                      'current_session_time': current_session_time_arr,
                                      'current_session_q_count': current_session_q_count_arr,
                                      'last_session_break_time': last_session_break_time_arr,
                                      'avg_break_time': avg_break_time_arr,
                                      'avg_session_time': session_avg_time_arr,
                                      'avg_session_q_count': avg_session_q_count_arr,
                                      'current_session_avg_score': current_session_avg_score_arr,
                                      'current_right_answ_streak': current_right_answ_streak_arr,
                                      'max_right_answ_streak': max_right_answ_streak_arr,
                                      'hist_1_right_answ_streak': hist_1_right_answ_streak_arr,
                                      'hist_2_right_answ_streak': hist_2_right_answ_streak_arr,
                                      'hist_3_right_answ_streak': hist_3_right_answ_streak_arr,
                                      'avg_right_answ_streak_hist_3': avg_right_answ_streak_hist_3_arr,
                                      'first_try_success_count_u': first_try_success_count_u_arr,
                                      'unique_count_attempted_q_u': unique_count_attempted_q_u_arr,
                                      'avg_score_u_hist_100_75': avg_score_u_hist_100_75_arr,
                                      'avg_score_u_hist_75_50': avg_score_u_hist_75_50_arr,
                                      'avg_score_u_hist_50_25': avg_score_u_hist_50_25_arr,
                                      'avg_score_u_hist_25_0': avg_score_u_hist_25_0_arr,
                                      'hist_score_diff_u_100_50': hist_score_diff_u_100_50_arr,
                                      'hist_score_diff_u_75_25': hist_score_diff_u_75_25_arr,
                                      'hist_score_diff_u_50_0': hist_score_diff_u_50_0_arr,
                                      'hist_100_score_slope_u': hist_100_score_slope_u_arr,
                                      'q_explanation_avg_u': q_explanation_avg_u_arr,
                                      'q_explanation_sum_u': q_explanation_sum_u_arr,
                                      'time_since_last_sum_u': time_since_last_sum_u_arr,
                                      'time_since_last_avg_u': time_since_last_avg_u_arr,
                                      'curr_cont_score_avg_u': curr_cont_score_avg_u_arr,
                                      'curr_cont_score_sum_u': curr_cont_score_sum_u_arr,
                                      'curr_cont_tackled_q_count_u': curr_cont_tackled_q_count_u_arr,
                                      'sum_score_q_level': sum_score_q_level_arr,
                                      'q_count_level_u': q_count_level_arr,
                                      'avg_score_q_level_u': avg_score_q_level_arr,
                                      'first_answ': first_answ_arr,
                                      'u_cluster_start': u_cluster_start_arr,
                                      'avg_score_u_cluster_start': avg_score_u_cluster_start_arr,
                                      'time_since_last': time_since_last_arr,
                                      'time_since_last_2': time_since_last_2_arr,
                                      'time_since_last_3': time_since_last_3_arr,
                                      'time_since_last_right': time_since_last_right_arr,
                                      'time_since_last_wrong': time_since_last_wrong_arr,
                                      'avg_score_u_part': avg_score_u_part_arr, 
                                      'sum_score_u_part': sum_score_u_part_arr,
                                      'q_count_u_part': q_count_u_part_arr,
                                      'avg_prior_cont_time': avg_prior_cont_time_arr,
                                      'avg_prior_q_elapsed_time': avg_prior_q_elapsed_time_arr,
                                      'avg_prior_cont_time_part': avg_prior_cont_time_part_arr,
                                      'avg_prior_q_elapsed_time_part': avg_prior_q_elapsed_time_part_arr,
                                      'ratio_curr_by_avg_session_time': ratio_curr_by_avg_session_time_arr,
                                      'avg_u_time_since_last': avg_u_time_since_last_arr,
                                      'ratio_tdiff_by_avg_tdiff_u': ratio_tdiff_by_avg_tdiff_u_arr,
                                       })


        
        user_feats_df['answered_correctly_avg_u'] = user_feats_df['answered_correctly_avg_u'].fillna(0)

        return pd.concat([df_curr, user_feats_df], axis=1) 
    


    def get_l_feats_df_from_curr_state(self, df_curr):
        """ Get lecture features data using current user's states for each row of df_curr. 
        
        Must be called on test_df 
            - before dropping the questions ;
            - after calling get_lq_feats_from_curr_state() : 
             /!\ 'part' must already be in df_curr for both lec & questions /!\
             
        Returns a pd dataframe containing features related to lectures. This df should be concatenate to test_df 
        as soon as questions will be dropped in the inference loop. """
        
        l_count_arr = np.zeros(len(df_curr[df_curr.content_type_id == 0]), dtype=np.int32)
        i_row_arr = 0

        for i_row, (u_id, c_type_id, part) in enumerate(df_curr[['user_id','content_type_id', 'part']].values):

            if u_id not in self.data:
                self.add_new_user(u_id)

            if(c_type_id == 1): # lecture
                self.data[u_id]['l_count'] += 1
                
            else: # question
                l_count_arr[i_row_arr] = self.data[u_id]['l_count']
                i_row_arr += 1

        return pd.DataFrame({'l_count_u': l_count_arr})
    

    
    
    def get_lq_feats_from_curr_state(self, df_curr, lecture_stats, question_stats):
        """ Get features which requires both questions & lectures from the current state. 
        Must be called on test_df before dropping the questions.
        Return a dataframe with the new columns concatenated to df_curr. """

        part_arr = np.zeros(len(df_curr), dtype=np.int8)

        for i_row, (u_id, c_id, c_type_id) in enumerate(df_curr[['user_id','content_id','content_type_id']].values):

            if u_id not in self.data:
                self.add_new_user(u_id)

            if(c_type_id == 0): # question
                part_arr[i_row] = question_stats.data[c_id]['part']
            else: # lecture
                part_arr[i_row] = lecture_stats.data[c_id]['part']

        lq_feats_df = pd.DataFrame({'part': part_arr}) 
        
        df_curr.reset_index(drop=True, inplace=True)

        return pd.concat([df_curr, lq_feats_df], axis=1) 
    


            

In [15]:

class QuestionStats(object):
    """ Object which contains stats related to questions for inference. """
    
    def __init__(self, question_feat_stats_paths, question_level_feat_stats_paths):

        # [c_id] = { ... }
        self.data = {}
        for feat in question_feat_stats_paths:
            self.add_feat_to_data_path(feat, question_feat_stats_paths[feat])

        # [level] = { ... }
        self.data_level = {}
        for feat in question_level_feat_stats_paths:
            self.add_feat_to_data_level_path(feat, question_level_feat_stats_paths[feat])
   

    
    def add_feat_to_data_path(self, feat_name, feat_path):
        """ Load a feature state stored as a dict (pickle file) and set it to data for each question. """ 
        feat_dict = load_obj(feat_path)
        
        for q_id in feat_dict:
            
            if q_id not in self.data:
                self.add_new_question(q_id)
                
            self.data[q_id][feat_name] = feat_dict[q_id]
    
    
    def add_feat_to_data_level_path(self, feat_name, feat_path):
        """ Load a feature state stored as a dict (pickle file) and set it to data level for each question. """ 
        feat_dict = load_obj(feat_path)
        
        for lvl in feat_dict:
            
            if lvl not in self.data_level:
                self.data_level[lvl] = {}
                
            self.data_level[lvl][feat_name] = feat_dict[lvl]
                    

    
    def add_new_question(self, q_id):
        """Initialize data stats dicts for a new question."""
        self.data[q_id] = {}
    
    
        
    def get_q_feats_from_curr_state(self, df_curr):
        """ Get question features data using current user's states for each row of df_curr. Return a dataframe with 
        the new columns concatenated to df_curr.
        Warning : df_curr must only contains questions. """
        
        answ_corr_avg_arr = np.zeros(len(df_curr), dtype=np.float32)
        answ_corr_std_arr = np.zeros(len(df_curr), dtype=np.float32)
        answ_corr_median_arr = np.zeros(len(df_curr), dtype=np.int8)
        answ_tag_count_arr = np.zeros(len(df_curr), dtype=np.int8)
        
        avg_question_elapsed_time_c_arr = np.zeros(len(df_curr),dtype=np.int64)
        std_question_elapsed_time_c_arr = np.zeros(len(df_curr),dtype=np.int64)
        avg_time_since_last_c_arr = np.zeros(len(df_curr),dtype=np.int64)
        std_time_since_last_c_arr = np.zeros(len(df_curr),dtype=np.int64)

        avg_part_score_c_arr = np.zeros(len(df_curr),dtype=np.float16)
        std_part_score_c_arr = np.zeros(len(df_curr),dtype=np.float16)
        
        avg_score_q_level_arr = np.zeros(len(df_curr), dtype=np.float16)
        std_score_q_level_arr = np.zeros(len(df_curr), dtype=np.float16)
        
        avg_q_elapsed_time_per_lvl_arr = np.zeros(len(df_curr), dtype=np.float32)
        
        num_answers_q_arr = np.zeros(len(df_curr), dtype=np.int8)
        q_count_trainset_c_arr = np.zeros(len(df_curr), dtype=np.int16)
        
        q_tags_clusters_arr = np.zeros(len(df_curr), dtype=np.int8)
        
        curr_avg_q_elapsed_time_arr = np.zeros(len(df_curr), dtype=np.float32)
        curr_avg_q_is_explained_arr = np.zeros(len(df_curr), dtype=np.float16)
        
        for i_row, (q_id, lvl) in enumerate(df_curr[['content_id', 'difficulty_level']].values):
            
            answ_corr_avg_arr[i_row] = self.data[q_id]['answered_correctly_avg']
            answ_corr_std_arr[i_row] = self.data[q_id]['answered_correctly_std']
            answ_corr_median_arr[i_row] = self.data[q_id]['answered_correctly_median']
            answ_tag_count_arr[i_row] = self.data[q_id]['tag_count']
            
            avg_question_elapsed_time_c_arr[i_row] = self.data[q_id]['avg_question_elapsed_time']
            std_question_elapsed_time_c_arr[i_row] = self.data[q_id]['std_question_elapsed_time']
            avg_time_since_last_c_arr[i_row] = self.data[q_id]['avg_time_since_last']
            std_time_since_last_c_arr[i_row] = self.data[q_id]['std_time_since_last']
            
            part = self.data[q_id]['part']
            avg_part_score_c_arr[i_row] = self.data[part]['avg_part_score']
            std_part_score_c_arr[i_row] = self.data[part]['std_part_score']
            
            avg_score_q_level_arr[i_row] = self.data_level[lvl]['difficulty_level_avg_score_c']
            std_score_q_level_arr[i_row] = self.data_level[lvl]['difficulty_level_std_score_c']
            
            avg_q_elapsed_time_per_lvl_arr[i_row] = self.data_level[lvl]['avg_q_elapsed_time_level']
            
            num_answers_q_arr[i_row] = self.data[q_id]['num_answers_q']
            q_count_trainset_c_arr[i_row] = self.data[q_id]['q_count_trainset_c']
        
            q_tags_clusters_arr[i_row] = self.data[q_id]['tag_cluster']
            
            curr_avg_q_elapsed_time_arr[i_row] = self.data[q_id]['curr_avg_q_elapsed_time_c']
            curr_avg_q_is_explained_arr[i_row] = self.data[q_id]['curr_avg_q_is_explained_c']
        
        
        question_feats_df = pd.DataFrame({'answered_correctly_avg_c': answ_corr_avg_arr, \
                                          'answered_correctly_std_c': answ_corr_std_arr, \
                                          'answered_correctly_median_c': answ_corr_median_arr, \
                                          'tag_count_q': answ_tag_count_arr,
                                          'avg_question_elapsed_time_c': avg_question_elapsed_time_c_arr,
                                          'std_question_elapsed_time_c': std_question_elapsed_time_c_arr,
                                          'avg_time_since_last_c': avg_time_since_last_c_arr,
                                          'std_time_since_last_c': std_time_since_last_c_arr,
                                          'avg_part_score_c': avg_part_score_c_arr,
                                          'std_part_score_c': std_part_score_c_arr,
                                          'difficulty_level_avg_score_c': avg_score_q_level_arr,
                                          'difficulty_level_std_score_c': std_score_q_level_arr,
                                          'avg_q_elapsed_time_lvl': avg_q_elapsed_time_per_lvl_arr,
                                          'num_answers_q': num_answers_q_arr,
                                          'q_count_trainset_c': q_count_trainset_c_arr,
                                          'tag_cluster': q_tags_clusters_arr,
                                          'curr_avg_q_elapsed_time': curr_avg_q_elapsed_time_arr,
                                          'curr_avg_q_is_explained': curr_avg_q_is_explained_arr, })
       
        


        return pd.concat([df_curr, question_feats_df], axis=1)
    
    
    
    def get_lq_feats_from_curr_state(self, df_curr):
        """ Get features which requires both questions & lectures from the current state. 
        Must be called on test_df before dropping the questions.
        Return a dataframe with the new columns concatenated to df_curr. 
        Used to get 'difficulty_level' in previous_test_df. """
        
        q_level_arr = np.zeros(len(df_curr), dtype=np.int8)

        for i_row, (c_id, c_type_id) in enumerate(df_curr[['content_id','content_type_id']].values):

            is_lecture = c_type_id == 1
            if is_lecture:
                q_level_arr[i_row] = -1
            else:
                q_level_arr[i_row] = self.data[c_id]['difficulty_level']

        lq_feats_df = pd.DataFrame({'difficulty_level': q_level_arr})
        
        df_curr.reset_index(drop=True, inplace=True)

        return pd.concat([df_curr, lq_feats_df], axis=1)



In [16]:
class LectureStats(object):
    """ Object which contains stats related to lecture for inference. """
    
    def __init__(self, lecture_feat_stats_paths):

        self.data = {}
        for feat in lecture_feat_stats_paths:
            self.add_feat_to_data(feat, lecture_feat_stats_paths[feat])
   
    
    def add_feat_to_data(self, feat_name, feat_path):
        """ Load a feature state stored as a dict (pickle file) and set it to data for each lecture. """ 
        feat_dict = load_obj(feat_path)
        
        for l_id in feat_dict:
            
            if l_id not in self.data:
                self.add_new_lecture(l_id)
                
            self.data[l_id][feat_name] = feat_dict[l_id]
            
            
    def add_new_lecture(self, l_id):
        """Initialize data stats dicts for a new lecture."""
        self.data[l_id] = {}



In [17]:
class TagsStats(object):
    """ Object which contains stats related to tags for inference. """
    
    def __init__(self, tags_feat_stats_paths):

        self.data = {}
        for feat in tags_feat_stats_paths:
            self.add_feat_to_data(feat, tags_feat_stats_paths[feat])

    
    def add_feat_to_data(self, feat_name, feat_path):
        """ Load a feature state stored as a dict (pickle file) and set it to data for each tag. """ 
        feat_dict = load_obj(feat_path)
        
        for t_id in feat_dict:
            
            if t_id not in self.data:
                self.add_new_tag(t_id)
                
            self.data[t_id][feat_name] = feat_dict[t_id]
            
            
    def add_new_tag(self, t_id):
        """Initialize data stats dicts for a new question."""
        self.data[t_id] = {}
        self.data[t_id]['answered_correctly_sum'] = 0
        self.data[t_id]['q_count'] = 0
        
    
        
    def get_q_feats_from_curr_state(self, df_curr, question_stats):
        """ Get question features data using current user's states for each row of df_curr. Return a dataframe with 
        the new columns concatenated to df_curr.
        Warning : df_curr must only contains questions. """
        
        avg_score_t_arr = np.zeros(len(df_curr), dtype=np.float32)
        max_avg_score_t_arr = np.zeros(len(df_curr), dtype=np.float32)
        min_avg_score_t_arr = np.zeros(len(df_curr), dtype=np.float32)
        std_avg_score_t_arr = np.zeros(len(df_curr), dtype=np.float32)
        
        avg_tag_frequency_arr = np.zeros(len(df_curr), dtype=np.float32)
        max_tag_frequency_arr = np.zeros(len(df_curr), dtype=np.float32)
        min_tag_frequency_arr = np.zeros(len(df_curr), dtype=np.float32)
        std_tag_frequency_arr = np.zeros(len(df_curr), dtype=np.float32)
        
        
        for i_row, q_id in enumerate(df_curr['content_id'].values):
            
            tags_str = question_stats.data[q_id]['tags']
            tags = np.array(tags_str.split(' '), dtype=np.int32)
            
            avg_score_tags_l = []
            ncount_tags_l = []
            for t_id in tags:
                avg_score_tags_l.append(self.data[t_id]['answered_correctly_sum']/self.data[t_id]['q_count'])
                ncount_tags_l.append(self.data[t_id]['ncount'])
            
            avg_score_tags_l = np.array(avg_score_tags_l)
            tag_frequencies_arr = np.array(ncount_tags_l)/NUM_UNIQUE_QUESTIONS
            
            avg_score_t_arr[i_row] = avg_score_tags_l.mean()
            max_avg_score_t_arr[i_row] = avg_score_tags_l.max()
            min_avg_score_t_arr[i_row] = avg_score_tags_l.min()
            std_avg_score_t_arr[i_row] = avg_score_tags_l.std()
            
            avg_tag_frequency_arr[i_row] = tag_frequencies_arr.mean()
            max_tag_frequency_arr[i_row] = tag_frequencies_arr.max()
            min_tag_frequency_arr[i_row] = tag_frequencies_arr.min()
            std_tag_frequency_arr[i_row] = tag_frequencies_arr.std()
            


        tag_feats_df = pd.DataFrame({'avg_score_t': avg_score_t_arr,
                                     'max_avg_score_t': max_avg_score_t_arr,
                                     'min_avg_score_t': min_avg_score_t_arr,
                                     'std_avg_score_t': std_avg_score_t_arr,
                                     'avg_tag_freq': avg_tag_frequency_arr,
                                     'max_tag_freq': max_tag_frequency_arr,
                                     'min_tag_freq': min_tag_frequency_arr,
                                     'std_tag_freq': std_tag_frequency_arr,})
        

        return pd.concat([df_curr, tag_feats_df], axis=1)


In [18]:
# [CV] Reinitialize user_dicts  
if validaten_flg:
    
    del data_preprocessor
    gc.collect()
    
    data_preproc_params = {'offline_feats_train_paths': offline_features_train_paths, 
                           'offline_feats_valid_paths': offline_features_valid_paths,
                           'train_raw_q_path': df_questions_kfold_paths[kfold_name]['train'], 
                           'valid_raw_q_path': df_questions_kfold_paths[kfold_name]['valid'],
                           'X_feats': FEATS,
                           'y_feats': TARGET,
                           'train_slice_rows_i': train_slice_rows,
                           'valid_slice_rows_i': valid_slice_rows,
                           'valid_flag': validaten_flg}

    data_preprocessor = DataPreprocessor(**data_preproc_params)


# User state
u_feat_states_paths = users_feature_states_paths[kfold_name]
users_state = UsersState(u_feat_states_paths, user_cluster_avg_score_start_200_path)


# Question, lecture, and tags objects
question_stats = QuestionStats(question_feature_stats_paths, question_level_feature_stats_paths)
lecture_stats = LectureStats(lecture_feature_stats_paths)
tags_stats = TagsStats(tags_feature_stats_paths)


Loading offline features dict to user state ...
Loading user cluster avg start hist 200 dict.
User state init complete.


In [19]:
class InferenceUnitTester(object):
    """ Object prepare data for inference. """
    
    def __init__(self, valid_pkl, example_test_file_path):
        
        self.field_needed = ['row_id','user_id', 'timestamp','content_type_id', 'answered_correctly', 'task_container_id', 'content_id']
        
        self.valid_df_q = self.get_valid_only_q(valid_pkl)
        self.valid_df_l = self.get_valid_with_l(valid_pkl)
        self.batch_test_df = self.get_test_batch(example_test_file_path)
        
    def get_valid_only_q(self, valid_pkl):
        '''For debug : only questions.'''
        valid_df = pd.read_pickle(valid_pkl)[self.field_needed]
        valid_df.drop(valid_df[valid_df.content_type_id == 1].index, inplace=True)
        valid_df = valid_df.reset_index(drop=True)
        subset_valid_df = valid_df[:3]
        return subset_valid_df
    
    def get_valid_with_l(self, valid_pkl):
        '''For debug : include some lectures rows.'''
        valid_df = pd.read_pickle(valid_pkl)
        valid_df = valid_df.reset_index(drop=True)
        subset_valid_df = valid_df.iloc[45:55,:]
        subset_valid_df.drop(columns=['answered_correctly'], inplace=True)
        return subset_valid_df
    
    def get_raw_valid(self, valid_pkl):
        valid_df = pd.read_pickle(valid_pkl)
        valid_df = valid_df.reset_index(drop=True)
        return valid_df
    
    def get_test_batch(self, example_test_file_path):
        '''For debug : include some lectures rows.'''
        test_df = pd.read_csv(example_test_file_path)
        batch_test_df = test_df[:6]
        #batch_test_df = batch_test_df.append(batch_test_df.iloc[-1,:])
        #batch_test_df = batch_test_df.append(batch_test_df.iloc[-1,:])
        batch_test_df.reset_index(drop=True, inplace=True)
        #batch_test_df.loc[6,'content_type_id'] = 1
        #batch_test_df.loc[5:7,'content_id'] = 89
        #batch_test_df = users_state.get_lq_feats_from_curr_state(batch_test_df, lecture_stats, question_stats)
        #batch_test_df.loc[5:7,'part'] = 1
        return batch_test_df

    def get_valid_subset_change_session(self, X_valid, valid_pkl):
        raw_valid_df = pd.read_pickle(valid_pkl)
        raw_valid_df = raw_valid_df[raw_valid_df.content_type_id == 0]
        raw_valid_df.reset_index(drop=True, inplace=True)
        change_sess_u_subset = X_valid.iloc[raw_valid[raw_valid.user_id == 2147470777].index,:]
        change_sess_u_subset['user_id'] = raw_valid[raw_valid.user_id == 2147470777].user_id
        change_sess_u_subset['content_id'] = raw_valid[raw_valid.user_id == 2147470777].content_id
        change_sess_u_subset.reset_index(drop=True, inplace=True)
        return change_sess_u_subset[25:35]



In [20]:
######################
# UNIT TEST INFERENCE
######################

if unit_test_inference_flg:
    inference_unit_tester = InferenceUnitTester(valid_pickle, example_test_file)

    valid_df_q = inference_unit_tester.valid_df_q
    valid_df_l = inference_unit_tester.valid_df_l
    batch_test_df = inference_unit_tester.batch_test_df

In [21]:
class InferenceProcessor(object):
    """ Object prepare data for inference. """
    
    def __init__(self, data_preproc):
        self.data_preprocessor = data_preproc
        self.prior_question_elapsed_time_mean = PRIOR_QUESTION_ELAPSED_TIME_TRAIN_MEAN
        
        
    def get_cont_q_count_inference(self, df):
        '''Get the number of question per container for df.
        Requires user_id',task_container_id,row_id.
        Return an ndarray contianing the values of "num_q_cont". (values : from 1 to 6). '''

        g = df[['user_id','task_container_id','row_id']].groupby(['user_id','task_container_id']).count()
        g.reset_index(inplace=True)
        g.rename(columns={'row_id':'num_q_cont'},inplace=True)
        g['num_q_cont'] = g['num_q_cont'].astype(np.int8, copy=False)
        g.set_index('user_id', inplace=True)
        g.loc[g[g.num_q_cont >= 6].index,'num_q_cont'] = 6

        return g.reindex(df['user_id'].values).reset_index(drop=True).num_q_cont.values

    
    def get_diff_avg_question_elapsed_time_c(self, df):
        '''Req: avg_question_elapsed_time_c (in question_stats), question_elapsed_time (inference). '''
        return (df.question_elapsed_time-df.avg_question_elapsed_time_c).fillna(0).astype(np.int64) 

    def get_diff_avg_time_since_last_c(self, df):
        '''Req: avg_time_since_last_c (in question_stats), time_since_last (inference). '''
        return (df.time_since_last-df.avg_time_since_last_c).fillna(0).astype(np.int64)

    def get_ratio_diff_avg_question_elapsed_time_c(self, df):
        '''Req: avg_question_elapsed_time_c (in question_stats), diff_avg_question_elapsed_time_c (inference). '''
        return (df.diff_avg_question_elapsed_time_c/df.avg_question_elapsed_time_c).fillna(0).astype(np.float32) 
              
   
    def get_ratio_diff_avg_time_since_last_c(self, df):
        '''Req: avg_time_since_last_c (in question_stats), diff_avg_time_since_last_c (inference). '''
        return (df.diff_avg_time_since_last_c/df.avg_time_since_last_c).fillna(0).astype(np.float32)
        
        
    def get_diff_avg_q_elapsed_time_lvl(self, df):
        '''Req: avg_q_elapsed_time_lvl (in question_stats), question_elapsed_time (inference). '''
        return (df.question_elapsed_time-df.avg_q_elapsed_time_lvl).fillna(0).astype(np.float32)
    
    
    def get_ratio_diff_avg_q_elapsed_time_lvl(self, df):
        '''Req: avg_q_elapsed_time_lvl (in question_stats), diff_avg_q_elapsed_time_lvl (inference). '''
        return (df.diff_avg_q_elapsed_time_lvl/df.avg_q_elapsed_time_lvl).fillna(0).astype(np.float16)
    
    
    
    def get_diff_avg_score_q_lvl(self, df):
        '''Req:  avg_score_q_level_u (user), difficulty_level_avg_score_c (in question_stats). '''
        return (df.avg_score_q_level_u-df.difficulty_level_avg_score_c).fillna(0).astype(np.float16)
    
    
    def get_ratio_diff_avg_score_q_lvl(self, df):
        '''Req: diff_avg_score_q_lvl (inference), difficulty_level_avg_score_c (in question_stats). '''
        return (df.diff_avg_score_q_lvl/df.difficulty_level_avg_score_c).fillna(0).astype(np.float16)
    

    def get_ratio_diff_avg_score_u_c(self, df):
        '''Req: answered_correctly_avg_u (user), answered_correctly_avg_c (in question_stats). '''
        return ((df.answered_correctly_avg_u - answered_correctly_avg_c)/df.answered_correctly_avg_c).fillna(0).astype(np.float16)


      
    
    def apply_online_feats_inference(self, df):
        '''Apply online feats which combine raws and user/question/lecture features. Return df with the additionnal features
        concatenated.
        Must be called after user states and static statistics features update, and before feeding df into the model.'''

        # Add features
        df['time_btw_cont_mean'] = data_preprocessor.get_time_btw_containers_mean(df)
        df['time_per_action_mean'] = data_preprocessor.get_time_per_action_mean(df)
        df['is_first_question'] = data_preprocessor.get_is_first_question(df)
        df['last_cont_score_mean'] = data_preprocessor.get_last_cont_score_mean(df)
        df['cont_q_count'] = self.get_cont_q_count_inference(df)
        df['question_elapsed_time'] = data_preprocessor.get_question_elapsed_time(df)
        df['time_ratio'] = data_preprocessor.get_time_ratio(df)
        df['log_time_ratio'] = data_preprocessor.get_log_time_ratio(df)
        
        df['ratio_score_u_time_ratio'] = data_preprocessor.get_ratio_score_u_time_ratio(df)
        df['ratio_q_count_u_time_ratio'] = data_preprocessor.get_ratio_q_count_u_time_ratio(df)
        
        df['diff_avg_question_elapsed_time_c'] = self.get_diff_avg_question_elapsed_time_c(df)
        df['diff_avg_time_since_last_c'] = self.get_diff_avg_time_since_last_c(df)
        df['ratio_diff_avg_question_elapsed_time_c'] = self.get_ratio_diff_avg_question_elapsed_time_c(df)
        df['ratio_diff_avg_time_since_last_c'] = self.get_ratio_diff_avg_time_since_last_c(df)
        
        
        df['diff_avg_q_elapsed_time_lvl'] = self.get_diff_avg_q_elapsed_time_lvl(df)
        df['ratio_diff_avg_q_elapsed_time_lvl'] = self.get_ratio_diff_avg_q_elapsed_time_lvl(df)
        
        # NaNs & casts
        df['prior_question_had_explanation'] = df.prior_question_had_explanation.fillna(False).astype('int8')
        df['prior_question_elapsed_time'] = df.prior_question_elapsed_time.fillna(self.prior_question_elapsed_time_mean)

        return df
    
    

# model 2: SAKT

Credit : This model heavily rely on this public kernel : https://www.kaggle.com/wangsg/a-self-attentive-model-for-knowledge-tracing

In [22]:
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

In [23]:
DROPOUT_SAKT = 0.1
MAX_SEQ = 180
EMBED_SIZE_SAKT = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [24]:
class FFN(nn.Module):
    def __init__(self, state_size = 200, forward_expansion = 1, bn_size=MAX_SEQ - 1, dropout=0.2):
        super(FFN, self).__init__()
        self.state_size = state_size
        
        self.lr1 = nn.Linear(state_size, forward_expansion * state_size) 
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm1d(bn_size)
        self.lr2 = nn.Linear(forward_expansion * state_size, state_size) 
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.relu(self.lr1(x))
        x = self.bn(x)
        x = self.lr2(x)
        return self.dropout(x)
    

def future_mask(seq_length):
    future_mask = (np.triu(np.ones([seq_length, seq_length]), k = 1)).astype('bool')
    return torch.from_numpy(future_mask)



class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, heads = 8, dropout = DROPOUT_SAKT, forward_expansion = 1):
        super(TransformerBlock, self).__init__()
        self.multi_att = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=heads, dropout=dropout)
        self.dropout = nn.Dropout(dropout)
        self.layer_normal = nn.LayerNorm(embed_dim)
        self.ffn = FFN(embed_dim, forward_expansion = forward_expansion, dropout=dropout)
        self.layer_normal_2 = nn.LayerNorm(embed_dim)
        

    def forward(self, value, key, query, att_mask):
        att_output, att_weight = self.multi_att(value, key, query, attn_mask=att_mask)
        att_output = self.dropout(self.layer_normal(att_output + value))
        att_output = att_output.permute(1, 0, 2) # att_output: [s_len, bs, embed] => [bs, s_len, embed]
        x = self.ffn(att_output)
        x = self.dropout(self.layer_normal_2(x + att_output))
        return x.squeeze(-1), att_weight
    
class Encoder(nn.Module):
    def __init__(self, n_skill, max_seq=100, embed_dim=128, dropout = DROPOUT_SAKT, forward_expansion = 1, num_layers=1, heads = 8):
        super(Encoder, self).__init__()
        self.n_skill, self.embed_dim = n_skill, embed_dim
        self.embedding = nn.Embedding(2 * n_skill + 1, embed_dim)
        self.pos_embedding = nn.Embedding(max_seq - 1, embed_dim)
        self.e_embedding = nn.Embedding(n_skill+1, embed_dim)
        self.layers = nn.ModuleList([TransformerBlock(embed_dim, forward_expansion = forward_expansion) for _ in range(num_layers)])
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, question_ids):
        device = x.device
        x = self.embedding(x)
        pos_id = torch.arange(x.size(1)).unsqueeze(0).to(device)
        pos_x = self.pos_embedding(pos_id)
        x = self.dropout(x + pos_x)
        x = x.permute(1, 0, 2) # x: [bs, s_len, embed] => [s_len, bs, embed]
        e = self.e_embedding(question_ids)
        e = e.permute(1, 0, 2)
        for layer in self.layers:
            att_mask = future_mask(e.size(0)).to(device)
            x, att_weight = layer(e, x, x, att_mask=att_mask)
            x = x.permute(1, 0, 2)
        x = x.permute(1, 0, 2)
        return x, att_weight

class SAKTModel(nn.Module):
    def __init__(self, n_skill, max_seq=100, embed_dim=128, dropout = DROPOUT_SAKT, forward_expansion = 1, enc_layers=1, heads = 8):
        super(SAKTModel, self).__init__()
        self.encoder = Encoder(n_skill, max_seq, embed_dim, dropout, forward_expansion, num_layers=enc_layers)
        self.pred = nn.Linear(embed_dim, 1)
        
    def forward(self, x, question_ids):
        x, att_weight = self.encoder(x, question_ids)
        x = self.pred(x)
        return x.squeeze(-1), att_weight

In [25]:
def create_model():
    return SAKTModel(n_skill, max_seq=MAX_SEQ, embed_dim=EMBED_SIZE_SAKT, forward_expansion=1, enc_layers=1, heads=8, dropout=0.1)

# Paths
sakt_group_pickle_path = SAKT_GROUP_PICKLE_FULLSET_PATH
sakt_model_state_dict_path = SAKT_MODEL_STATE_DICT_FULLSET_PATH
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Skills
skills = np.load(SAKT_SKILL_NP_PATH)
n_skill = len(skills)

# Loading group
group = load_obj(sakt_group_pickle_path)

# Loading SAKT model
sakt_model = create_model()
sakt_model.load_state_dict(torch.load(sakt_model_state_dict_path, map_location=device))
sakt_model.to(device)


log('SAKT model successfully loaded.')


SAKT model successfully loaded.


In [26]:
class TestDataset(Dataset):
    def __init__(self, samples, test_df, n_skill, max_seq=100):
        super(TestDataset, self).__init__()
        self.samples, self.user_ids, self.test_df = samples, [x for x in test_df["user_id"].unique()], test_df
        self.n_skill, self.max_seq = n_skill, max_seq

    def __len__(self):
        return self.test_df.shape[0]
    
    def __getitem__(self, index):
        test_info = self.test_df.iloc[index]
        
        user_id = test_info['user_id']
        target_id = test_info['content_id']
        
        content_id_seq = np.zeros(self.max_seq, dtype=int)
        answered_correctly_seq = np.zeros(self.max_seq, dtype=int)
        
        if user_id in self.samples.index:
            content_id, answered_correctly = self.samples[user_id]
            
            seq_len = len(content_id)
            
            if seq_len >= self.max_seq:
                content_id_seq = content_id[-self.max_seq:]
                answered_correctly_seq = answered_correctly[-self.max_seq:]
            else:
                content_id_seq[-seq_len:] = content_id
                answered_correctly_seq[-seq_len:] = answered_correctly
                
        x = content_id_seq[1:].copy()
        x += (answered_correctly_seq[1:] == 1) * self.n_skill
        
        questions = np.append(content_id_seq[2:], [target_id])
        
        return x, questions

# Inference : bagging

In [27]:
log('Inference ...')

Inference ...


In [28]:
########################################################################################
# INFERENCE LOOP
########################################################################################

import psutil
sakt_model.eval()


inference_preprocessor = InferenceProcessor(data_preprocessor)

previous_test_df = None

for (test_df, sample_prediction_df) in iter_test:
    
    #--------------------------------------------------------------
    # State updating
    #--------------------------------------------------------------
    
    if previous_test_df is not None:
        answers = eval(test_df["prior_group_answers_correct"].iloc[0])
        previous_test_df[TARGET] = answers
        
        #---------
        # LBGM states
        #---------
        users_state.update_dicts_from_prev_state(previous_test_df) # lectures are not updated
        
        #---------
        # SAKT states
        #---------
        if psutil.virtual_memory().percent < 90:
 

            previous_test_df = previous_test_df[previous_test_df.content_type_id == False]
            prev_group = previous_test_df[['user_id', 'content_id', 'answered_correctly']].groupby('user_id').apply(lambda r: (
                r['content_id'].values,
                r['answered_correctly'].values))
            for prev_user_id in prev_group.index:
                prev_group_content = prev_group[prev_user_id][0]
                prev_group_ac = prev_group[prev_user_id][1]
                if prev_user_id in group.index:
                    group[prev_user_id] = (np.append(group[prev_user_id][0],prev_group_content), 
                                           np.append(group[prev_user_id][1],prev_group_ac))

                else:
                    group[prev_user_id] = (prev_group_content,prev_group_ac)
                if len(group[prev_user_id][0])>MAX_SEQ:
                    new_group_content = group[prev_user_id][0][-MAX_SEQ:]
                    new_group_ac = group[prev_user_id][1][-MAX_SEQ:]
                    group[prev_user_id] = (new_group_content,new_group_ac)    
    
    
    #--------------------------------------------------------------
    # Features preparation
    #--------------------------------------------------------------

    
    test_df['prior_question_had_explanation'] = test_df.prior_question_had_explanation.fillna(False).astype('int8')         
        
    #-- 1) Add 'difficulty level' and 'part', required for dicts update
    test_df = question_stats.get_lq_feats_from_curr_state(test_df)
    test_df = users_state.get_lq_feats_from_curr_state(test_df, lecture_stats, question_stats)
    
    # Save test_df for the next iteration
    previous_test_df = test_df.copy()
    
    #-- 2) Add feats from q&l to concate after lecture dropping
    test_l_df = users_state.get_l_feats_df_from_curr_state(test_df)

    #-- 3) Drop lectures
    test_df = test_df[test_df['content_type_id'] == 0].reset_index(drop=True)
    
    #-- 4) Add feats on questions only
    
    test_df = pd.concat([test_df, test_l_df], axis=1)
    
    test_df = question_stats.get_q_feats_from_curr_state(test_df) 
    test_df = users_state.get_q_feats_from_curr_state(test_df)
    test_df = tags_stats.get_q_feats_from_curr_state(test_df, question_stats)
    
    test_df = inference_preprocessor.apply_online_feats_inference(test_df)
    
    #--------------------------------------------------------------
    # LGBM predictions
    #--------------------------------------------------------------
    
    lgbm_preds = model.predict(test_df[FEATS])
    
    
    #--------------------------------------------------------------
    # TODO: SAKT predictions
    #--------------------------------------------------------------
    test_dataset = TestDataset(group, test_df, n_skill, max_seq=MAX_SEQ)
    test_dataloader = DataLoader(test_dataset, batch_size=51200, shuffle=False)
    
    
    sakt_preds = []

    for item in test_dataloader:
        x = item[0].to(device).long()
        target_id = item[1].to(device).long()

        with torch.no_grad():
            output, att_weight = sakt_model(x, target_id)
        
        
        output = torch.sigmoid(output)
        output = output[:, -1]

        sakt_preds.extend(output.view(-1).data.cpu().numpy()) 
    
    
    
    # LGBM + SAKT averaging
    sakt_preds = np.array(sakt_preds)
    ensemble_preds = (0.7*lgbm_preds+0.3*sakt_preds)
    
    
    #-- 5) Predict
    test_df[TARGET] = ensemble_preds
    set_predict(test_df[['row_id', TARGET]])




In [29]:
########################################################################################
# CROSS VALIDATION
########################################################################################


if validaten_flg:
    y_true = target_df[target_df.content_type_id == 0].answered_correctly # again drop lectures
    y_pred = pd.concat(predicted).answered_correctly # ['user_id','answered_correctly'] -> ndarray
    print(roc_auc_score(y_true, y_pred))



In [30]:
print('Run OK : Ready for submission')

Run OK : Ready for submission


In [31]:
# Inference features sample
test_df[FEATS].head()

Unnamed: 0,answered_correctly_avg_c,avg_score_q_level_u,avg_score_u_part,is_first_attempt,time_since_last,time_since_last_2,answered_correctly_std_c,current_session_q_count,current_session_time,question_elapsed_time,...,hist_1_right_answ_streak,part,q_explanation_sum_u,ratio_score_u_time_ratio,last_session_break_time,q_count_trainset_c,avg_right_answ_streak_hist_3,q_explanation_avg_u,avg_time_since_last_c,l_count_u
0,0.62789,0.0,0.666504,1,32799,18903,0.483368,3,75311,32799.0,...,0,5,0,726817.898074,0,5468,0.0,0.0,41233,0
1,0.733491,0.782715,0.671875,0,255228,97394,0.442133,14,1427667,255228.0,...,0,2,4069,1.696466,115960204,32767,2.333333,0.998047,37889,123
2,0.415218,0.571289,0.845703,1,35781230,92999,0.49276,0,0,7156246.0,...,0,7,2070,1.381715,35781230,3205,0.333333,0.98584,524954,27
3,0.864874,0.974609,0.845703,1,35781230,92999,0.341858,0,0,7156246.0,...,0,7,2070,1.381715,35781230,3205,0.333333,0.98584,524954,27
4,0.456215,0.571289,0.845703,1,35781230,92999,0.498079,0,0,7156246.0,...,0,7,2070,1.381715,35781230,3205,0.333333,0.98584,524954,27


In [32]:
# Prediction sample
test_df[TARGET]

0     0.538380
1     0.663291
2     0.590066
3     0.973922
4     0.643600
5     0.534273
6     0.940147
7     0.814437
8     0.542967
9     0.853229
10    0.599507
11    0.692839
12    0.980183
13    0.942471
14    0.892039
15    0.857754
16    0.612839
17    0.769468
18    0.743086
19    0.735842
20    0.292633
21    0.483123
22    0.904868
23    0.776741
24    0.686324
25    0.760144
26    0.652089
27    0.784797
28    0.434220
29    0.820995
30    0.669553
31    0.651687
32    0.975113
Name: answered_correctly, dtype: float64