# RIIID - SAINT Model (Encoder only) Inference

In [1]:
import riiideducation
import math
import torch
import time

import pandas as pd
import numpy as np
import torch.nn as nn 
import torch.nn.functional as F

from tqdm.notebook import tqdm
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from sklearn.metrics import roc_auc_score
from collections import deque, defaultdict

In [2]:
env = riiideducation.make_env()
iter_test = env.iter_test()

In [3]:
AMOUNT = 100
PAD = 0
BATCH_SIZE = 100

In [4]:
QUESTIONS_PATH = '../input/riiid-test-answer-prediction/questions.csv'
df_questions = pd.read_csv(QUESTIONS_PATH)

In [5]:
part_ids_map = dict(zip(df_questions.question_id, df_questions.part))

## Data Processing

In [6]:
def dataset_transform(dataset):
  final_dataset = {}
  user_id_to_idx = {}
  grp = dataset.groupby('user_id').tail(AMOUNT)
  
  for idx, row in tqdm(grp.groupby("user_id").agg({"content_id":list, 
                "answered_correctly":list, 
                "task_container_id":list, 
                "part_id":list, 
                "prior_question_elapsed_time":list}).reset_index().iterrows()):

    # pad the required rows to have AMOUNT values
    if (len(row['content_id']) >= AMOUNT):
      final_dataset[idx] = {
            "user_id": row["user_id"],
            "content_id" : deque(row["content_id"], maxlen=AMOUNT),
            "answered_correctly" : deque(row["answered_correctly"], maxlen=AMOUNT),
            "task_container_id" : deque(row["task_container_id"], maxlen=AMOUNT),
            "prior_question_elapsed_time" : deque(row["prior_question_elapsed_time"], maxlen=AMOUNT),
            "part_id": deque(row["part_id"], maxlen=AMOUNT),
            "padded" : deque([False]*100, maxlen=AMOUNT)
        }
    else: # need to pad
        final_dataset[idx] = {
            "user_id": row["user_id"],
            "content_id" : deque(row["content_id"] + [PAD]*(AMOUNT-len(row["content_id"])), maxlen=AMOUNT),
            "answered_correctly" : deque(row["answered_correctly"] + [PAD]*(AMOUNT-len(row["content_id"])), maxlen=AMOUNT),
            "task_container_id" : deque(row["task_container_id"] + [PAD]*(AMOUNT-len(row["content_id"])), maxlen=AMOUNT),
            "prior_question_elapsed_time" : deque(row["prior_question_elapsed_time"] + [PAD]*(AMOUNT-len(row["content_id"])), maxlen=AMOUNT),
            "part_id": deque(row["part_id"] + [PAD]*(AMOUNT-len(row["content_id"])), maxlen=AMOUNT),
            "padded" : deque([False]*len(row["content_id"]) + [True]*(AMOUNT-len(row["content_id"])), maxlen=AMOUNT)
        }

    user_id_to_idx[row['user_id']] = idx
  # FIXME new users? 
  return final_dataset, user_id_to_idx 

In [7]:
import pickle
with open('../input/saint-model/user_prepared_dataframe.pickle', 'rb') as f:
    user_data = pickle.load(f)

with open('../input/saint-model/user_index_dataframe.pickle', 'rb') as f:
    user_idx = pickle.load(f)

In [8]:
class Riiid(torch.utils.data.Dataset):
    
    def generate(idx):
        return {
            "user_id": idx,
            "content_id" : deque([PAD]*AMOUNT, maxlen=AMOUNT),
            "answered_correctly" : deque([PAD]*AMOUNT, maxlen=AMOUNT),
            "task_container_id" : deque([PAD]*AMOUNT, maxlen=AMOUNT),
            "prior_question_elapsed_time" : deque([PAD]*AMOUNT, maxlen=AMOUNT),
            "part_id": deque([PAD]*AMOUNT, maxlen=AMOUNT),
            "padded" : deque([True]*AMOUNT, maxlen=AMOUNT),
        }
    
    def __init__(self, d, idxs):
        self.d = d
        self.idxs = idxs 
    
    def __len__(self):
        return len(self.d)
    
    def __getitem__(self, idx):
        if idx not in self.idxs:
                self.idxs[idx] = max(self.d.keys()) + 1
                self.d[self.idxs[idx]] = Riiid.generate(idx)
        idx = self.idxs[idx]
        return idx, self.d[idx]["content_id"], self.d[idx]["task_container_id"], \
    self.d[idx]["part_id"], self.d[idx]["prior_question_elapsed_time"], self.d[idx]["padded"], \
    self.d[idx]["answered_correctly"]
            
    
    def update(self, df):
        user_id = np.array(df['user_id'])
        df_content_id = np.array(df['content_id'])
        df_task_container_id = np.array(df['task_container_id'])
        df_prior_question_elapsed_time = np.array(df['prior_question_elapsed_time'])
        df_part_id = np.array(df['part_id'])
        size = len(user_id)

        for i in range(size):
          _, content_id, task_container_id, part_id, prior_question_elapsed_time, padded, labels = self[user_id[i]]
          content_id.popleft()
          task_container_id.popleft()
          prior_question_elapsed_time.popleft()
          part_id.popleft()
          padded.popleft()
          
          content_id.append(df_content_id[i])
          task_container_id.append(df_task_container_id[i])
          prior_question_elapsed_time.append(df_prior_question_elapsed_time[i])
          part_id.append(df_part_id[i])
          padded.append(False)
            

def collate_fn(batch):
    _, content_id, task_id, part_id, prior_question_elapsed_time, padded, labels = zip(*batch)
    content_id = torch.Tensor(content_id).long()
    task_id = torch.Tensor(task_id).long()
    part_id = torch.Tensor(part_id).long()
    prior_question_elapsed_time = torch.Tensor(prior_question_elapsed_time).long()
    padded = torch.Tensor(padded).bool()
    labels = torch.Tensor(labels)
    return content_id, task_id, part_id, prior_question_elapsed_time, padded, labels

## SAINT  Model (Encoder only)

In [9]:
class SAINT(nn.Module):
  def __init__(self, ninp=32, nhead=2, nhid=64, nlayers=2, dropout=0.1):
    super(SAINT, self).__init__()
    # Note: there is no positional encoding for SAINT
    self.src_mask = None
    encoder_layers = TransformerEncoderLayer(d_model=ninp, nhead=nhead, dim_feedforward=nhid, dropout=dropout, activation='relu')
    self.transformer_encoder = TransformerEncoder(encoder_layer=encoder_layers, num_layers=nlayers)
    self.exercise_embeddings = nn.Embedding(num_embeddings=13523, embedding_dim=ninp) # exercise_id
    self.pos_embedding = nn.Embedding(ninp, ninp) # positional embeddings
    self.part_embeddings = nn.Embedding(num_embeddings=7+1, embedding_dim=ninp) # part_id_embeddings
    self.prior_question_elapsed_time = nn.Embedding(num_embeddings=301, embedding_dim=ninp) # prior_question_elapsed_time
    self.device = "cpu" if not torch.cuda.is_available() else torch.device('cuda')
    self.ninp = ninp
    self.decoder = nn.Linear(ninp, 2)
    self.init_weights()
  
  def init_weights(self):
    initrange = 0.1
    # init embeddings
    # FIXME should be Xavier uniform acording to paper
    self.exercise_embeddings.weight.data.uniform_(-initrange, initrange)
    self.part_embeddings.weight.data.uniform_(-initrange, initrange)
    self.prior_question_elapsed_time.weight.data.uniform_(-initrange, initrange)
    self.decoder.bias.data.zero_()
    self.decoder.weight.data.uniform_(-initrange, initrange)

  def generate_square_subsequent_mask(self, sz):
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

  def forward(self, content_id, part_id, prior_question_elapsed_time=None, mask_src=None):
    # Generate embeddings according to paper
    content_id = content_id.to(self.device)
    part_id = part_id.to(self.device)
    prior_question_elapsed_time = prior_question_elapsed_time.to(self.device)
    mask_src = mask_src.to(self.device)

    embedded_src = self.exercise_embeddings(content_id) + \
        self.pos_embedding(torch.arange(0, content_id.shape[1]).to(self.device).unsqueeze(0).repeat(content_id.shape[0], 1)) + \
        self.part_embeddings(part_id) + self.prior_question_elapsed_time(prior_question_elapsed_time) # (N, S, E)
    embedded_src = embedded_src.transpose(0, 1) # (S, N, E)
    
    # Standard transformer
    _src = embedded_src * np.sqrt(self.ninp)
    
    output = self.transformer_encoder(src=_src, src_key_padding_mask=mask_src)
    output = self.decoder(output)
    output = output.transpose(1, 0)
    return output

## Load Pretrained Models

In [10]:
MODEL_PATH = '../input/saint-model/saint_model.pt'

model = SAINT(ninp=100)
device = "cpu" if not torch.cuda.is_available() else torch.device('cuda')
model = model.to(device)
model.device = device
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device(device)))


<All keys matched successfully>

## Inference

In [11]:
dataset = Riiid(user_data, user_idx)
start_time = time.time()

def scale(X, x_min, x_max):
    nom = (X-X.min(axis=0))*(x_max-x_min)
    denom = X.max(axis=0) - X.min(axis=0)
    denom[denom==0] = 1
    return x_min + nom/denom 

for index, (test_df, sample_prediction_df) in enumerate(iter_test):
    test_df = test_df[test_df.content_type_id == 0]

    test_df["prior_question_elapsed_time"].fillna(26000, inplace=True) # FIXME some random value fill in should it be like this?
    test_df["prior_question_elapsed_time"] = test_df["prior_question_elapsed_time"] // 1000
    test_df["prior_question_elapsed_time"].clip(upper=300)
    test_df['prior_question_had_explanation'] = test_df['prior_question_had_explanation'].astype(np.float16).fillna(-1).astype(np.int8)
    test_df['part_id'] = np.array(list(map(lambda x: part_ids_map[x], test_df['content_id'])))

    dataset.update(test_df)

    preds = []
    idxs = test_df['user_id']

    batch = collate_fn([dataset[idx] for idx in idxs])
    # print(batch)
    content_id, task_id, part_id, prior_question_elapsed_time, mask, labels = batch
    content_id = content_id.to(device)
    task_id = task_id.to(device)
    part_id = part_id.to(device)
    prior_question_elapsed_time = prior_question_elapsed_time.to(device)
    mask = mask.to(device)
    labels = labels.to(device)
    output = model.forward(content_id, part_id, prior_question_elapsed_time, mask)
    # print(output)
    #   output_prob = output[:,-1,1]
    #   pred = (pred.cpu().numpy()).astype(np.int32)#[-1]  # we only care about the last one
    preds = torch.sigmoid(output)[:,-1, 1].cpu().detach().numpy()

    # Normalised [0,1]
    sample_prediction_df['answered_correctly'] =  preds
    print(f'{index} has a prediction as the following: {sample_prediction_df}')
    env.predict(sample_prediction_df[['row_id', 'answered_correctly']])

    print(f'Ran in {time.time() - start_time}')


0 has a prediction as the following:            row_id  answered_correctly
group_num                            
0               0            0.378761
0               1            0.817393
0               2            0.840980
0               3            0.762317
0               4            0.354948
0               5            0.477288
0               6            0.630987
0               7            0.775457
0               8            0.678727
0               9            0.653153
0              10            0.692447
0              11            0.364155
0              12            0.364121
0              13            0.372460
0              14            0.377758
0              15            0.735187
0              16            0.605621
0              17            0.874092
Ran in 0.9022457599639893
1 has a prediction as the following:            row_id  answered_correctly
group_num                            
1              18            0.649533
1              19         