## Source Kernel
This kernel generates and submits predictions using the model and features developed in the kernel titled [RIIID: BigQuery-XGBoost End-to-End](https://www.kaggle.com/calebeverett/riiid-bigquery-xgboost-end-to-end).

In [None]:
import gc
import json
import time

import numpy as np
import pandas as pd
from pathlib import Path
from pprint import pprint
import sqlite3
from sklearn.metrics import roc_auc_score
from tqdm.notebook import trange, tqdm
import xgboost as xgb

import riiideducation

In [None]:
PATH = Path('../input/riiid-submission-private')

In [None]:
class PredictEnv:
    def __init__(self, folds_path, folds):
        self.conn = sqlite3.connect(':memory:')
        self.c = self.conn.cursor()
        self.setup_folds(folds_path, folds)

    def setup_folds(self, folds_path, folds):        
        self.c.executescript(f"""
            ATTACH DATABASE "{folds_path}" AS folds_db;

            DROP TABLE IF EXISTS b_records;

            CREATE TABLE b_records AS
            SELECT row_id, timestamp, user_id, content_id, content_type_id, task_container_id, prior_question_elapsed_time,
                prior_question_had_explanation, answered_correctly, user_answer
            FROM folds_db.train
            WHERE fold in ({(', ').join(list(map(str,folds)))})
            ORDER BY user_id, task_container_id, row_id;

            CREATE INDEX user_id_task_container_id_index ON b_records (user_id, task_container_id);

            DROP TABLE IF EXISTS b_users;

            CREATE TABLE b_users AS
            SELECT user_id, MIN(task_container_id) - 1 task_container_id_next, MAX(task_container_id) task_container_id_max
            FROM b_records
            GROUP BY user_id
                ORDER BY user_id, task_container_id_next;

            CREATE UNIQUE INDEX user_id_index ON b_users (user_id);

            ALTER TABLE b_users
                ADD COLUMN group_num INTEGER;

        """).fetchone()

        self.group_num = 0
        self.records_remaining = self.c.execute('SELECT COUNT(*) FROM b_records').fetchone()[0]
        self.df_users = pd.read_sql('SELECT * FROM b_users', self.conn)


    def iter_test(self):
        next_correct = '[]'
        next_responses = '[]'

        while self.records_remaining:
            self.c.execute(f"""
                INSERT INTO b_users (user_id)
                SELECT user_id
                FROM b_users
                WHERE task_container_id_next <= task_container_id_max
                LIMIT 1 + ABS(RANDOM() % 40) + ABS(RANDOM() % 1000) * (ABS(RANDOM() % 100) < 5)
                ON CONFLICT (user_id) DO UPDATE SET
                    task_container_id_next = task_container_id_next + 1,
                    group_num = {self.group_num};
            """).fetchone()
            
            self.conn.commit()

            df_b = pd.read_sql(f"""
                SELECT r.*
                FROM b_records r
                JOIN b_users u
                ON group_num = {self.group_num}
                    AND r.user_id = u.user_id
                    AND r.task_container_id = u.task_container_id_next
            """, self.conn)

            if len(df_b):
                df_b['group_num'] = self.group_num
                df_b['prior_group_answers_correct'] = None
                df_b.at[0, 'prior_group_answers_correct'] = next_correct

                df_b['prior_group_responses'] = None
                df_b.at[0, 'prior_group_responses'] = next_responses

                next_correct = f'[{(", ").join(df_b.answered_correctly.astype(str))}]'
                next_responses = f'[{(", ").join(df_b.user_answer.astype(str))}]'
                del df_b['answered_correctly']
                del df_b['user_answer']

                df_b = df_b.set_index('group_num')

                df_p = df_b[['row_id']].copy()
                df_p['answered_correctly'] = 0.5
                
                self.records_remaining -= len(df_b)

                yield df_b, df_p
            
            self.group_num += 1

    def predict(self, df_pred):
        if (df_pred.answered_correctly == -1).any():
            raise
        else:
            df_pred.reset_index().to_sql('predictions', self.conn, if_exists='append', index=False)

    def get_predictions(self):
        df_preds = pd.read_sql("""
            SELECT p.row_id, b.answered_correctly y_true, p.answered_correctly y_pred
            FROM predictions p
            JOIN b_records b
            ON p.row_id = b.row_id
        """, self.conn)

        self.score = roc_auc_score(df_preds.y_true, df_preds.y_pred)

        print(f'ROC AUC Score: {self.score:0.4f}')

        return df_preds

In [None]:
mock = False

if mock:
    FOLDS = Path('../input/riiid-folds/riiid.db')
    env = PredictEnv(FOLDS, [0])
    iter_test = env.iter_test()

else:
    env = riiideducation.make_env()
    iter_test = env.iter_test()

## Load Models

In [None]:
with open(PATH/'models.json') as mj:
    models = json.load(mj)

pprint(models)
 
for m in models:
    models[m]['model'] = xgb.Booster(model_file=PATH/f'{m}.xgb')

## Load State

In [None]:
batch_cols_prior = [
    'user_id',
    'content_id',
    'content_type_id',
    'timestamp'
]

batch_cols = [
    'user_id',
    'content_id',
    'row_id',
    'timestamp',
    'prior_question_had_explanation',
    'prior_question_elapsed_time'
]

with open(PATH/'columns.json') as cj:
    test_cols = json.load(cj)
    
with open(PATH/'dtypes.json') as dj:
    dtypes = json.load(dj)


print('Test Columns:\n')
_ = list(map(print, test_cols))

dtypes_test = {k: v for k,v in dtypes.items() if k in test_cols}
dtypes_test = {**dtypes_test, **{'user_id': 'int32', 'content_id': 'int16'}}

### Load Users-Content

In [None]:
df_users_content = pd.read_pickle(PATH/'df_users_content.pkl')
df_users_content.head()

### Load Users-Part

In [None]:
df_users_part = pd.read_pickle(PATH/'df_users_part.pkl')
df_users_part.head()

### Load Users-Tag

In [None]:
df_users_tag = pd.read_pickle(PATH/'df_users_tag.pkl')
df_users_tag.head()

### Load Users

In [None]:
df_users = pd.read_pickle(PATH/'df_users.pkl')
df_users.head()

### Load Questions
Question related features joined with batches received from competition api prior to making predictions.

In [None]:
df_questions = pd.read_pickle(PATH/'df_questions.pkl')
df_questions.head()

In [None]:
q_part_map = df_questions[['question_id', 'part']].set_index('question_id').part.sort_index().to_dict()

In [None]:
q_tags_array_map = df_questions[['question_id', 'tags_array']].set_index('question_id').tags_array.map(eval).to_dict()

In [None]:
def get_qt_tuples(r):
    return [(r.question_id, t) for t in q_tags_array_map[r.question_id]]

In [None]:
df_qt = pd.DataFrame(sum(list(df_questions.apply(get_qt_tuples, axis=1)), []), columns=['question_id', 'tag'])

### Load Lectures

In [None]:
df_lectures = pd.read_pickle(PATH/'df_lectures.pkl')
df_lectures.head()

In [None]:
l_tags_array_map = df_lectures[['lecture_id', 'tags_array']].set_index('lecture_id').tags_array.map(eval).map(lambda x: x[0]).to_dict()

In [None]:
l_part_map = df_lectures[['lecture_id', 'part']].set_index('lecture_id').part.sort_index().to_dict()

### Load Weights

In [None]:
cat_cols = ['user_id', 'content_id']

weights = np.load(PATH/'weights_all.npy', allow_pickle=True).item()

for c in cat_cols:
    embed_key = f'{c}_embedding'
    weights[embed_key] = np.reshape(weights[embed_key], (len(weights[c]) + 1, -1))

weights.keys()

In [None]:
def get_preds(df, weights, logits=True):
    user_codes, content_codes = [pd.Categorical(df[c], categories=weights[c]).codes  for c in cat_cols]
    user_col, content_col = cat_cols
    logit = np.sum(weights[f'{user_col}_embedding'][user_codes] * weights[f'{content_col}_embedding'][content_codes], axis=1)
    logit += weights[f'{user_col}_bias'][user_codes] + weights[f'{content_col}_bias'][content_codes]
    
    if logits:
        return logit
    else:
        return 1 / (1 + np.exp(-logit))

## Create Database

In [None]:
conn = sqlite3.connect(':memory:')
c = conn.cursor()

### Create Users-Content Table

In [None]:
c.execute("""
    CREATE TABLE users_content (
        user_id INTEGER,
        content_id INTEGER,
        ac_cumsum_content_id INTEGER,
        r_cumcnt_content_id INTEGER,
        aic_cumsum_content_id INTEGER,
        ac_cumsum_pct_content_id INTEGER GENERATED ALWAYS AS (IFNULL(ROUND(ac_cumsum_content_id * 100 / r_cumcnt_content_id), -1)) STORED
    )
""").fetchone()

In [None]:
%%time

chunk_size = 20000
total = len(df_users_content)
n_chunks = (total // chunk_size + 1)

for i in trange(n_chunks):
    df_users_content.iloc[i * chunk_size:(i + 1) * chunk_size].to_sql('users_content', conn, method='multi', if_exists='append', index=False)

c.execute('CREATE UNIQUE INDEX users_content_index ON users_content (user_id, content_id)')
del df_users_content
gc.collect()

In [None]:
%%time
pd.read_sql('SELECT * from users_content LIMIT 5', conn)

### Create Users-Part Table

In [None]:
c.execute("""
    CREATE TABLE users_part (
        user_id INTEGER,
        part INTEGER,
        ac_cumsum_part INTEGER,
        r_cumcnt_part INTEGER,
        aic_cumsum_part INTEGER,
        lectures_cumcnt_part INTEGER,
        ac_cumsum_pct_part INTEGER GENERATED ALWAYS AS (IFNULL(ROUND(ac_cumsum_part * 100 / r_cumcnt_part), -1)) STORED
    )
""").fetchone()

In [None]:
%%time

chunk_size = 20000
total = len(df_users_part)
n_chunks = (total // chunk_size + 1)

for i in trange(n_chunks):
    df_users_part.iloc[i * chunk_size:(i + 1) * chunk_size].to_sql('users_part', conn, method='multi', if_exists='append', index=False)

c.execute('CREATE UNIQUE INDEX users_part_index ON users_part (user_id, part)')
del df_users_part
gc.collect()

In [None]:
%%time
pd.read_sql('SELECT * from users_part LIMIT 5', conn)

### Create Users-Tag Table

In [None]:
c.executescript("""
    CREATE TABLE users_tag (
        user_id INTEGER,
        tag INTEGER,
        ac_cumsum_tag INTEGER,
        r_cumcnt_tag INTEGER,
        aic_cumsum_tag INTEGER,
        lectures_cumcnt_tag INTEGER,
        ac_cumsum_pct_tag INTEGER GENERATED ALWAYS AS (IFNULL(ROUND(ac_cumsum_tag * 100 / r_cumcnt_tag), -1)) STORED
    )
""").fetchone()

In [None]:
%%time

chunk_size = 20000
total = len(df_users_tag)
n_chunks = (total // chunk_size + 1)

for i in trange(n_chunks):
    df_users_tag.iloc[i * chunk_size:(i + 1) * chunk_size].to_sql('users_tag', conn, method='multi', if_exists='append', index=False)

c.execute('CREATE UNIQUE INDEX users_tag_index ON users_tag (user_id, tag)')
del df_users_tag
gc.collect()

In [None]:
%%time
pd.read_sql('SELECT * from users_tag LIMIT 5', conn)

### Create Users Table

In [None]:
c.executescript("""
    DROP TABLE IF EXISTS users;

    CREATE TABLE users (
        user_id INTEGER,
        ac_cumsum INTEGER,
        aic_cumsum INTEGER,
        r_cumcnt INTEGER,
        lectures_cumcnt INTEGER,
        ac_cumsum_upto INTEGER,
        aic_cumsum_upto INTEGER,
        r_cumcnt_upto INTEGER,
        ac_cumsum_session INTEGER,
        aic_cumsum_session INTEGER,
        r_cumcnt_session INTEGER,
        lectures_cumcnt_session INTEGER,
        session INTEGER,
        timestamp INTEGER,
        ac_cumsum_pct INTEGER GENERATED ALWAYS AS (IFNULL(ROUND(ac_cumsum * 100 / r_cumcnt), -1)) STORED,
        ac_cumsum_pct_upto INTEGER GENERATED ALWAYS AS (IFNULL(ROUND(ac_cumsum_upto * 100 / r_cumcnt_upto), -1)) STORED,
        ac_cumsum_pct_session INTEGER GENERATED ALWAYS AS (IFNULL(ROUND(ac_cumsum_session * 100 / r_cumcnt_session), -1)) STORED
    )
""").fetchone()

In [None]:
%%time

chunk_size = 10000
total = len(df_users)
n_chunks = (total // chunk_size + 1)

i = 0
while i < n_chunks:
    df_users.iloc[i * chunk_size:(i + 1) * chunk_size].to_sql('users', conn, method='multi', if_exists='append', index=False)
    i += 1

c.execute('CREATE UNIQUE INDEX users_index ON users (user_id)').fetchone()
del df_users
gc.collect()

In [None]:
%%time
pd.read_sql('SELECT * from users LIMIT 5', conn)

### Create Questions Table

In [None]:
%%time

q_cols = [
    'question_id',
    'part',
    'tag_0',
    'part_correct_pct',
    'tag_0_correct_pct',
    'question_id_correct_pct',
    'tags_correct_pct',
    'tags_code',
    'tag_0_part_correct_pct',
    'question_id_pqet_avg',
    'tags_pqet_avg',
    'tag_0_pqet_avg',
    'part_pqet_avg'
]

df_questions[q_cols].to_sql('questions', conn, method='multi', index=False)
c.execute('CREATE UNIQUE INDEX question_id_index ON questions (question_id)').fetchone()
del df_questions
gc.collect()

In [None]:
%%time
pd.read_sql('SELECT * from questions LIMIT 5', conn)

### Create Questions-Tags Table

In [None]:
df_qt.to_sql('questions_tags', conn, method='multi', index=False)
c.execute('CREATE UNIQUE INDEX question_id_tag_index ON questions_tags (question_id, tag)').fetchone()
del df_qt
gc.collect()

In [None]:
db_size = pd.read_sql('SELECT page_count * page_size as size FROM pragma_page_count(), pragma_page_size()', conn)['size'][0]
print(f'Total size of database is: {db_size/1e9:0.3f} GB')

## Predict

### Get State

In [None]:
# with tags_cumsum
def select_state(batch_cols, records, session_hours=72):
    cols_string = (', ').join(batch_cols)
    values_string = (', ').join(list(map(str, records)))
        
    return f"""
        
        WITH b ({cols_string}) AS (
            VALUES {values_string}
        )
        SELECT
            {(', ').join([f'b.{col}' for col in batch_cols])},
            CAST(b.timestamp / 60000 AS INTEGER) ts_minute,
            CAST(b.prior_question_elapsed_time / 1000 AS INTEGER) pqet_sec,
            IFNULL(IIF(u.r_cumcnt > 300, 300, u.r_cumcnt), 0) r_cumcnt_clip,
            IFNULL(IIF(b.timestamp - u.timestamp > (1000 * 60 * 60 * {session_hours}), u.session + 1, u.session), 0) session,
            IFNULL(IIF(b.timestamp - u.timestamp > (1000 * 60 * 60 * {session_hours}), 0, u.ac_cumsum_session), 0) ac_cumsum_session,
            IFNULL(IIF(b.timestamp - u.timestamp > (1000 * 60 * 60 * {session_hours}), 0, u.aic_cumsum_session), 0) aic_cumsum_session,
            IFNULL(IIF(b.timestamp - u.timestamp > (1000 * 60 * 60 * {session_hours}), 0, u.r_cumcnt_session), 0) r_cumcnt_session,
            IFNULL(IIF(b.timestamp - u.timestamp > (1000 * 60 * 60 * {session_hours}), 0, u.lectures_cumcnt_session), 0) lectures_cumcnt_session,
            IFNULL(IIF(b.timestamp - u.timestamp > (1000 * 60 * 60 * {session_hours}), -1, u.ac_cumsum_pct_session), -1) ac_cumsum_pct_session,
            IFNULL(u.ac_cumsum, 0) ac_cumsum, 
            IFNULL(u.aic_cumsum, 0) aic_cumsum,
            IFNULL(u.ac_cumsum_pct, -1) ac_cumsum_pct,
            IFNULL(u.r_cumcnt, 0) r_cumcnt,
            IFNULL(u.lectures_cumcnt, 0) lectures_cumcnt,
            IFNULL(u.ac_cumsum_upto, 0) ac_cumsum_upto, 
            IFNULL(u.aic_cumsum_upto, 0) aic_cumsum_upto,
            IFNULL(u.ac_cumsum_pct_upto, -1) ac_cumsum_pct_upto,
            IFNULL(u.r_cumcnt_upto, 0) r_cumcnt_upto,
            IFNULL(uc.ac_cumsum_content_id, 0) ac_cumsum_content_id, 
            IFNULL(uc.aic_cumsum_content_id, 0) aic_cumsum_content_id,
            IFNULL(uc.ac_cumsum_pct_content_id, -1) ac_cumsum_pct_content_id,
            IFNULL(uc.r_cumcnt_content_id, 0) r_cumcnt_content_id,
            IFNULL(ut.ac_cumsum_tag, 0) ac_cumsum_tag_0, 
            IFNULL(ut.aic_cumsum_tag, 0) aic_cumsum_tag_0,
            IFNULL(ut.ac_cumsum_pct_tag, -1) ac_cumsum_pct_tag_0,
            IFNULL(ut.r_cumcnt_tag, 0) r_cumcnt_tag_0,
            IFNULL(ut.lectures_cumcnt_tag, 0) lectures_cumcnt_tag_0,
            IFNULL(ut3.ac_cumsum_tags, 0) ac_cumsum_tags, 
            IFNULL(ut3.aic_cumsum_tags, 0) aic_cumsum_tags,
            IFNULL(ut3.r_cumcnt_tags, 0) r_cumcnt_tags,
            IFNULL(ut3.lectures_cumcnt_tags, 0) lectures_cumcnt_tags,
            IFNULL(CAST(ut3.ac_cumsum_tags * 100 / ut3.r_cumcnt_tags AS INTEGER), -1) ac_cumsum_pct_tags,
            IFNULL(up.ac_cumsum_part, 0) ac_cumsum_part, 
            IFNULL(up.aic_cumsum_part, 0) aic_cumsum_part,
            IFNULL(up.ac_cumsum_pct_part, -1) ac_cumsum_pct_part,
            IFNULL(up.r_cumcnt_part, 0) r_cumcnt_part,
            IFNULL(up.lectures_cumcnt_part, 0) lectures_cumcnt_part,
            {(', ').join([f'q.{c}' for c in q_cols])}
        FROM b
        LEFT JOIN (
            WITH bt ({cols_string}) AS (
            VALUES {values_string}
            )
            SELECT
                bt.user_id, bt.content_id,
                    SUM(ut.ac_cumsum_tag) ac_cumsum_tags,
                    SUM(ut.aic_cumsum_tag) aic_cumsum_tags,
                    SUM(ut.r_cumcnt_tag) r_cumcnt_tags,
                    SUM(ut.lectures_cumcnt_tag) lectures_cumcnt_tags
            FROM bt
            LEFT JOIN questions_tags qt
                ON qt.question_id = bt.content_id
            LEFT JOIN users_tag ut
                ON ut.tag = qt.tag AND ut.user_id = bt.user_id
            GROUP BY
                bt.user_id, bt.content_id
        ) ut3
            ON ut3.user_id = b.user_id AND ut3.content_id = b.content_id
        LEFT JOIN users u
            ON u.user_id = b.user_id
        LEFT JOIN users_content uc
            ON uc.user_id = b.user_id AND uc.content_id = b.content_id
        LEFT JOIN questions q
            ON q.question_id = b.content_id
        LEFT JOIN users_tag ut
            ON ut.user_id = b.user_id AND ut.tag = q.tag_0
        LEFT JOIN users_part up
            ON up.user_id = b.user_id AND up.part = q.part
    """

### Update State

In [None]:
def update_state(df, session_hours=72):
    
    def get_update_values(r):
        values_u = f"""({r.user_id}, {r.answered_correctly}, {r.answered_correctly}, {1-r.answered_correctly}, {1-r.answered_correctly}, 1, 1, 0,
                    {r.answered_correctly}, {1-r.answered_correctly}, 1, 0, {r.timestamp}, 0)"""
        values_uc = f'({r.user_id}, {r.content_id}, {r.answered_correctly}, {1-r.answered_correctly}, 1)'
        values_ut = [f'({r.user_id}, {t}, {r.answered_correctly}, {1-r.answered_correctly}, 1, 0)' for t in q_tags_array_map[r.content_id]]
        values_up = f'({r.user_id}, {q_part_map[r.content_id]}, {r.answered_correctly}, {1-r.answered_correctly}, 1, 0)'
        return values_u, values_uc, values_ut, values_up
    
    values = df.apply(get_update_values, axis=1, result_type='expand')
    
    return f"""
        INSERT INTO users (user_id, ac_cumsum, ac_cumsum_upto, aic_cumsum, aic_cumsum_upto, r_cumcnt, r_cumcnt_upto, lectures_cumcnt,
            ac_cumsum_session, aic_cumsum_session, r_cumcnt_session, lectures_cumcnt_session, timestamp, session)
        VALUES {(',').join(values[0])}
        ON CONFLICT (user_id) DO UPDATE SET
            ac_cumsum = ac_cumsum + excluded.ac_cumsum,
            aic_cumsum = aic_cumsum + excluded.aic_cumsum,
            r_cumcnt = r_cumcnt + excluded.r_cumcnt,
            ac_cumsum_upto = IIF(r_cumcnt_upto < 10, ac_cumsum_upto + excluded.ac_cumsum, ac_cumsum_upto),
            aic_cumsum_upto = IIF(r_cumcnt_upto < 10, aic_cumsum_upto + excluded.aic_cumsum, aic_cumsum_upto),
            r_cumcnt_upto = IIF(r_cumcnt_upto < 10, r_cumcnt_upto + 1, r_cumcnt_upto),
            session = IIF(excluded.timestamp - timestamp > (1000 * 60 * 60 * {session_hours}), session + 1, session),
            ac_cumsum_session = IIF(excluded.timestamp - timestamp > (1000 * 60 * 60 * {session_hours}), excluded.ac_cumsum_session, ac_cumsum_session + excluded.ac_cumsum_session),
            aic_cumsum_session = IIF(excluded.timestamp - timestamp > (1000 * 60 * 60 * {session_hours}), excluded.aic_cumsum_session, aic_cumsum_session + excluded.aic_cumsum_session),
            r_cumcnt_session = IIF(excluded.timestamp - timestamp > (1000 * 60 * 60 * {session_hours}), excluded.r_cumcnt_session, r_cumcnt_session + excluded.r_cumcnt_session),
            timestamp = excluded.timestamp;
        
        INSERT INTO users_content (user_id, content_id, ac_cumsum_content_id, aic_cumsum_content_id, r_cumcnt_content_id)
        VALUES {(',').join(values[1])}
        ON CONFLICT (user_id, content_id) DO UPDATE SET
            ac_cumsum_content_id = ac_cumsum_content_id + excluded.ac_cumsum_content_id,
            aic_cumsum_content_id = aic_cumsum_content_id + excluded.aic_cumsum_content_id,
            r_cumcnt_content_id = r_cumcnt_content_id + 1;
        
        INSERT INTO users_tag (user_id, tag, ac_cumsum_tag, aic_cumsum_tag, r_cumcnt_tag, lectures_cumcnt_tag)
        VALUES {(',').join(sum(values[2], []))}
        ON CONFLICT (user_id, tag) DO UPDATE SET
            ac_cumsum_tag = ac_cumsum_tag + excluded.ac_cumsum_tag,
            aic_cumsum_tag = aic_cumsum_tag + excluded.aic_cumsum_tag,
            r_cumcnt_tag = r_cumcnt_tag + 1;
        
        INSERT INTO users_part (user_id, part, ac_cumsum_part, aic_cumsum_part, r_cumcnt_part, lectures_cumcnt_part)
        VALUES {(',').join(values[3])}
        ON CONFLICT (user_id, part) DO UPDATE SET
            ac_cumsum_part = ac_cumsum_part + excluded.ac_cumsum_part,
            aic_cumsum_part = aic_cumsum_part + excluded.aic_cumsum_part,
            r_cumcnt_part = r_cumcnt_part + 1; 
    """

In [None]:
def update_state_l(df, session_hours=72):
    
    def get_update_values(r):
        values_u = f'({r.user_id}, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, {r.timestamp})'
        values_ut = f'({r.user_id}, {l_tags_array_map[r.content_id]}, 1, 0, 0, 0)'
        values_up = f'({r.user_id}, {l_part_map[r.content_id]}, 1, 0, 0, 0)'
        return values_u, values_ut, values_up
    
    values = df.apply(get_update_values, axis=1, result_type='expand')
    
    return f"""
        INSERT INTO users (user_id, lectures_cumcnt, ac_cumsum, aic_cumsum,
            ac_cumsum_upto, aic_cumsum_upto, r_cumcnt, r_cumcnt_upto,
            ac_cumsum_session, aic_cumsum_session, r_cumcnt_session, lectures_cumcnt_session, timestamp)
        VALUES {(',').join(values[0])}
        ON CONFLICT (user_id) DO UPDATE SET
            lectures_cumcnt = lectures_cumcnt + excluded.lectures_cumcnt,
            lectures_cumcnt_session = IIF(excluded.timestamp - timestamp > (1000 * 60 * 60 * {session_hours}),
                                            excluded.lectures_cumcnt_session,
                                            lectures_cumcnt_session + excluded.lectures_cumcnt_session);
                        
        INSERT INTO users_tag (user_id, tag, lectures_cumcnt_tag, ac_cumsum_tag,
            aic_cumsum_tag, r_cumcnt_tag)
        VALUES {(',').join(values[1])}
        ON CONFLICT (user_id, tag) DO UPDATE SET
            lectures_cumcnt_tag = lectures_cumcnt_tag + 1;
            
        INSERT INTO users_part (user_id, part, lectures_cumcnt_part, ac_cumsum_part,
            aic_cumsum_part, r_cumcnt_part)
        VALUES {(',').join(values[2])}
        ON CONFLICT (user_id, part) DO UPDATE SET
            lectures_cumcnt_part = lectures_cumcnt_part + excluded.lectures_cumcnt_part;

    """

In [None]:
if mock:
    pbar = tqdm(total=env.records_remaining)

df_batch_prior = None
counter = 0

for test_batch in iter_test:
    counter += 1

    # update state
    if df_batch_prior is not None:
        answers = eval(test_batch[0]['prior_group_answers_correct'].iloc[0])
        df_batch_prior['answered_correctly'] = answers
        df_batch_prior_q = df_batch_prior[df_batch_prior.content_type_id == 0]
        if len(df_batch_prior_q):
            sql = update_state(df_batch_prior_q)
            c.executescript(sql)

        df_batch_prior_l = df_batch_prior[df_batch_prior.content_type_id == 1]
        if len(df_batch_prior_l):
            sql = update_state_l(df_batch_prior_l)
            c.executescript(sql)

        if not counter % 100:
            conn.commit()

    # save prior batch for state update
    df_batch_prior = test_batch[0][batch_cols_prior].astype({k: dtypes[k] for k in batch_cols_prior})

    # get state
    df_batch = test_batch[0][test_batch[0].content_type_id == 0]
    if len(df_batch):
        records = df_batch[batch_cols].fillna(0).to_records(index=False)
        df_batch = pd.read_sql(select_state(batch_cols, records), conn)
        df_batch['pred_collab_logit'] = get_preds(df_batch, weights)

        # predict
        preds_list = []
        for m in models:
            preds_list.append(models[m]['model'].predict(xgb.DMatrix(df_batch[test_cols])))

        df_batch['answered_correctly'] = np.average(np.stack(preds_list), axis=0)

        #submit
        env.predict(df_batch[['row_id', 'answered_correctly']])
        
    if mock:
        pbar.update(len(test_batch[0]))

In [None]:
if mock:
    df_pred = env.get_predictions()

## Tests

In [None]:
if False:

    def get_test_batches():
        return next(iter_test)[0], next(iter_test)[0]

    df_batch_prior_test, df_batch_test = get_test_batches()

    for _ in trange(100):
        # iterate through test batches until on one that has lectures, is less than 20 records and has at least one user with multiple questions
        while df_batch_prior_test.content_type_id.sum() == 0 or len(df_batch_prior_test) > 20 or df_batch_prior_test.user_id.value_counts().max() == 1:
            df_batch_prior_test, df_batch_test = get_test_batches()

        # get state before updating
        records_test = df_batch_prior_test[df_batch_prior_test.content_type_id == 0][batch_cols].fillna(0).to_records(index=False)
        df_batch_prior_test_state = pd.read_sql(select_state(batch_cols, records_test), conn)

        # update state from questions
        answers_test = eval(df_batch_test['prior_group_answers_correct'].iloc[0])
        df_batch_prior_test['answered_correctly'] = answers_test
        df_batch_prior_q_test = df_batch_prior_test[df_batch_prior_test.content_type_id == 0]
        sql = update_state(df_batch_prior_q_test)
        c.executescript(sql).fetchone()

        # update state from lectures
        df_batch_prior_l_test = df_batch_prior_test[df_batch_prior_test.content_type_id == 1]
        sql = update_state_l(df_batch_prior_l_test)
        c.executescript(sql).fetchone()

        # get state after update
        df_batch_prior_test_state_updated = pd.read_sql(select_state(batch_cols, records_test), conn)

        # tests, yo!
        assert len(answers_test) == len(df_batch_prior_test), \
                'len of answers from current not equal to len of prior'

        assert (df_batch_prior_test_state.timestamp.values
                == df_batch_prior_test[df_batch_prior_test.content_type_id == 0].timestamp.values).all(), \
                'timestamps returned from get_state should should be from batch'

        assert (df_batch_prior_test_state_updated.timestamp == df_batch_prior_test_state.timestamp).all(), \
                'timestamps should all be the same after update'

                # this will be greater when comparing users from consecutive batches, but comparing update of
                # same batch here. it is the same because the database got update with the value from the 
                # current batch

        assert (df_batch_prior_test_state.pqet_sec
                == (df_batch_prior_test_state.prior_question_elapsed_time / 1000).astype(int)).all(), \
                'pqet_sec = prior_question_elapsed_time / 1000'

        assert (df_batch_prior_test_state.ts_minute
                == (df_batch_prior_test_state.timestamp / (1000 * 60)).astype(int)).all(), \
                'ts_minute = timestamp / (1000 * 60)'

        r_cnt = df_batch_prior_test_state[['user_id', 'content_id']].groupby('user_id').count()
        assert ((df_batch_prior_test_state['r_cumcnt'].add(df_batch_prior_test_state.user_id.map(r_cnt.content_id), axis='rows')
                 == df_batch_prior_test_state_updated['r_cumcnt']).all()), \
                'r_cumcnt for each user should increment by number of questions in batch for each user'

        l_cnt = df_batch_prior_test[['user_id', 'content_type_id']].groupby('user_id').sum()
        assert ((df_batch_prior_test_state['lectures_cumcnt'].add(df_batch_prior_test_state.user_id.map(l_cnt.content_type_id), axis='rows')
                 == df_batch_prior_test_state_updated['lectures_cumcnt']).all()), \
                'lectures_cumcnt for each user should increment by number of lectures in batch for each user'

        assert ((df_batch_prior_test_state['r_cumcnt_clip']
                 .add(df_batch_prior_test_state.user_id.map(r_cnt.content_id), axis='rows')
                 .apply(lambda x: min(x, 300))
                 == df_batch_prior_test_state_updated['r_cumcnt_clip']).all()), \
                'r_cumcnt_clip should increment by number of questions in batch for each user with max of 300'

        ac_sum = df_batch_prior_test[['user_id', 'answered_correctly']].groupby('user_id').sum()
        assert ((df_batch_prior_test_state['ac_cumsum'].add(df_batch_prior_test_state.user_id.map(ac_sum.answered_correctly), axis='rows')
                 == df_batch_prior_test_state_updated['ac_cumsum']).all()), \
                'ac_cumsum for each user should increment by number of questions answered_correctly in batch for each user'

        aic_sum = df_batch_prior_test[['user_id', 'answered_correctly']].copy()
        aic_sum['answered_incorrectly'] = (aic_sum.answered_correctly == 0).astype(int)
        aic_sum = aic_sum.groupby('user_id').sum()
        assert ((df_batch_prior_test_state['aic_cumsum'].add(df_batch_prior_test_state.user_id.map(aic_sum.answered_incorrectly), axis='rows')
                 == df_batch_prior_test_state_updated['aic_cumsum']).all()), \
                'aic_cumsum for each user should increment by number of questions answered_incorrectly in batch for each user'

        assert (df_batch_prior_test_state_updated.session == df_batch_prior_test_state.session).all(), \
                'session should be the same if the change in timestamp is less than session_hours'

        assert (df_batch_prior_test_state_updated.ac_cumsum_session
                == df_batch_prior_test_state.ac_cumsum_session
                    .add(df_batch_prior_test_state.user_id.map(ac_sum.answered_correctly), axis='rows')).all(), \
                'ac_cumsum_session should increment if timestamp is less than session_hours'

        records_test_add_time = [r.copy() for r in records_test]
        for r in records_test_add_time:
            r.timestamp += (1000 * 60 * 60 * 73)
        assert (pd.read_sql(select_state(batch_cols, records_test_add_time), conn).session
                == df_batch_prior_test_state.session + 1).all(), \
                'session increments if change in timestamp delta is greater than session_hours'

        assert ~pd.read_sql(select_state(batch_cols, records_test_add_time), conn).ac_cumsum_session.sum(), \
                'ac_cumsum_session should be zero if timestamp delta is greater than session_hours'

        assert ~pd.read_sql(select_state(batch_cols, records_test_add_time), conn).aic_cumsum_session.sum(), \
                'aic_cumsum_session should be zero if timestamp delta is greater than session_hours'

        assert ~pd.read_sql(select_state(batch_cols, records_test_add_time), conn).r_cumcnt_session.sum(), \
                'r_cumcnt_session should be zero if timestamp delta is greater than session_hours'

        assert ~pd.read_sql(select_state(batch_cols, records_test_add_time), conn).lectures_cumcnt_session.sum(), \
                'lectures_cumcnt_session should be zero if timestamp delta is greater than session_hours'

        ac_sum_part = df_batch_prior_test[['user_id', 'content_id', 'answered_correctly']][df_batch_prior_test.content_type_id == 0]
        ac_sum_part['part'] = ac_sum_part.content_id.map(q_part_map)
        ac_sum_part = ac_sum_part.groupby(['user_id', 'part'])['answered_correctly'].sum()
        assert (df_batch_prior_test_state.set_index(['user_id', 'part']).ac_cumsum_part.add(ac_sum_part).reset_index()[0]
                == df_batch_prior_test_state_updated['ac_cumsum_part']).all(), \
                'ac_cumsum_part for each user should increment by number of questions answered_correctly in batch for each user for each part'

        aic_sum_part = df_batch_prior_test[['user_id', 'content_id', 'answered_correctly']][df_batch_prior_test.content_type_id == 0].copy()
        aic_sum_part['answered_incorrectly'] = (aic_sum_part.answered_correctly == 0).astype(int)
        aic_sum_part['part'] = aic_sum_part.content_id.map(q_part_map)
        aic_sum_part = aic_sum_part.groupby(['user_id', 'part'])['answered_incorrectly'].sum()
        assert (df_batch_prior_test_state.set_index(['user_id', 'part']).aic_cumsum_part.add(aic_sum_part).reset_index()[0]
                == df_batch_prior_test_state_updated['aic_cumsum_part']).all(), \
                'aic_cumsum_part for each user should increment by number of questions answered incorrectly in batch for each user for each part'

        user_tag_tuples = sum(df_batch_prior_test[df_batch_prior_test.content_type_id == 0]
                          .apply(lambda r: [(r.user_id, t, r.answered_correctly) for
                          t in q_tags_array_map[r.content_id]], axis=1), [])
        r_cnt_tag = pd.DataFrame(user_tag_tuples, columns=['user_id', 'tag_0', 'answered_correctly'])
        r_cnt_tag = r_cnt_tag.groupby(['user_id', 'tag_0']).count()
        assert ((df_batch_prior_test_state.set_index(['user_id', 'tag_0'])
                 .join(r_cnt_tag)['answered_correctly'].sort_index()
                 + df_batch_prior_test_state.set_index(['user_id', 'tag_0']).r_cumcnt_tag_0.sort_index())
                == df_batch_prior_test_state_updated.set_index(['user_id', 'tag_0']).r_cumcnt_tag_0.sort_index()).all(), \
                'r_cumcnt_tag_0 is incremented for each user for each tag in the batch'
        
        ac_sum_tag = pd.DataFrame(user_tag_tuples, columns=['user_id', 'tag_0', 'answered_correctly'])
        ac_sum_tag = ac_sum_tag.groupby(['user_id', 'tag_0']).sum()
        assert ((df_batch_prior_test_state.set_index(['user_id', 'tag_0'])
                 .join(ac_sum_tag)['answered_correctly'].sort_index()
                 + df_batch_prior_test_state.set_index(['user_id', 'tag_0']).ac_cumsum_tag_0.sort_index())
                == df_batch_prior_test_state_updated.set_index(['user_id', 'tag_0']).ac_cumsum_tag_0.sort_index()).all(), \
                'ac_cumsum_tag_0 is incremented for each user for each tagged question answered_correctly in the batch'