# RIIID - SAINT+ Model Inference

In [1]:
import riiideducation
import math
import torch
import time

import pandas as pd
import numpy as np
import torch.nn as nn 
import torch.nn.functional as F

from tqdm.notebook import tqdm
from sklearn.metrics import roc_auc_score
from collections import deque, defaultdict
from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer

In [2]:
env = riiideducation.make_env()
iter_test = env.iter_test()

In [3]:
AMOUNT = 100
PAD = 0
BATCH_SIZE = 100

In [4]:
QUESTIONS_PATH = '../input/riiid-test-answer-prediction/questions.csv'
df_questions = pd.read_csv(QUESTIONS_PATH)

In [5]:
part_ids_map = dict(zip(df_questions.question_id, df_questions.part))

## Data Processing

In [6]:
import pickle
with open('../input/saint-model/user_prepared_dataframe.pickle', 'rb') as f:
    user_data = pickle.load(f)

with open('../input/saint-model/user_index_dataframe.pickle', 'rb') as f:
    user_idx = pickle.load(f)

In [7]:
class Riiid(torch.utils.data.Dataset):
    
    def generate(idx):
        return {
            "user_id": idx,
            "content_id" : deque([PAD]*AMOUNT, maxlen=AMOUNT),
            "answered_correctly" : deque([PAD]*AMOUNT, maxlen=AMOUNT),
            "task_container_id" : deque([PAD]*AMOUNT, maxlen=AMOUNT),
            "lagtime" : deque([PAD]*AMOUNT, maxlen=AMOUNT),
            "prior_question_elapsed_time" : deque([PAD]*AMOUNT, maxlen=AMOUNT),
            "part_id": deque([PAD]*AMOUNT, maxlen=AMOUNT),
            "padded" : deque([True]*AMOUNT, maxlen=AMOUNT),
            "timestamp": deque([PAD]*AMOUNT, maxlen=AMOUNT) # for mere calculation of lagtime
        }
    
    def __init__(self, d, idxs):
        self.d = d
        self.idxs = idxs # we use this to locate where to locate the hit
    
    def __len__(self):
        return len(self.d)
    
    def __getitem__(self, idx):
        # you can return a dict of these as well etc etc...
        # remember the order
        if idx not in self.idxs:
                self.idxs[idx] = max(self.d.keys()) + 1
                self.d[self.idxs[idx]] = Riiid.generate(idx)
        idx = self.idxs[idx]
        return idx, self.d[idx]["content_id"], self.d[idx]["task_container_id"], \
    self.d[idx]["part_id"], self.d[idx]["prior_question_elapsed_time"], self.d[idx]["padded"], \
    self.d[idx]["answered_correctly"], self.d[idx]['lagtime'], self.d[idx]['timestamp']
            
    
    def update(self, df):
        #numpy array with the data, the labels are to guarantee the order
        user_id = np.array(df['user_id'])
        df_content_id = np.array(df['content_id'])
        df_task_container_id = np.array(df['task_container_id'])
        df_prior_question_elapsed_time = np.array(df['prior_question_elapsed_time'])
        df_part_id = np.array(df['part_id'])
        df_timestamp = np.array(df['timestamp'])
        size = len(user_id)

        for i in range(size):
          _, content_id, task_container_id, part_id, prior_question_elapsed_time, padded, _, lagtime, timestamp = self[user_id[i]]
          content_id.popleft()
          task_container_id.popleft()
          prior_question_elapsed_time.popleft()
          part_id.popleft()
          padded.popleft()
          lagtime.popleft()
          timestamp.popleft()
          
          content_id.append(df_content_id[i])
          task_container_id.append(df_task_container_id[i])
          prior_question_elapsed_time.append(df_prior_question_elapsed_time[i])
          part_id.append(df_part_id[i])
          padded.append(False)
          timestamp.append(df_timestamp[i])
          
          # new lagtime is the difference of 2 last timestamps
          lag = (timestamp[-1] - timestamp[-2]) // 1000
          lagtime.append(lag if lag <= 300 else 300)
            
    
    def add_answers(self, users, answers):
        pairs = zip(users, answers)
        for user, answer in filter(lambda x: x[1] != -1, pairs):
            answers = self[user][-1]
            answers.popleft()
            answers.append(answer)

def collate_fn(batch):
    _, content_id, task_id, part_id, prior_question_elapsed_time, padded, labels, lagtime, _ = zip(*batch)
    content_id = torch.Tensor(content_id).long()
    task_id = torch.Tensor(task_id).long()
    part_id = torch.Tensor(part_id).long()
    lagtime = torch.Tensor(lagtime).long()
    prior_question_elapsed_time = torch.Tensor(prior_question_elapsed_time).long()
    padded = torch.Tensor(padded).bool()
    labels = torch.Tensor(labels)
    # remember the order
    return content_id, task_id, part_id, prior_question_elapsed_time, padded, labels, lagtime

## SAINT+ Model

In [8]:
class SAINT(nn.Module):
  def __init__(self, d_model, nhead, num_encoder_layers, num_decoder_layers, dropout, dim_feedforward, device='cpu'):
    super(SAINT, self).__init__()
    self.model = nn.Transformer(d_model=d_model, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, dropout=dropout, dim_feedforward=dim_feedforward).to(device)
    
    # Encoder embeddings
    self.exercise_embeddings = nn.Embedding(num_embeddings=13523, embedding_dim=d_model) # exercise_id
    self.enc_pos_embedding = nn.Embedding(d_model, d_model) # positional embeddings
    self.part_embeddings = nn.Embedding(num_embeddings=7+1, embedding_dim=d_model) # part_id_embeddings
    
    # Decoder embeddings
    self.prior_question_elapsed_time = nn.Embedding(num_embeddings=302, embedding_dim=d_model, padding_idx=301) # prior_question_elapsed_time
    self.dec_pos_embedding = nn.Embedding(d_model+1, d_model, padding_idx=d_model) # positional embeddings
    self.correctness_embeddings = nn.Embedding(num_embeddings=3, embedding_dim=d_model, padding_idx=2) # Correctness embeddings
    self.lagtime = nn.Embedding(num_embeddings=302, embedding_dim=d_model, padding_idx=301) #lag time embedding
    
    self.linear = nn.Linear(d_model, 1)

    self.device = device
    self.init_weights()

  def init_weights(self):
    initrange = 0.1
    # init embeddings
    # FIXME should be Xavier uniform acording to paper
    self.exercise_embeddings.weight.data.uniform_(-initrange, initrange)
    self.part_embeddings.weight.data.uniform_(-initrange, initrange)
    self.prior_question_elapsed_time.weight.data.uniform_(-initrange, initrange)
    self.lagtime.weight.data.uniform_(-initrange, initrange)
    self.correctness_embeddings.weight.data.uniform_(-initrange, initrange)
    self.enc_pos_embedding.weight.data.uniform_(-initrange, initrange)
    self.dec_pos_embedding.weight.data.uniform_(-initrange, initrange)

  def forward(self, encoder_exercises, encoder_position, encoder_part, encoder_padding, decoder_correctness, decoder_position, decoder_elapsed_time, decoder_padding, decoder_lagtime):
    encoder_exercises = encoder_exercises.to(self.device)
    encoder_position = encoder_position.to(self.device)
    encoder_part = encoder_part.to(self.device)
    decoder_correctness = decoder_correctness.to(self.device)
    decoder_position = decoder_position.to(self.device)
    decoder_elapsed_time = decoder_elapsed_time.to(self.device)
    decoder_lagtime = decoder_lagtime.to(self.device)

    embedding_size = encoder_exercises.shape[1] # S / T
    mask_src = self.model.generate_square_subsequent_mask(sz=embedding_size).to(self.device)
    mask_tgt = self.model.generate_square_subsequent_mask(sz=embedding_size).to(self.device)
    mem_mask = self.model.generate_square_subsequent_mask(sz=embedding_size).to(self.device)
    
    # padded positions are masked from the self attention (when True)
    encoder_padding = encoder_padding.bool().to(self.device)
    decoder_padding = decoder_padding.bool().to(self.device)
    #Memory padding mask is the same as the src

    # Generate embeddings according to paper
    embedded_src = self.exercise_embeddings(encoder_exercises) + \
                   self.enc_pos_embedding(encoder_position) + \
                   self.part_embeddings(encoder_part)
    embedded_src = embedded_src.transpose(0, 1) # (S, N, E)

    embedded_dcdr = self.correctness_embeddings(decoder_correctness) + \
                    self.dec_pos_embedding(decoder_position) + \
                    self.prior_question_elapsed_time(decoder_elapsed_time) + \
                    self.lagtime(decoder_lagtime)

    embedded_dcdr = embedded_dcdr.transpose(0, 1) # (S, N, E)
    
    output = self.model(src=embedded_src, 
                        tgt=embedded_dcdr, 
                        src_mask = mask_src, 
                        tgt_mask = mask_tgt, 
                        memory_mask = mem_mask)
                        # src_key_padding_mask = encoder_padding,
                        # tgt_key_padding_mask = decoder_padding,
                        # memory_key_padding_mask = encoder_padding) # TODO add padding masks
                        # FIXME: Key padding mask not implemented. Error in Pytorch

    output = self.linear(output.transpose(1, 0))

    return output

CORRECTNESS_DEFAULT_TOKEN = 2
POSITION_DEFAULT_TOKEN = 100
ELAPSED_TIME_DEFAULT_TOKEN = 301
LAG_TIME_DEFAULT_TOKEN = 301

def get_batch_embeddings(content, part, correctness, elapsed_time, lagtime, padding, device="cpu"):
  #encoder has size n, decoder has default + n-1 

  size_x = content.shape[1]
  size_y = content.shape[0]

  # Encoder
  encoder_exercises = content.long()
  encoder_position = torch.arange(0, size_x).to(device).unsqueeze(0).repeat(size_y, 1).long()
  encoder_part = part.long()
  encoder_key_padding = padding.bool()

  # Decoder
  default_correct = torch.Tensor([CORRECTNESS_DEFAULT_TOKEN]).unsqueeze(0).repeat(size_y, 1).to(device)
  default_position = torch.Tensor([POSITION_DEFAULT_TOKEN]).unsqueeze(0).repeat(size_y, 1).to(device)
  default_elapsed_time = torch.Tensor([ELAPSED_TIME_DEFAULT_TOKEN]).unsqueeze(0).repeat(size_y, 1).to(device)
  default_lagtime = torch.Tensor([LAG_TIME_DEFAULT_TOKEN]).unsqueeze(0).repeat(size_y, 1).to(device)
  default_padding = torch.Tensor([True]).unsqueeze(0).repeat(size_y, 1).to(device)

  decoder_correctness = torch.cat((default_correct, correctness[:,:size_x-1]), -1).long()
  decoder_position = torch.cat((default_position, torch.arange(0, size_x-1).to(device).unsqueeze(0).repeat(size_y, 1)), -1).long()
  decoder_elapsed_time = torch.cat((default_elapsed_time, elapsed_time[:,:size_x-1]), -1).long()
  decoder_lagtime = torch.cat((default_lagtime, lagtime[:,:size_x-1]), -1).long()
  decoder_key_padding = torch.cat((default_padding, padding[:,:size_x-1]), -1).bool()

  return encoder_exercises, encoder_position, encoder_part, encoder_key_padding, decoder_correctness, decoder_position, decoder_elapsed_time, decoder_key_padding, decoder_lagtime

## Load Pretrained Models

In [10]:
MODEL_PATH = '../input/saint-model/saint_model_2.pt'
# adam optimizer
LEARNING_RATE = 0.001
BETA_1 = 0.9
BETA_2 = 0.999
EPSILON = 1e-8
WARMUP = 4000

#SAINT
N_LAYERS = 4
WINDOW_SIZE = 100
MODEL_DIM = 512
DROPOUT = 0
BATCH_SIZE = 64

device = "cpu" if not torch.cuda.is_available() else torch.device('cuda')
model = SAINT(d_model=WINDOW_SIZE, nhead=5, num_encoder_layers=N_LAYERS, num_decoder_layers=N_LAYERS, dropout=DROPOUT, dim_feedforward=MODEL_DIM, device=device).to(device)
model = model.to(device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device(device)))


<All keys matched successfully>

In [None]:
## Inference

In [11]:
dataset = Riiid(user_data, user_idx)
start_time = time.time()
for index, (test_df, sample_prediction_df) in enumerate(iter_test):
  prev_answers = eval(test_df.iloc[0]['prior_group_answers_correct'])
  test_df = test_df[test_df.content_type_id == 0]
  test_df["prior_question_elapsed_time"].fillna(26000, inplace=True) # FIXME some random value fill in should it be like this?
  test_df["prior_question_elapsed_time"] = test_df["prior_question_elapsed_time"] // 1000
  test_df["prior_question_elapsed_time"].clip(upper=300)
  test_df['prior_question_had_explanation'] = test_df['prior_question_had_explanation'].astype(np.float16).fillna(-1).astype(np.int8)
  test_df['part_id'] = np.array(list(map(lambda x: part_ids_map[x], test_df['content_id'])))

  dataset.update(test_df)

  preds = []
  idxs = test_df['user_id']

  batch = collate_fn([dataset[idx] for idx in idxs])
  # print(batch)
      # extract data
  content_id, task_id, part_id, prior_question_elapsed_time, mask, labels, lagtime = batch
  content_id = content_id.to(device)
  task_id = task_id.to(device)
  part_id = part_id.to(device)
  prior_question_elapsed_time = prior_question_elapsed_time.to(device)
  lagtime = lagtime.to(device)
  mask = mask.to(device)
  labels = labels.to(device)

  # get embeddings
  encoder_exercises, encoder_position, encoder_part, encoder_padding, decoder_correctness, decoder_position, decoder_elapsed_time, decoder_padding, decoder_lagtime = get_batch_embeddings(content_id, part_id, labels, prior_question_elapsed_time, lagtime, mask, device=device)

  # run model 
  output = model(encoder_exercises, encoder_position, encoder_part, encoder_padding, decoder_correctness, decoder_position, decoder_elapsed_time, decoder_padding, decoder_lagtime) 


  preds = torch.sigmoid(output)[:,-1, 0].cpu().detach().numpy()
  preds = (preds >= 0.5).astype(np.int32)
  sample_prediction_df['answered_correctly'] =  preds
  env.predict(sample_prediction_df[['row_id', 'answered_correctly']])
  
  dataset.add_answers(test_df['user_id'].values, prev_answers)

print(f'Ran in {time.time() - start_time}')


Ran in 1.4412074089050293
