# RIIID - SAKT Model Training

Public Leaderboard Score: **0.774**

#### If you like this kernel or forking this kernel, please consider upvoting this and the kernels I copied (acknowledgements) from. It helps them reach more people.

- **Inference Notebook**: https://www.kaggle.com/manikanthr5/riiid-sakt-model-inference-public
- **Pretrained Dataset**: https://www.kaggle.com/manikanthr5/riiid-sakt-model-dataset-public

**Acknowledgement**:

All the credits go to this popular notebook https://www.kaggle.com/leadbest/sakt-with-randomization-state-updates which is a modification of https://www.kaggle.com/wangsg/a-self-attentive-model-for-knowledge-tracing. Please show some support to these original work kernels.

**Possible Improvements**:
- All the data in this notebook is used for training, so create a train and valid dataset for cross validation. Note: For me this degraded my LB score.
- Some other text book ideas you could try:
 - Using Label Smoothing
 - Using Learning Rate Schedulers (ex: [check this kernel](https://www.kaggle.com/scaomath/riiid-sakt-train-with-a-warm-up-scheduler))
 - Increase the max sequence length and/or embedding dimension (ex: [check this kernel](https://www.kaggle.com/gilfernandes/riiid-self-attention-transformer))
 - Add more attention layers
 - Change the sampling strategy (most important)

In [1]:
# import os
# os._exit(0)

In [2]:
import gc
import psutil
import joblib
import random
from tqdm import tqdm

import numpy as np
import pandas as pd

from sklearn.metrics import roc_auc_score

import torch

import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [3]:
TRAIN_SAMPLES = 320000

MAX_SEQ = 100
MIN_SAMPLES = 5
EMBED_DIM = 128
DROPOUT_RATE = 0.2
LEARNING_RATE = 1e-3
MAX_LEARNING_RATE = 2e-3
# EPOCHS = 30
EPOCHS = 10
# TRAIN_BATCH_SIZE = 2048
TRAIN_BATCH_SIZE = 64

## Load data

In [4]:
%%time

dtypes = {'timestamp': 'int64', 'user_id': 'int32' ,'content_id': 'int16','content_type_id': 'int8','answered_correctly':'int8'}
# train_df = pd.read_feather('../input/riiid-cross-validation-dataset/train.feather')[[
#     'timestamp', 'user_id', 'content_id', 'content_type_id', 'answered_correctly'
# ]]
train_df = pd.read_csv('./input/riiid-test-answer-prediction/train_medium.csv')[['timestamp', 'user_id', 'content_id', 'content_type_id', 'answered_correctly']]
for col, dtype in dtypes.items():
    train_df[col] = train_df[col].astype(dtype)
    
#train_df have only rows with False in content_type_id (0 if the event was a question being posed to the user)
train_df = train_df[train_df.content_type_id == False]  

train_df = train_df.sort_values(['timestamp'], ascending=True)
train_df.reset_index(drop=True, inplace=True)
train_df.head(5)

CPU times: user 784 ms, sys: 101 ms, total: 885 ms
Wall time: 982 ms


Unnamed: 0,timestamp,user_id,content_id,content_type_id,answered_correctly
0,0,115,5692,0,1
1,0,7777501,7900,0,1
2,0,2966556,7900,0,1
3,0,31632191,3942,0,0
4,0,25096623,7900,0,1


## Preprocess

In [5]:
skills = train_df["content_id"].unique()
joblib.dump(skills, "skills.pkl.zip")
n_skill = len(skills)
print("number skills", n_skill)

number skills 13374


In [6]:
group = train_df[['user_id', 'content_id', 'answered_correctly']].groupby('user_id').apply(lambda r: (
            r['content_id'].values,
            r['answered_correctly'].values))

joblib.dump(group, "group.pkl.zip")  # Save models?
del train_df
gc.collect()
group

user_id
115         ([5692, 5716, 128, 7860, 7922, 156, 51, 50, 78...
124         ([7900, 7876, 175, 1278, 2065, 2064, 2063, 336...
2746        ([5273, 758, 5976, 236, 404, 382, 405, 873, 53...
5382        ([5000, 3944, 217, 5844, 5965, 4990, 5235, 605...
8623        ([3915, 4750, 6456, 3968, 6104, 5738, 6435, 54...
                                  ...                        
42189769    ([7900, 7876, 175, 1278, 2065, 2063, 2064, 336...
42198669    ([3982, 9723, 6435, 4536, 4002, 5836, 3758, 39...
42200769    ([4554, 3817, 6374, 6372, 289, 7956, 217, 5182...
42206662    ([4126, 6686, 4993, 4089, 6142, 6394, 4435, 37...
42207371    ([4527, 4995, 8737, 8445, 9141, 6368, 11496, 2...
Length: 7712, dtype: object

In [7]:
group[115]
# q:content_id; qa:answered_correctly
# user_id: ((array([content_id]), array([answered_correctly]))

(array([5692, 5716,  128, 7860, 7922,  156,   51,   50, 7896, 7863,  152,
         104,  108, 7900, 7901, 7971,   25,  183, 7926, 7927,    4, 7984,
          45,  185,   55, 7876,    6,  172, 7898,  175,  100, 7859,   57,
        7948,  151,  167, 7897, 7882, 7962, 1278, 2064, 2065, 2063, 3363,
        3364, 3365], dtype=int16),
 array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0,
        1, 0], dtype=int8))

In [9]:
len(group[115][1])

46

In [8]:
group.index[:5]

Int64Index([115, 124, 2746, 5382, 8623], dtype='int64', name='user_id')

In [9]:
TRAIN_SAMPLES = int(len(group.index)*0.8)
print('TRAIN_SAMPLES',TRAIN_SAMPLES)

train_indexes = list(group.index)[:TRAIN_SAMPLES]
valid_indexes = list(group.index)[TRAIN_SAMPLES:]
train_group = group[group.index.isin(train_indexes)]
valid_group = group[group.index.isin(valid_indexes)]
print('train_group \n', train_group[:5] )
print('valid_group \n', valid_group[:5] )

del group, train_indexes, valid_indexes
print(len(train_group), len(valid_group))

TRAIN_SAMPLES 6169
train_group 
 user_id
115     ([5692, 5716, 128, 7860, 7922, 156, 51, 50, 78...
124     ([7900, 7876, 175, 1278, 2065, 2064, 2063, 336...
2746    ([5273, 758, 5976, 236, 404, 382, 405, 873, 53...
5382    ([5000, 3944, 217, 5844, 5965, 4990, 5235, 605...
8623    ([3915, 4750, 6456, 3968, 6104, 5738, 6435, 54...
dtype: object
valid_group 
 user_id
33952399    ([5059, 6681, 5880, 378, 217, 4441, 6381, 6668...
33953997    ([5637, 5473, 4882, 6374, 5182, 217, 4466, 795...
33955846    ([4141, 4941, 5033, 4764, 6400, 4882, 5689, 78...
33956171    ([7900, 7876, 175, 1278, 2063, 2065, 2064, 336...
33959731    ([3862, 4852, 6286, 4452, 8688, 5487, 5163, 98...
dtype: object
6169 1543


In [10]:
class SAKTDataset(Dataset):
    def __init__(self, group, n_skill, min_samples=1, max_seq=128):
        super(SAKTDataset, self).__init__()
        self.max_seq = max_seq
        self.n_skill = n_skill
        self.samples = {}
        
        self.user_ids = []
        for user_id in group.index:
            q, qa = group[user_id]  # q:content_id; qa:answered_correctly
            if len(q) < min_samples:
                continue 
            
            # Main Contribution
            if len(q) > self.max_seq:
                total_questions = len(q)
                initial = total_questions % self.max_seq
                if initial >= min_samples:
                    self.user_ids.append(f"{user_id}_0")
                    self.samples[f"{user_id}_0"] = (q[:initial], qa[:initial])
                for seq in range(total_questions // self.max_seq):
                    self.user_ids.append(f"{user_id}_{seq+1}")
                    start = initial + seq * self.max_seq
                    end = start + self.max_seq
                    self.samples[f"{user_id}_{seq+1}"] = (q[start:end], qa[start:end])
            else:
                user_id = str(user_id)
                self.user_ids.append(user_id)
                self.samples[user_id] = (q, qa)
    
    def __len__(self):
        return len(self.user_ids)

    def __getitem__(self, index):
        user_id = self.user_ids[index]
        q_, qa_ = self.samples[user_id]
        seq_len = len(q_)

        q = np.zeros(self.max_seq, dtype=int)
        qa = np.zeros(self.max_seq, dtype=int)
        if seq_len == self.max_seq:
            q[:] = q_
            qa[:] = qa_
        else:
            q[-seq_len:] = q_
            qa[-seq_len:] = qa_
        
        target_id = q[1:]
        label = qa[1:]

        x = np.zeros(self.max_seq-1, dtype=int)
        x = q[:-1].copy()
        x += (qa[:-1] == 1) * self.n_skill

        return x, target_id, label

In [11]:
train_dataset = SAKTDataset(train_group, n_skill, min_samples=MIN_SAMPLES, max_seq=MAX_SEQ)
train_dataloader = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True, num_workers=8)
valid_dataset = SAKTDataset(valid_group, n_skill, max_seq=MAX_SEQ)
valid_dataloader = DataLoader(valid_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=False, num_workers=8)


In [12]:
print('len(train_dataset)',len(train_dataset),'len(valid_dataset)',len(valid_dataset))

len(train_dataset) 19303 len(valid_dataset) 4864


## Define model

In [None]:
class FFN(nn.Module):
    def __init__(self, state_size=200):
        super(FFN, self).__init__()
        self.state_size = state_size

        self.lr1 = nn.Linear(state_size, state_size)
        self.relu = nn.ReLU()
        self.lr2 = nn.Linear(state_size, state_size)
        self.dropout = nn.Dropout(0.2)
    
    def forward(self, x):
        x = self.lr1(x)
        x = self.relu(x)
        x = self.lr2(x)
        return self.dropout(x)

def future_mask(shape):
    future_mask = np.triu(np.ones(shape), k=1).astype('bool')
    return torch.from_numpy(future_mask)

class SubLayer(nn.Module):
    def __init__(self,embed_dim):
        super().__init__()
        self.multi_att = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=8, dropout=0.2)

        self.dropout = nn.Dropout(0.2)
        self.layer_normal = nn.LayerNorm(embed_dim) 

        self.ffn = FFN(embed_dim)
    def forward(self, e, x):
        att_mask = future_mask(shape=(e.size(0), x.size(0))).to(device)
        att_output, att_weight = self.multi_att(e, x, x, attn_mask=att_mask)
        att_output = self.layer_normal(att_output + e)
        att_output = att_output.permute(1, 0, 2) # att_output: [s_len, bs, embed] => [bs, s_len, embed]

        x = self.ffn(att_output)
        x = self.layer_normal(x + att_output)
        return x, att_weight
class BERTModel(nn.Module):
    def __init__(self, n_skill, max_seq=MAX_SEQ, embed_dim=128):
        super().__init__()
        self.pos_embedding = nn.Embedding(max_seq, embed_dim)
        self.embedding = nn.Embedding(n_skill+1, embed_dim)
        
        self.ans_embedding = nn.Embedding(3, embed_dim)
        self.time_embedding = nn.Embedding(10000, embed_dim)
        
        self.lag_time_embedding = nn.Embedding(3600, embed_dim)
        
        self.elapsed_time_embedding = nn.Embedding(1520, embed_dim)
        
        self.part_embedding = nn.Embedding(8, embed_dim)
        
        self.sub1 = SubLayer(embed_dim)
        
        self.fc = nn.Linear(embed_dim, embed_dim*2)
        self.fc1 = nn.Linear(embed_dim*5, embed_dim)
        
        self.bacth_norm = nn.BatchNorm1d(max_seq)
        self.bacth_norm1 = nn.BatchNorm1d(max_seq)
        
        self.pred = nn.Linear(embed_dim*2, 1)
        
    def forward(self, history_question, history_answer, time, lag_time, part, elapsed_time):
        device = history_question.device
        history_answer = history_answer
        history_answer = self.ans_embedding(history_answer)
        
        x = self.embedding(history_question)
        pos_id = torch.arange(x.size(1)).unsqueeze(0).to(device)
        pos_x = self.pos_embedding(pos_id)
        time_x = self.time_embedding(time)
        lag_time = self.lag_time_embedding(lag_time)
        part = self.part_embedding(part)
        elapsed_time = self.elapsed_time_embedding(elapsed_time)
        
        history_answer += pos_x
        x += history_answer
        time_x += history_answer
        lag_time += history_answer
        part += history_answer
        elapsed_time += history_answer
        x = torch.cat([x, time_x, lag_time, part, elapsed_time], axis=-1)
        x = self.fc1(x)
        
        
        x = x.permute(1, 0, 2) # x: [bs, s_len, embed] => [s_len, bs, embed]
        x, att_weight= self.sub1(x, x)
        
        x = self.fc(x)
        x = self.bacth_norm(x)

        x = self.pred(x)
        
        return x.squeeze(-1), att_weight

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BERTModel(n_skill, max_seq=MAX_SEQ, embed_dim=EMBED_DIM, dropout_rate=DROPOUT_RATE)

In [13]:
# class FFN(nn.Module):
#     def __init__(self, state_size=200):
#         super(FFN, self).__init__()
#         self.state_size = state_size

#         self.lr1 = nn.Linear(state_size, state_size)
#         self.relu = nn.ReLU()
#         self.lr2 = nn.Linear(state_size, state_size)
#         self.dropout = nn.Dropout(0.2)
    
#     def forward(self, x):
#         x = self.lr1(x)
#         x = self.relu(x)
#         x = self.lr2(x)
#         return self.dropout(x)

# def future_mask(seq_length):
#     future_mask = np.triu(np.ones((seq_length, seq_length)), k=1).astype('bool')
#     return torch.from_numpy(future_mask)


# class SAKTModel(nn.Module):
#     def __init__(self, n_skill, max_seq=128, embed_dim=128, dropout_rate=0.2):
#         super(SAKTModel, self).__init__()
#         self.n_skill = n_skill
#         self.embed_dim = embed_dim

#         self.embedding = nn.Embedding(2*n_skill+1, embed_dim)
#         self.pos_embedding = nn.Embedding(max_seq-1, embed_dim)
#         self.e_embedding = nn.Embedding(n_skill+1, embed_dim)

#         self.multi_att = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=8, dropout=dropout_rate)

#         self.dropout = nn.Dropout(dropout_rate)
#         self.layer_normal = nn.LayerNorm(embed_dim) 

#         self.ffn = FFN(embed_dim)
#         self.pred = nn.Linear(embed_dim, 1)
    
#     def forward(self, x, question_ids):
#         device = x.device        
#         x = self.embedding(x)
#         pos_id = torch.arange(x.size(1)).unsqueeze(0).to(device)

#         pos_x = self.pos_embedding(pos_id)
#         x = x + pos_x

#         e = self.e_embedding(question_ids)

#         x = x.permute(1, 0, 2) # x: [bs, s_len, embed] => [s_len, bs, embed]
#         e = e.permute(1, 0, 2)
#         att_mask = future_mask(x.size(0)).to(device)
#         att_output, att_weight = self.multi_att(e, x, x, attn_mask=att_mask)
#         att_output = self.layer_normal(att_output + e)
#         att_output = att_output.permute(1, 0, 2) # att_output: [s_len, bs, embed] => [bs, s_len, embed]

#         x = self.ffn(att_output)
#         x = self.layer_normal(x + att_output)
#         x = self.pred(x)

#         return x.squeeze(-1), att_weight

In [14]:
# x
# item[0] tensor([[20319,  6725, 20251,  ..., 18708, 22741, 19804],
#         [21465,     2, 21487,  ..., 20119,  6594, 20301],
#         [13800, 14017,   488,  ...,  4818,  5352, 18955],
#         ...,
#         [    0,     0,     0,  ...,  9221, 13722,   834],
#         [13949, 13957, 14773,  ..., 14506, 14148,   414],
#         [ 2817,  3380, 16901,  ..., 20147, 20382, 20381]])

# target_id
# item[1] tensor([[10688,   739,  1291,  ...,  2587,  2589,  2588],
#         [ 5687,  6372,  6116,  ...,  9435,  3878,  5177],
#         [ 4526,  3685,  3989,  ...,  3573,  8271,  3580],
#         ...,
#         [ 5771,  8548,  5861,  ...,  2830,  2832,  2831],
#         [ 4260,  4259,  4177,  ...,  3759,  8870,  5671],
#         [   65,  1344,   335,  ...,  4078,  3744,  5616]])

# label (answered_correctly)
# item[2] tensor([[1, 1, 1,  ..., 1, 1, 0],
#         [1, 0, 0,  ..., 0, 0, 1],
#         [1, 0, 0,  ..., 1, 1, 0],
#         ...,
#         [1, 0, 1,  ..., 1, 1, 1],
#         [1, 1, 1,  ..., 1, 0, 1],
#         [1, 1, 1,  ..., 0, 1, 0]])

In [15]:
def train_fn(model, dataloader, optimizer, scheduler, criterion, device="cpu"):
    model.train()

    train_loss = []
    num_corrects = 0
    num_total = 0
    labels = []
    outs = []

    for item in dataloader:

        x = item[0].to(device).long()
        target_id = item[1].to(device).long()
# #         x = item[0].to(device)
#         target_id = item[1].to(device)

    
        label = item[2].to(device).float()
        target_mask = (target_id != 0)

        optimizer.zero_grad()
        output, _, = model(x, target_id)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
        scheduler.step()
        train_loss.append(loss.item())

        output = torch.masked_select(output, target_mask)
        label = torch.masked_select(label, target_mask)
        pred = (torch.sigmoid(output) >= 0.5).long()
#         pred = (torch.sigmoid(output) >= 0.5)
    
        num_corrects += (pred == label).sum().item()
        num_total += len(label)

        labels.extend(label.view(-1).data.cpu().numpy())
        outs.extend(output.view(-1).data.cpu().numpy())

    acc = num_corrects / num_total
    auc = roc_auc_score(labels, outs)
    loss = np.mean(train_loss)

    return loss, acc, auc

In [16]:
def valid_fn(model, dataloader, criterion, device="cpu"):
    model.eval()

    valid_loss = []
    num_corrects = 0
    num_total = 0
    labels = []
    outs = []

    for item in dataloader:
        x = item[0].to(device).long()
        target_id = item[1].to(device).long()
#         x = item[0].to(device)
#         target_id = item[1].to(device)

        label = item[2].to(device).float()
        target_mask = (target_id != 0)

        output, _, = model(x, target_id)
        loss = criterion(output, label)
        valid_loss.append(loss.item())

        output = torch.masked_select(output, target_mask)
        label = torch.masked_select(label, target_mask)
        pred = (torch.sigmoid(output) >= 0.5).long()
#         pred = (torch.sigmoid(output) >= 0.5)
    
        num_corrects += (pred == label).sum().item()
        num_total += len(label)

        labels.extend(label.view(-1).data.cpu().numpy())
        outs.extend(output.view(-1).data.cpu().numpy())

    acc = num_corrects / num_total
    auc = roc_auc_score(labels, outs)
    loss = np.mean(valid_loss)

    return loss, acc, auc

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SAKTModel(n_skill, max_seq=MAX_SEQ, embed_dim=EMBED_DIM, dropout_rate=DROPOUT_RATE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.BCEWithLogitsLoss()
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=MAX_LEARNING_RATE, steps_per_epoch=len(train_dataloader), epochs=EPOCHS
)

model.to(device)
criterion.to(device)

BCEWithLogitsLoss()

In [18]:
model

SAKTModel(
  (embedding): Embedding(26749, 128)
  (pos_embedding): Embedding(99, 128)
  (e_embedding): Embedding(13375, 128)
  (multi_att): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
  )
  (dropout): Dropout(p=0.2, inplace=False)
  (layer_normal): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (ffn): FFN(
    (lr1): Linear(in_features=128, out_features=128, bias=True)
    (relu): ReLU()
    (lr2): Linear(in_features=128, out_features=128, bias=True)
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (pred): Linear(in_features=128, out_features=1, bias=True)
)

In [19]:
best_auc = 0
max_steps = 3
step = 0
for epoch in range(EPOCHS):
    loss, acc, auc = train_fn(model, train_dataloader, optimizer, scheduler, criterion, device)
    print("epoch - {}/{} train: - {:.3f} acc - {:.3f} auc - {:.3f}".format(epoch+1, EPOCHS, loss, acc, auc))
    loss, acc, auc = valid_fn(model, valid_dataloader, criterion, device)
    print("epoch - {}/{} valid: - {:.3f} acc - {:.3f} auc - {:.3f}".format(epoch+1, EPOCHS, loss, acc, auc))
    if auc > best_auc:
        best_auc = auc
        step = 0
        torch.save(model.state_dict(), "sakt_model.pt")
    else:
        step += 1
        if step >= max_steps:
            break

../aten/src/ATen/native/cuda/Indexing.cu:975: indexSelectLargeIndex: block: [391,0,0], thread: [96,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:975: indexSelectLargeIndex: block: [391,0,0], thread: [97,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:975: indexSelectLargeIndex: block: [391,0,0], thread: [98,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:975: indexSelectLargeIndex: block: [391,0,0], thread: [99,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:975: indexSelectLargeIndex: block: [391,0,0], thread: [100,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:975: indexSelectLargeIndex: block: [391,0,0], thread: [101,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:975: indexSelectLargeIndex: block: [391,0,0],

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

In [None]:
del train_dataset, valid_dataset

In [None]:
torch.save(model.state_dict(), "sakt_model_final.pt")

---