In [68]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

import random
import math
import time
import pandas as pd
import torch
from torch.utils.data import Dataset

In [4]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [5]:
train = pd.read_csv('/Users/hesu/Documents/KT/riiid/train_1M.csv',
                   usecols = [1,2,3,4,7,8,9],
                   dtype={'timestamp':'int64',
                         'used_id':'int32',
                         'content_id':'int16',
                         'content_type_id':'int8',
                         'answered_correctly':'int8',
                         'prior_question_elapsed_time':'float32',
                         'prior_question_had_explanation':'boolean'})

train = train[train.content_type_id == False]

train = train.sort_values(['timestamp'],ascending=True).reset_index(drop=True)
train.head(10)

Unnamed: 0,timestamp,user_id,content_id,content_type_id,answered_correctly,prior_question_elapsed_time,prior_question_had_explanation
0,0,115,5692,0,1,,
1,0,7022747,5558,0,0,,
2,0,7023662,4626,0,1,,
3,0,7025965,7900,0,1,,
4,0,7029547,4449,0,1,,
5,0,579346,7900,0,1,,
6,0,7039142,5458,0,1,,
7,0,581706,4565,0,1,,
8,0,7042700,7900,0,1,,
9,0,20042606,7900,0,0,,


In [7]:
question = pd.read_csv('/Users/hesu/Documents/KT/riiid/questions.csv')
question.head(10)

Unnamed: 0,question_id,bundle_id,correct_answer,part,tags
0,0,0,0,1,51 131 162 38
1,1,1,1,1,131 36 81
2,2,2,0,1,131 101 162 92
3,3,3,0,1,131 149 162 29
4,4,4,3,1,131 5 162 38
5,5,5,2,1,131 149 162 81
6,6,6,2,1,10 94 162 92
7,7,7,0,1,61 110 162 29
8,8,8,3,1,131 13 162 92
9,9,9,3,1,10 164 81


In [13]:
train_ques = pd.merge(train, question, left_on='content_id',right_on='question_id', how='left')
train_ques.drop('content_id',axis=1,inplace=True)
train_ques.head(10)

Unnamed: 0,timestamp,user_id,content_type_id,answered_correctly,prior_question_elapsed_time,prior_question_had_explanation,question_id,bundle_id,correct_answer,part,tags
0,0,115,0,1,,,5692,5692,3,5,151
1,0,7022747,0,0,,,5558,5558,1,5,125
2,0,7023662,0,1,,,4626,4626,2,5,79
3,0,7025965,0,1,,,7900,7900,0,1,131 93 81
4,0,7029547,0,1,,,4449,4449,0,5,156
5,0,579346,0,1,,,7900,7900,0,1,131 93 81
6,0,7039142,0,1,,,5458,5458,1,5,125
7,0,581706,0,1,,,4565,4565,0,5,8
8,0,7042700,0,1,,,7900,7900,0,1,131 93 81
9,0,20042606,0,0,,,7900,7900,0,1,131 93 81


In [16]:
train_ques.tail(10)

Unnamed: 0,timestamp,user_id,content_type_id,answered_correctly,prior_question_elapsed_time,prior_question_had_explanation,question_id,bundle_id,correct_answer,part,tags
980082,76809860397,4508124,0,1,35000.0,True,3016,3015,0,4,157 171 92
980083,76809860397,4508124,0,1,35000.0,True,3015,3015,2,4,136 171 92
980084,76810038254,4508124,0,1,32666.0,True,3068,3066,3,4,113 12 162 38
980085,76810038254,4508124,0,0,32666.0,True,3067,3066,0,4,74 12 162 38
980086,76810038254,4508124,0,1,32666.0,True,3066,3066,1,4,106 12 162 38
980087,78091996556,4508124,0,1,28666.0,True,7398,7396,2,7,97 160 16 35 122
980088,78091996556,4508124,0,0,28666.0,True,7399,7396,0,7,97 160 16 35 122
980089,78091996556,4508124,0,1,28666.0,True,7397,7396,1,7,18 160 16 35 122
980090,78091996556,4508124,0,0,28666.0,True,7396,7396,1,7,39 160 16 35 122
980091,78091996556,4508124,0,1,28666.0,True,7400,7396,1,7,145 160 16 35 122


In [15]:
elapsed_mean = train_ques.prior_question_elapsed_time.mean()

In [17]:
train_ques['prior_question_elapsed_time'].fillna(elapsed_mean, inplace=True)
train_ques['part'].fillna(4, inplace=True)

In [19]:
train_ques.loc[:,'prior_question_elapsed_time'].value_counts()

17000.0     50744
16000.0     46949
18000.0     46550
19000.0     39580
15000.0     35889
            ...  
135200.0        1
121750.0        1
119250.0        1
150200.0        1
99333.0         1
Name: prior_question_elapsed_time, Length: 1660, dtype: int64

In [40]:
train_ques.loc[:,'part'].value_counts()

5    403239
2    190731
6    108567
3     82175
4     75997
1     69411
7     49972
Name: part, dtype: int64

In [23]:
import datetime
import time
def convert_time_to_yearMonthDay(timeStamp):
    timeStamp = timeStamp /1000.0
    timearr = time.localtime(timeStamp)
    otherStyleTime = time.strftime("%Y-%m-%d %H:%M:%S", timearr)
    print(otherStyleTime)

convert_time_to_yearMonthDay(78091996556)

1972-06-23 04:13:16


In [24]:
def get_elapsed_time(ela):
    ela = ela // 1000
    if ela > 300:
        return 300
    else:
        return ela

In [26]:
train_ques['prior_question_elapsed_time'] = train_ques['prior_question_elapsed_time'].apply(lambda x: get_elapsed_time(x))

In [29]:
train_ques.head(10)

Unnamed: 0,timestamp,user_id,content_type_id,answered_correctly,prior_question_elapsed_time,prior_question_had_explanation,question_id,bundle_id,correct_answer,part,tags
0,0,115,0,1,25.0,,5692,5692,3,5,151
1,0,7022747,0,0,25.0,,5558,5558,1,5,125
2,0,7023662,0,1,25.0,,4626,4626,2,5,79
3,0,7025965,0,1,25.0,,7900,7900,0,1,131 93 81
4,0,7029547,0,1,25.0,,4449,4449,0,5,156
5,0,579346,0,1,25.0,,7900,7900,0,1,131 93 81
6,0,7039142,0,1,25.0,,5458,5458,1,5,125
7,0,581706,0,1,25.0,,4565,4565,0,5,8
8,0,7042700,0,1,25.0,,7900,7900,0,1,131 93 81
9,0,20042606,0,0,25.0,,7900,7900,0,1,131 93 81


In [33]:
train_ques['timestamp'] = train_ques['timestamp'].astype(str)
train_ques['question_id'] = train_ques['question_id'].astype(str)
train_ques['part'] = train_ques['part'].astype(str)
train_ques['prior_question_elapsed_time'] = train_ques['prior_question_elapsed_time'].astype(str)
train_ques['answered_correctly'] = train_ques['answered_correctly'].astype(str)


In [None]:
train_ques[]

In [57]:
train_user = train_ques.groupby('user_id').agg({"question_id": ','.join, 
                                                "answered_correctly":','.join,
                                                "timestamp":','.join,
                                                "part":','.join,
                                                "prior_question_elapsed_time":','.join})

In [58]:
train_user.head(10)

Unnamed: 0_level_0,question_id,answered_correctly,timestamp,part,prior_question_elapsed_time
user_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
115,"5692,5716,128,7860,7922,156,51,50,7896,7863,15...","1,1,1,1,1,1,1,1,1,1,0,0,0,1,1,1,1,0,1,1,0,0,0,...","0,56943,118363,131167,137965,157063,176092,194...","5,5,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,...","25.0,37.0,55.0,19.0,11.0,5.0,17.0,17.0,16.0,16..."
124,"7900,7876,175,1278,2064,2065,2063,3363,3364,33...","1,0,1,0,0,1,0,0,0,0,1,0,0,0,1,0,0,0,0,1,0,0,0,...","0,32683,62000,83632,189483,189483,189483,25879...","1,1,1,2,3,3,3,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,6,...","25.0,26.0,29.0,26.0,18.0,18.0,18.0,33.0,33.0,3..."
2746,"5273,758,5976,236,404,382,405,873,531,775,294,...",0001011110110101101,"0,21592,49069,72254,91945,111621,134341,234605...",5252222222222222222,"25.0,28.0,17.0,24.0,20.0,16.0,16.0,19.0,18.0,1..."
5382,"5000,3944,217,5844,5965,4990,5235,6050,5721,55...","1,0,1,0,1,1,1,1,0,0,0,1,0,1,1,1,1,0,1,1,0,0,0,...","0,39828,132189,153727,169080,178049,274437,348...","5,5,2,5,5,5,5,5,5,5,5,1,1,1,1,1,1,1,1,1,1,1,1,...","25.0,24.0,35.0,88.0,18.0,12.0,5.0,92.0,70.0,14..."
8623,"3915,4750,6456,3968,6104,5738,6435,5498,6102,4...","1,1,1,1,1,1,1,0,0,1,1,0,0,1,1,0,1,1,1,0,0,1,1,...","0,38769,72859,116541,155537,189115,221413,2399...","5,5,5,5,5,5,5,5,5,5,5,2,2,2,2,2,5,5,5,5,5,5,5,...","25.0,16.0,33.0,30.0,40.0,35.0,30.0,29.0,15.0,1..."
8701,"3901,6671,4963,6143,8279,3964,4002,754,1110,77...",11101000100111011,"0,17833,45872,74561,121601,141679,183773,11482...",55555552222222222,"25.0,13.0,15.0,24.0,25.0,44.0,17.0,39.0,16.0,1..."
12741,"5145,9691,9697,5202,4787,5695,7858,5653,5889,4...","0,1,0,1,1,0,1,0,1,0,0,0,1,1,1,0,1,0,0,1,1,0,0,...","0,22273,54323,92046,109716,132679,158477,18403...","5,5,5,5,5,5,1,5,5,5,5,6,6,6,6,6,6,6,6,6,6,6,6,...","25.0,13.0,18.0,29.0,35.0,15.0,21.0,23.0,23.0,3..."
13134,"3926,564,3865,4231,3684,3988,3968,5219,4447,61...","1,0,0,1,1,0,0,1,1,0,1,1,1,0,1,1,0,1,0,1,0,1,1,...","0,23840,46834,64749,113000,183369,218217,29783...","5,2,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,2,2,2,2,2,2,...","25.0,22.0,18.0,19.0,13.0,43.0,65.0,31.0,5.0,17..."
24418,"7900,7876,175,1278,2063,2065,2064,3363,3364,33...","0,0,1,1,0,0,0,0,1,1,0,0,1,0,0,0,1,0,1,0,0,1,1,...","0,24224,51020,70540,88142,88142,88142,100241,1...","1,1,1,2,3,3,3,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,6,...","25.0,30.0,20.0,24.0,17.0,17.0,17.0,4.0,4.0,4.0..."
24600,"7900,7876,175,1278,2063,2065,2064,3365,3363,33...","1,0,1,1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,...","0,25379,50137,70181,148601,148601,148601,21935...","1,1,1,2,3,3,3,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,6,...","25.0,24.0,23.0,22.0,18.0,18.0,18.0,24.0,24.0,2..."


In [59]:
train_user.shape

(3824, 5)

In [60]:
type(train_user)

pandas.core.frame.DataFrame

In [61]:
train_user

Unnamed: 0_level_0,question_id,answered_correctly,timestamp,part,prior_question_elapsed_time
user_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
115,"5692,5716,128,7860,7922,156,51,50,7896,7863,15...","1,1,1,1,1,1,1,1,1,1,0,0,0,1,1,1,1,0,1,1,0,0,0,...","0,56943,118363,131167,137965,157063,176092,194...","5,5,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,...","25.0,37.0,55.0,19.0,11.0,5.0,17.0,17.0,16.0,16..."
124,"7900,7876,175,1278,2064,2065,2063,3363,3364,33...","1,0,1,0,0,1,0,0,0,0,1,0,0,0,1,0,0,0,0,1,0,0,0,...","0,32683,62000,83632,189483,189483,189483,25879...","1,1,1,2,3,3,3,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,6,...","25.0,26.0,29.0,26.0,18.0,18.0,18.0,33.0,33.0,3..."
2746,"5273,758,5976,236,404,382,405,873,531,775,294,...",0001011110110101101,"0,21592,49069,72254,91945,111621,134341,234605...",5252222222222222222,"25.0,28.0,17.0,24.0,20.0,16.0,16.0,19.0,18.0,1..."
5382,"5000,3944,217,5844,5965,4990,5235,6050,5721,55...","1,0,1,0,1,1,1,1,0,0,0,1,0,1,1,1,1,0,1,1,0,0,0,...","0,39828,132189,153727,169080,178049,274437,348...","5,5,2,5,5,5,5,5,5,5,5,1,1,1,1,1,1,1,1,1,1,1,1,...","25.0,24.0,35.0,88.0,18.0,12.0,5.0,92.0,70.0,14..."
8623,"3915,4750,6456,3968,6104,5738,6435,5498,6102,4...","1,1,1,1,1,1,1,0,0,1,1,0,0,1,1,0,1,1,1,0,0,1,1,...","0,38769,72859,116541,155537,189115,221413,2399...","5,5,5,5,5,5,5,5,5,5,5,2,2,2,2,2,5,5,5,5,5,5,5,...","25.0,16.0,33.0,30.0,40.0,35.0,30.0,29.0,15.0,1..."
...,...,...,...,...,...
20913319,"6659,5675,3841,5299,5254,4706,5318,6051,174,78...","0,1,0,1,0,0,0,0,1,0,1,1,1,1,1,1,1,0,1,0,0,1,1,...","0,13518,35768,64516,86907,111406,130852,367471...","5,5,5,5,5,5,5,5,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,...","25.0,27.0,10.0,20.0,26.0,20.0,22.0,16.0,12.0,2..."
20913864,"4790,4422,9200,3644,9418,9805,10405,6659,6286,...",100100001110001000100,"0,29051,50530,60217,79747,98008,121888,158226,...",555555155225555555555,"25.0,9.0,21.0,18.0,6.0,16.0,15.0,21.0,33.0,40...."
20938253,"7900,7876,175,1278,2065,2063,2064,3365,3364,33...","0,1,1,1,1,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,1,...","0,4124,115985,130714,149045,149045,149045,1608...","1,1,1,2,3,3,3,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,6,...","25.0,2.0,1.0,4.0,3.0,3.0,3.0,3.0,3.0,3.0,1.0,1..."
20948951,"6040,6444,8933,8537,10471,9236,4707,9353,8969,...","0,1,1,0,1,0,1,0,0,1,0,1,0,0,0,0,1,0,1,1,1,0,1,...","0,24764,45950,71359,95527,120065,145390,172145...","5,5,5,5,1,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,...","25.0,9.0,18.0,18.0,23.0,21.0,22.0,23.0,24.0,24..."


In [62]:
train_user.reset_index(inplace=True)

In [63]:
train_user

Unnamed: 0,user_id,question_id,answered_correctly,timestamp,part,prior_question_elapsed_time
0,115,"5692,5716,128,7860,7922,156,51,50,7896,7863,15...","1,1,1,1,1,1,1,1,1,1,0,0,0,1,1,1,1,0,1,1,0,0,0,...","0,56943,118363,131167,137965,157063,176092,194...","5,5,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,...","25.0,37.0,55.0,19.0,11.0,5.0,17.0,17.0,16.0,16..."
1,124,"7900,7876,175,1278,2064,2065,2063,3363,3364,33...","1,0,1,0,0,1,0,0,0,0,1,0,0,0,1,0,0,0,0,1,0,0,0,...","0,32683,62000,83632,189483,189483,189483,25879...","1,1,1,2,3,3,3,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,6,...","25.0,26.0,29.0,26.0,18.0,18.0,18.0,33.0,33.0,3..."
2,2746,"5273,758,5976,236,404,382,405,873,531,775,294,...",0001011110110101101,"0,21592,49069,72254,91945,111621,134341,234605...",5252222222222222222,"25.0,28.0,17.0,24.0,20.0,16.0,16.0,19.0,18.0,1..."
3,5382,"5000,3944,217,5844,5965,4990,5235,6050,5721,55...","1,0,1,0,1,1,1,1,0,0,0,1,0,1,1,1,1,0,1,1,0,0,0,...","0,39828,132189,153727,169080,178049,274437,348...","5,5,2,5,5,5,5,5,5,5,5,1,1,1,1,1,1,1,1,1,1,1,1,...","25.0,24.0,35.0,88.0,18.0,12.0,5.0,92.0,70.0,14..."
4,8623,"3915,4750,6456,3968,6104,5738,6435,5498,6102,4...","1,1,1,1,1,1,1,0,0,1,1,0,0,1,1,0,1,1,1,0,0,1,1,...","0,38769,72859,116541,155537,189115,221413,2399...","5,5,5,5,5,5,5,5,5,5,5,2,2,2,2,2,5,5,5,5,5,5,5,...","25.0,16.0,33.0,30.0,40.0,35.0,30.0,29.0,15.0,1..."
...,...,...,...,...,...,...
3819,20913319,"6659,5675,3841,5299,5254,4706,5318,6051,174,78...","0,1,0,1,0,0,0,0,1,0,1,1,1,1,1,1,1,0,1,0,0,1,1,...","0,13518,35768,64516,86907,111406,130852,367471...","5,5,5,5,5,5,5,5,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,...","25.0,27.0,10.0,20.0,26.0,20.0,22.0,16.0,12.0,2..."
3820,20913864,"4790,4422,9200,3644,9418,9805,10405,6659,6286,...",100100001110001000100,"0,29051,50530,60217,79747,98008,121888,158226,...",555555155225555555555,"25.0,9.0,21.0,18.0,6.0,16.0,15.0,21.0,33.0,40...."
3821,20938253,"7900,7876,175,1278,2065,2063,2064,3365,3364,33...","0,1,1,1,1,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,1,...","0,4124,115985,130714,149045,149045,149045,1608...","1,1,1,2,3,3,3,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,6,...","25.0,2.0,1.0,4.0,3.0,3.0,3.0,3.0,3.0,3.0,1.0,1..."
3822,20948951,"6040,6444,8933,8537,10471,9236,4707,9353,8969,...","0,1,1,0,1,0,1,0,0,1,0,1,0,0,0,0,1,0,1,1,1,0,1,...","0,24764,45950,71359,95527,120065,145390,172145...","5,5,5,5,1,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,...","25.0,9.0,18.0,18.0,23.0,21.0,22.0,23.0,24.0,24..."


In [71]:
def get_data_for_train(train_user, seq_len):
    all_ques_seq = []
    all_ans_seq = []
    all_parts_seq = []
    all_ela_seq = []
    
    for row in train_user.itertuples():
        q_ids = getattr(row, 'question_id').strip().split(',')
        ans_ids = getattr(row, 'answered_correctly').strip().split(',')
        part_ids = getattr(row, 'part').strip().split(',')
        ela_ids = getattr(row, 'prior_question_elapsed_time').strip().split(',')
        
        assert len(q_ids) == len(ans_ids) == len(part_ids) == len(ela_ids)
        for target_index in range(0, len(q_ids)):
            q_ids_seq = q_ids[:target_index+1]
            ans_ids_seq = ans_ids[:target_index+1]
            part_ids_seq = part_ids[:target_index+1]
            ela_ids_seq = ela_ids[:target_index+1]
            
            length = len(q_ids_seq)
            if length >= seq_len:
                q_ids_seq = q_ids_seq[-seq_len:]
                ans_ids_seq = ans_ids_seq[-seq_len:]
                part_ids_seq = part_ids_seq[-seq_len:]
                ela_ids_seq = ela_ids_seq[-seq_len:]  
                
                pad_counts = 0
            else:
                pad_counts = seq_len - length
            
            q_ids_seq = [int(float(e)) for e in q_ids_seq]
            ans_ids_seq = [int(float(e)) for e in ans_ids_seq]
            part_ids_seq = [int(float(e)) for e in part_ids_seq]
            ela_ids_seq = [int(float(e)) for e in ela_ids_seq]
            
            q_ids_seq = [13523]*pad_counts + q_ids_seq
            # question用13523表示padding位
            ans_ids_seq = [2]*pad_counts + ans_ids_seq
            # ans用2表示padding位
            part_ids_seq = [8]*pad_counts + part_ids_seq
            # part用8来表示padding位
            ela_ids_seq = [301]*pad_counts + ela_ids_seq
            # ela用301来表示padding位
            
            all_ques_seq.append(q_ids_seq)
            all_ans_seq.append(ans_ids_seq)
            all_parts_seq.append(part_ids_seq)
            all_ela_seq.append(ela_ids_seq)
    return torch.LongTensor(all_ques_seq),\
        torch.LongTensor(all_ans_seq),\
        torch.LongTensor(all_parts_seq),\
        torch.LongTensor(all_ela_seq)
            
            

In [75]:
def get_decode_mask(trg, trg_pad_idx):
    trg_pad_mask = (trg != trg_pad_idx).unsqueeze(1).unsqueeze(2)
    
    trg_len = trg.shape[1]
    trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = 'cpu')).bool()
    trg_mask = trg_pad_mask & trg_sub_mask
    
    return trg_mask

In [76]:
a = torch.tensor([[0,0,1,1,1],[0,1,1,1,1],[1,1,1,1,1]])

In [77]:
trg_mask = get_decode_mask(a, 0)

In [78]:
trg_mask

tensor([[[[False, False, False, False, False],
          [False, False, False, False, False],
          [False, False,  True, False, False],
          [False, False,  True,  True, False],
          [False, False,  True,  True,  True]]],


        [[[False, False, False, False, False],
          [False,  True, False, False, False],
          [False,  True,  True, False, False],
          [False,  True,  True,  True, False],
          [False,  True,  True,  True,  True]]],


        [[[ True, False, False, False, False],
          [ True,  True, False, False, False],
          [ True,  True,  True, False, False],
          [ True,  True,  True,  True, False],
          [ True,  True,  True,  True,  True]]]])

In [79]:
a

tensor([[0, 0, 1, 1, 1],
        [0, 1, 1, 1, 1],
        [1, 1, 1, 1, 1]])

In [80]:
def get_encode_mask(src,src_pad_idx):
    src_mask = (src != src_pad_idx).unsqueeze(1).unsqueeze(2)
    
    return src_mask

In [81]:
src_mask = get_encode_mask(a,0)

In [82]:
src_mask

tensor([[[[False, False,  True,  True,  True]]],


        [[[False,  True,  True,  True,  True]]],


        [[[ True,  True,  True,  True,  True]]]])

In [83]:
src_mask.shape

torch.Size([3, 1, 1, 5])

In [84]:
class Rii_dataset_train(Dataset):
    def __init__(self,train_user):
        self.df = train_user
        self.ques_seq, self.ans_seq, self.parts_seq, self.ela_seq = get_data_for_train(self.df, 100)
    def __len__(self):
        return len(self.ques_seq)
    def __getitem__(self, index):
        return self.ques_seq[index], self.ans_seq[index], self.parts_seq[index], self.ela_seq[index]

## Model

In [85]:
class Encoder(nn.Module):
    def __init__(self, 
                 que_num,
                 part_num,
                 ela_num,
                 hid_dim, 
                 n_layers, 
                 n_heads, 
                 pf_dim,
                 dropout, 
                 device,
                 max_length = 100):
        super().__init__()

        self.device = device
        
        self.que_embedding = nn.Embedding(que_num, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        self.part_embedding = nn.Embedding(part_num, hid_dim)
        self.ela_embedding = nn.Embedding(ela_num, hid_dim)
        
        self.layers = nn.ModuleList([EncoderLayer(hid_dim, 
                                                  n_heads, 
                                                  pf_dim,
                                                  dropout, 
                                                  device) 
                                     for _ in range(n_layers)])
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        
        
    def forward(self, src_que,src_part,src_ela,src_mask):
        
        #src = [batch size, src len]
        #src_mask = [batch size, src len]
        
        batch_size = src_que.shape[0]
        src_len = src_que.shape[1]
        
        pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
        
        que_emb = self.que_embedding(src_que)
        part_emb = self.part_embedding(src_part)
        ela_emb = self.ela_embedding(src_ela)
        tok_emb = que_emb+part_emb+ela_emb
        
        src = self.dropout((tok_emb * self.scale) + self.pos_embedding(pos))
        
        #src = [batch size, src len, hid dim]
        
        for layer in self.layers:
            src = layer(src, src_mask)
            
        #src = [batch size, src len, hid dim]
            
        return src
        

In [86]:
class EncoderLayer(nn.Module):
    def __init__(self, 
                 hid_dim, 
                 n_heads, 
                 pf_dim,  
                 dropout, 
                 device):
        super().__init__()
        
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.ff_layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, 
                                                                     pf_dim, 
                                                                     dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src, src_mask):
        
        #src = [batch size, src len, hid dim]
        #src_mask = [batch size, src len]
                
        #self attention
        _src, _ = self.self_attention(src, src, src, src_mask)
        
        #dropout, residual connection and layer norm
        src = self.self_attn_layer_norm(src + self.dropout(_src))
        
        #src = [batch size, src len, hid dim]
        
        #positionwise feedforward
        _src = self.positionwise_feedforward(src)
        
        #dropout, residual and layer norm
        src = self.ff_layer_norm(src + self.dropout(_src))
        
        #src = [batch size, src len, hid dim]
        
        return src

In [87]:
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, device):
        super().__init__()
        
        assert hid_dim % n_heads == 0
        
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        
        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)
        
        self.fc_o = nn.Linear(hid_dim, hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
        
    def forward(self, query, key, value, mask = None):
        
        batch_size = query.shape[0]
        
        #query = [batch size, query len, hid dim]
        #key = [batch size, key len, hid dim]
        #value = [batch size, value len, hid dim]
                
        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)
        
        #Q = [batch size, query len, hid dim]
        #K = [batch size, key len, hid dim]
        #V = [batch size, value len, hid dim]
                
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        
        
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
        
        #energy = [batch size, n heads, query len, key len]
        
        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)
        
        attention = torch.softmax(energy, dim = -1)
                
        #attention = [batch size, n heads, query len, key len]
                
        x = torch.matmul(self.dropout(attention), V)
        
        #x = [batch size, n heads, query len, head dim]
        
        x = x.permute(0, 2, 1, 3).contiguous()
        
        #x = [batch size, query len, n heads, head dim]
        
        x = x.view(batch_size, -1, self.hid_dim)
        
        #x = [batch size, query len, hid dim]
        
        x = self.fc_o(x)
        
        #x = [batch size, query len, hid dim]
        
        return x, attention        

In [88]:
class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self, hid_dim, pf_dim, dropout):
        super().__init__()
        
        self.fc_1 = nn.Linear(hid_dim, pf_dim)
        self.fc_2 = nn.Linear(pf_dim, hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        
        #x = [batch size, seq len, hid dim]
        
        x = self.dropout(torch.relu(self.fc_1(x)))
        
        #x = [batch size, seq len, pf dim]
        
        x = self.fc_2(x)
        
        #x = [batch size, seq len, hid dim]
        
        return x

In [89]:
class Decoder(nn.Module):
    def __init__(self, 
                 output_dim, 
                 hid_dim, 
                 n_layers, 
                 n_heads, 
                 pf_dim, 
                 dropout, 
                 device,
                 max_length = 100):
        super().__init__()
        
        self.device = device
        
        self.tok_embedding = nn.Embedding(output_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        
        self.layers = nn.ModuleList([DecoderLayer(hid_dim, 
                                                  n_heads, 
                                                  pf_dim, 
                                                  dropout, 
                                                  device)
                                     for _ in range(n_layers)])
        
        self.fc_out = nn.Linear(hid_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        
    def forward(self, trg, enc_src, trg_mask, src_mask):
        
        #trg = [batch size, trg len]
        #enc_src = [batch size, src len, hid dim]
        #trg_mask = [batch size, trg len]
        #src_mask = [batch size, src len]
                
        batch_size = trg.shape[0]
        trg_len = trg.shape[1]
        
        pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
                            
        #pos = [batch size, trg len]
            
        trg = self.dropout((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos))
                
        #trg = [batch size, trg len, hid dim]
        
        for layer in self.layers:
            trg, attention = layer(trg, enc_src, trg_mask, src_mask)
        
        #trg = [batch size, trg len, hid dim]
        #attention = [batch size, n heads, trg len, src len]
        
        output = self.fc_out(trg)
        
        #output = [batch size, trg len, output dim]
            
        return output, attention        

In [90]:
class DecoderLayer(nn.Module):
    def __init__(self, 
                 hid_dim, 
                 n_heads, 
                 pf_dim, 
                 dropout, 
                 device):
        super().__init__()
        
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.enc_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.ff_layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, 
                                                                     pf_dim, 
                                                                     dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, trg, enc_src, trg_mask, src_mask):
        
        #trg = [batch size, trg len, hid dim]
        #enc_src = [batch size, src len, hid dim]
        #trg_mask = [batch size, trg len]
        #src_mask = [batch size, src len]
        
        #self attention
        _trg, _ = self.self_attention(trg, trg, trg, trg_mask)
        
        #dropout, residual connection and layer norm
        trg = self.self_attn_layer_norm(trg + self.dropout(_trg))
            
        #trg = [batch size, trg len, hid dim]
        
        #encoder attention
        _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask)
        
        #dropout, residual connection and layer norm
        trg = self.enc_attn_layer_norm(trg + self.dropout(_trg))
                    
        #trg = [batch size, trg len, hid dim]
        
        #positionwise feedforward
        _trg = self.positionwise_feedforward(trg)
        
        #dropout, residual and layer norm
        trg = self.ff_layer_norm(trg + self.dropout(_trg))
        
        #trg = [batch size, trg len, hid dim]
        #attention = [batch size, n heads, trg len, src len]
        
        return trg, attention        
        

# Seq2Seq

In [91]:
class Seq2Seq(nn.Module):
    def __init__(self, 
                 encoder, 
                 decoder, 
                 src_pad_idx, 
                 trg_pad_idx, 
                 device):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device
        
    def make_src_mask(self, src):
        
        #src = [batch size, src len]
        
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)

        #src_mask = [batch size, 1, 1, src len]

        return src_mask
    
    def make_trg_mask(self, trg):
        
        #trg = [batch size, trg len]
        
        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
        
        #trg_pad_mask = [batch size, 1, 1, trg len]
        
        trg_len = trg.shape[1]
        
        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = self.device)).bool()
        
        #trg_sub_mask = [trg len, trg len]
            
        trg_mask = trg_pad_mask & trg_sub_mask
        
        #trg_mask = [batch size, 1, trg len, trg len]
        
        return trg_mask

    def forward(self, src_que,src_part,src_ela, trg_ans):
        
        #src = [batch size, src len]
        #trg = [batch size, trg len]
            
        src_mask = self.make_trg_mask(src_que)
        trg_mask = self.make_trg_mask(trg_ans)
        
        #src_mask = [batch size, 1, 1, src len]
        #trg_mask = [batch size, 1, trg len, trg len]
        
        enc_src = self.encoder(src_que,src_part,src_ela,src_mask)
        
        #enc_src = [batch size, src len, hid dim]
                
        output, attention = self.decoder(trg_ans, enc_src, trg_mask, src_mask)
        
        #output = [batch size, trg len, output dim]
        #attention = [batch size, n heads, trg len, src len]
        
        return output, attention        

In [None]:
que_num = 13524
ans_num = 3
part_num = 9
ela_num = 302

HID_DIM = 256
ENC_LAYERS = 3
DEC_LAYERS = 3
ENC_HEADS = 8
DEC_HEADS = 8
ENC_PF_DIM = 512
DEC_PF_DIM = 512
ENC_DROPOUT = 0.1
DEC_DROPOUT = 0.1



enc = Encoder(que_num,part_num,ela_num,
              HID_DIM, 
              ENC_LAYERS, 
              ENC_HEADS, 
              ENC_PF_DIM, 
              ENC_DROPOUT, 
              device)

dec = Decoder(ans_num, 
              HID_DIM, 
              DEC_LAYERS, 
              DEC_HEADS, 
              DEC_PF_DIM, 
              DEC_DROPOUT, 
              device)