In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torch.nn.functional as F


import random
import math
import time
import pandas as pd
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from torch.utils.data import DataLoader

In [2]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [5]:
train = pd.read_csv('/Users/hesu/Documents/KT/riiid/train_1M.csv',
                   usecols = [1,2,3,4,7,8,9],
                   dtype={'timestamp':'int64',
                         'used_id':'int32',
                         'content_id':'int16',
                         'content_type_id':'int8',
                         'answered_correctly':'int8',
                         'prior_question_elapsed_time':'float32',
                         'prior_question_had_explanation':'boolean'})

train = train[train.content_type_id == False]

train = train.sort_values(['timestamp'],ascending=True).reset_index(drop=True)
train.head(10)

Unnamed: 0,user_id,content_id_seq,answered_correctly
0,115,"5693,5717,129,7861,7923,157,52,51,7897,7864,13...","1,1,1,1,1,1,1,1,1,1,0,0,0,1,1,1,1,0,1,1,0,0,0,..."
1,124,"7901,21400,176,14802,2066,15588,15587,16888,16...",101010000010010000010000
2,2746,"18797,14282,19500,237,13928,383,406,874,532,14...",0001011110110
3,5382,"997,902,262,447,380,829,452,6439,18863,5413,91...","1,1,1,1,1,1,1,1,0,1,1,0,1,0,1,1,1,0,0,1,0,1,0,..."
4,8623,"13585,19,10596,21488,7914,195,23929,7975,131,7...","0,1,1,0,1,1,0,1,1,1,0,1,1,0,1,1,0,0,0,0,1,0,1,..."
5,8701,"3902,6672,4964,19667,8280,17488,17526,14278,11...",11101000100
6,12741,"24224,10699,10700,16152,2627,2628,24250,10726,...","0,1,1,0,1,1,0,1,0,0,1,1,0,1,0,1,1,0,1,1,1,1,1,..."
7,13134,"3116,3350,3349,3351,13063,13064,13062,3178,167...","1,1,1,1,1,1,1,1,0,0,1,1,1,0,1,1,1,1,1,1,1,1,0,..."
8,24418,"5045,8247,5736,17300,3777,6136,5620,4128,9128,...","1,1,1,0,1,1,1,1,1,1,1,0,1,0,1,1,1,1,1,1,0,1,1,..."
9,24600,"15587,2066,15588,16888,16887,16889,16471,16472...","0,1,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,..."


In [4]:
question = pd.read_csv('/Users/hesu/Documents/KT/riiid/questions.csv')
question.head(10)

Unnamed: 0,question_id,bundle_id,correct_answer,part,tags
0,0,0,0,1,51 131 162 38
1,1,1,1,1,131 36 81
2,2,2,0,1,131 101 162 92
3,3,3,0,1,131 149 162 29
4,4,4,3,1,131 5 162 38
5,5,5,2,1,131 149 162 81
6,6,6,2,1,10 94 162 92
7,7,7,0,1,61 110 162 29
8,8,8,3,1,131 13 162 92
9,9,9,3,1,10 164 81


In [5]:
train_ques = pd.merge(train, question, left_on='content_id',right_on='question_id', how='left')
train_ques.drop('content_id',axis=1,inplace=True)
train_ques.head(10)

Unnamed: 0,timestamp,user_id,content_type_id,answered_correctly,prior_question_elapsed_time,prior_question_had_explanation,question_id,bundle_id,correct_answer,part,tags
0,0,115,0,1,,,5692,5692,3,5,151
1,0,7022747,0,0,,,5558,5558,1,5,125
2,0,7023662,0,1,,,4626,4626,2,5,79
3,0,7025965,0,1,,,7900,7900,0,1,131 93 81
4,0,7029547,0,1,,,4449,4449,0,5,156
5,0,579346,0,1,,,7900,7900,0,1,131 93 81
6,0,7039142,0,1,,,5458,5458,1,5,125
7,0,581706,0,1,,,4565,4565,0,5,8
8,0,7042700,0,1,,,7900,7900,0,1,131 93 81
9,0,20042606,0,0,,,7900,7900,0,1,131 93 81


In [6]:
train_ques.tail(10)

Unnamed: 0,timestamp,user_id,content_type_id,answered_correctly,prior_question_elapsed_time,prior_question_had_explanation,question_id,bundle_id,correct_answer,part,tags
980082,76809860397,4508124,0,1,35000.0,True,3016,3015,0,4,157 171 92
980083,76809860397,4508124,0,1,35000.0,True,3015,3015,2,4,136 171 92
980084,76810038254,4508124,0,1,32666.0,True,3068,3066,3,4,113 12 162 38
980085,76810038254,4508124,0,0,32666.0,True,3067,3066,0,4,74 12 162 38
980086,76810038254,4508124,0,1,32666.0,True,3066,3066,1,4,106 12 162 38
980087,78091996556,4508124,0,1,28666.0,True,7398,7396,2,7,97 160 16 35 122
980088,78091996556,4508124,0,0,28666.0,True,7399,7396,0,7,97 160 16 35 122
980089,78091996556,4508124,0,1,28666.0,True,7397,7396,1,7,18 160 16 35 122
980090,78091996556,4508124,0,0,28666.0,True,7396,7396,1,7,39 160 16 35 122
980091,78091996556,4508124,0,1,28666.0,True,7400,7396,1,7,145 160 16 35 122


In [7]:
elapsed_mean = train_ques.prior_question_elapsed_time.mean()

In [8]:
train_ques['prior_question_elapsed_time'].fillna(elapsed_mean, inplace=True)
train_ques['part'].fillna(4, inplace=True)

In [9]:
train_ques.loc[:,'prior_question_elapsed_time'].value_counts()

17000.0     50744
16000.0     46949
18000.0     46550
19000.0     39580
15000.0     35889
            ...  
135200.0        1
121750.0        1
119250.0        1
150200.0        1
99333.0         1
Name: prior_question_elapsed_time, Length: 1660, dtype: int64

In [10]:
train_ques.loc[:,'part'].value_counts()

5    403239
2    190731
6    108567
3     82175
4     75997
1     69411
7     49972
Name: part, dtype: int64

In [11]:
import datetime
import time
def convert_time_to_yearMonthDay(timeStamp):
    timeStamp = timeStamp /1000.0
    timearr = time.localtime(timeStamp)
    otherStyleTime = time.strftime("%Y-%m-%d %H:%M:%S", timearr)
    print(otherStyleTime)

convert_time_to_yearMonthDay(78091996556)

1972-06-23 04:13:16


In [12]:
def get_elapsed_time(ela):
    ela = ela // 1000
    if ela > 300:
        return 300
    else:
        return ela

In [13]:
train_ques['prior_question_elapsed_time'] = train_ques['prior_question_elapsed_time'].apply(lambda x: get_elapsed_time(x))

In [14]:
train_ques.head(10)

Unnamed: 0,timestamp,user_id,content_type_id,answered_correctly,prior_question_elapsed_time,prior_question_had_explanation,question_id,bundle_id,correct_answer,part,tags
0,0,115,0,1,25.0,,5692,5692,3,5,151
1,0,7022747,0,0,25.0,,5558,5558,1,5,125
2,0,7023662,0,1,25.0,,4626,4626,2,5,79
3,0,7025965,0,1,25.0,,7900,7900,0,1,131 93 81
4,0,7029547,0,1,25.0,,4449,4449,0,5,156
5,0,579346,0,1,25.0,,7900,7900,0,1,131 93 81
6,0,7039142,0,1,25.0,,5458,5458,1,5,125
7,0,581706,0,1,25.0,,4565,4565,0,5,8
8,0,7042700,0,1,25.0,,7900,7900,0,1,131 93 81
9,0,20042606,0,0,25.0,,7900,7900,0,1,131 93 81


In [15]:
train_ques['timestamp'] = train_ques['timestamp'].astype(str)
train_ques['question_id'] = train_ques['question_id'].astype(str)
train_ques['part'] = train_ques['part'].astype(str)
train_ques['prior_question_elapsed_time'] = train_ques['prior_question_elapsed_time'].astype(str)
train_ques['answered_correctly'] = train_ques['answered_correctly'].astype(str)


In [16]:
train_user = train_ques.groupby('user_id').agg({"question_id": ','.join, 
                                                "answered_correctly":','.join,
                                                "timestamp":','.join,
                                                "part":','.join,
                                                "prior_question_elapsed_time":','.join})

In [17]:
train_user.head(10)

Unnamed: 0_level_0,question_id,answered_correctly,timestamp,part,prior_question_elapsed_time
user_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
115,"5692,5716,128,7860,7922,156,51,50,7896,7863,15...","1,1,1,1,1,1,1,1,1,1,0,0,0,1,1,1,1,0,1,1,0,0,0,...","0,56943,118363,131167,137965,157063,176092,194...","5,5,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,...","25.0,37.0,55.0,19.0,11.0,5.0,17.0,17.0,16.0,16..."
124,"7900,7876,175,1278,2064,2065,2063,3363,3364,33...","1,0,1,0,0,1,0,0,0,0,1,0,0,0,1,0,0,0,0,1,0,0,0,...","0,32683,62000,83632,189483,189483,189483,25879...","1,1,1,2,3,3,3,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,6,...","25.0,26.0,29.0,26.0,18.0,18.0,18.0,33.0,33.0,3..."
2746,"5273,758,5976,236,404,382,405,873,531,775,294,...",0001011110110101101,"0,21592,49069,72254,91945,111621,134341,234605...",5252222222222222222,"25.0,28.0,17.0,24.0,20.0,16.0,16.0,19.0,18.0,1..."
5382,"5000,3944,217,5844,5965,4990,5235,6050,5721,55...","1,0,1,0,1,1,1,1,0,0,0,1,0,1,1,1,1,0,1,1,0,0,0,...","0,39828,132189,153727,169080,178049,274437,348...","5,5,2,5,5,5,5,5,5,5,5,1,1,1,1,1,1,1,1,1,1,1,1,...","25.0,24.0,35.0,88.0,18.0,12.0,5.0,92.0,70.0,14..."
8623,"3915,4750,6456,3968,6104,5738,6435,5498,6102,4...","1,1,1,1,1,1,1,0,0,1,1,0,0,1,1,0,1,1,1,0,0,1,1,...","0,38769,72859,116541,155537,189115,221413,2399...","5,5,5,5,5,5,5,5,5,5,5,2,2,2,2,2,5,5,5,5,5,5,5,...","25.0,16.0,33.0,30.0,40.0,35.0,30.0,29.0,15.0,1..."
8701,"3901,6671,4963,6143,8279,3964,4002,754,1110,77...",11101000100111011,"0,17833,45872,74561,121601,141679,183773,11482...",55555552222222222,"25.0,13.0,15.0,24.0,25.0,44.0,17.0,39.0,16.0,1..."
12741,"5145,9691,9697,5202,4787,5695,7858,5653,5889,4...","0,1,0,1,1,0,1,0,1,0,0,0,1,1,1,0,1,0,0,1,1,0,0,...","0,22273,54323,92046,109716,132679,158477,18403...","5,5,5,5,5,5,1,5,5,5,5,6,6,6,6,6,6,6,6,6,6,6,6,...","25.0,13.0,18.0,29.0,35.0,15.0,21.0,23.0,23.0,3..."
13134,"3926,564,3865,4231,3684,3988,3968,5219,4447,61...","1,0,0,1,1,0,0,1,1,0,1,1,1,0,1,1,0,1,0,1,0,1,1,...","0,23840,46834,64749,113000,183369,218217,29783...","5,2,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,2,2,2,2,2,2,...","25.0,22.0,18.0,19.0,13.0,43.0,65.0,31.0,5.0,17..."
24418,"7900,7876,175,1278,2063,2065,2064,3363,3364,33...","0,0,1,1,0,0,0,0,1,1,0,0,1,0,0,0,1,0,1,0,0,1,1,...","0,24224,51020,70540,88142,88142,88142,100241,1...","1,1,1,2,3,3,3,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,6,...","25.0,30.0,20.0,24.0,17.0,17.0,17.0,4.0,4.0,4.0..."
24600,"7900,7876,175,1278,2063,2065,2064,3365,3363,33...","1,0,1,1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,...","0,25379,50137,70181,148601,148601,148601,21935...","1,1,1,2,3,3,3,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,6,...","25.0,24.0,23.0,22.0,18.0,18.0,18.0,24.0,24.0,2..."


In [18]:
train_user.shape

(3824, 5)

In [19]:
type(train_user)

pandas.core.frame.DataFrame

In [20]:
train_user

Unnamed: 0_level_0,question_id,answered_correctly,timestamp,part,prior_question_elapsed_time
user_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
115,"5692,5716,128,7860,7922,156,51,50,7896,7863,15...","1,1,1,1,1,1,1,1,1,1,0,0,0,1,1,1,1,0,1,1,0,0,0,...","0,56943,118363,131167,137965,157063,176092,194...","5,5,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,...","25.0,37.0,55.0,19.0,11.0,5.0,17.0,17.0,16.0,16..."
124,"7900,7876,175,1278,2064,2065,2063,3363,3364,33...","1,0,1,0,0,1,0,0,0,0,1,0,0,0,1,0,0,0,0,1,0,0,0,...","0,32683,62000,83632,189483,189483,189483,25879...","1,1,1,2,3,3,3,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,6,...","25.0,26.0,29.0,26.0,18.0,18.0,18.0,33.0,33.0,3..."
2746,"5273,758,5976,236,404,382,405,873,531,775,294,...",0001011110110101101,"0,21592,49069,72254,91945,111621,134341,234605...",5252222222222222222,"25.0,28.0,17.0,24.0,20.0,16.0,16.0,19.0,18.0,1..."
5382,"5000,3944,217,5844,5965,4990,5235,6050,5721,55...","1,0,1,0,1,1,1,1,0,0,0,1,0,1,1,1,1,0,1,1,0,0,0,...","0,39828,132189,153727,169080,178049,274437,348...","5,5,2,5,5,5,5,5,5,5,5,1,1,1,1,1,1,1,1,1,1,1,1,...","25.0,24.0,35.0,88.0,18.0,12.0,5.0,92.0,70.0,14..."
8623,"3915,4750,6456,3968,6104,5738,6435,5498,6102,4...","1,1,1,1,1,1,1,0,0,1,1,0,0,1,1,0,1,1,1,0,0,1,1,...","0,38769,72859,116541,155537,189115,221413,2399...","5,5,5,5,5,5,5,5,5,5,5,2,2,2,2,2,5,5,5,5,5,5,5,...","25.0,16.0,33.0,30.0,40.0,35.0,30.0,29.0,15.0,1..."
...,...,...,...,...,...
20913319,"6659,5675,3841,5299,5254,4706,5318,6051,174,78...","0,1,0,1,0,0,0,0,1,0,1,1,1,1,1,1,1,0,1,0,0,1,1,...","0,13518,35768,64516,86907,111406,130852,367471...","5,5,5,5,5,5,5,5,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,...","25.0,27.0,10.0,20.0,26.0,20.0,22.0,16.0,12.0,2..."
20913864,"4790,4422,9200,3644,9418,9805,10405,6659,6286,...",100100001110001000100,"0,29051,50530,60217,79747,98008,121888,158226,...",555555155225555555555,"25.0,9.0,21.0,18.0,6.0,16.0,15.0,21.0,33.0,40...."
20938253,"7900,7876,175,1278,2065,2063,2064,3365,3364,33...","0,1,1,1,1,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,1,...","0,4124,115985,130714,149045,149045,149045,1608...","1,1,1,2,3,3,3,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,6,...","25.0,2.0,1.0,4.0,3.0,3.0,3.0,3.0,3.0,3.0,1.0,1..."
20948951,"6040,6444,8933,8537,10471,9236,4707,9353,8969,...","0,1,1,0,1,0,1,0,0,1,0,1,0,0,0,0,1,0,1,1,1,0,1,...","0,24764,45950,71359,95527,120065,145390,172145...","5,5,5,5,1,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,...","25.0,9.0,18.0,18.0,23.0,21.0,22.0,23.0,24.0,24..."


In [21]:
train_user.reset_index(inplace=True)

In [22]:
train_user

Unnamed: 0,user_id,question_id,answered_correctly,timestamp,part,prior_question_elapsed_time
0,115,"5692,5716,128,7860,7922,156,51,50,7896,7863,15...","1,1,1,1,1,1,1,1,1,1,0,0,0,1,1,1,1,0,1,1,0,0,0,...","0,56943,118363,131167,137965,157063,176092,194...","5,5,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,...","25.0,37.0,55.0,19.0,11.0,5.0,17.0,17.0,16.0,16..."
1,124,"7900,7876,175,1278,2064,2065,2063,3363,3364,33...","1,0,1,0,0,1,0,0,0,0,1,0,0,0,1,0,0,0,0,1,0,0,0,...","0,32683,62000,83632,189483,189483,189483,25879...","1,1,1,2,3,3,3,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,6,...","25.0,26.0,29.0,26.0,18.0,18.0,18.0,33.0,33.0,3..."
2,2746,"5273,758,5976,236,404,382,405,873,531,775,294,...",0001011110110101101,"0,21592,49069,72254,91945,111621,134341,234605...",5252222222222222222,"25.0,28.0,17.0,24.0,20.0,16.0,16.0,19.0,18.0,1..."
3,5382,"5000,3944,217,5844,5965,4990,5235,6050,5721,55...","1,0,1,0,1,1,1,1,0,0,0,1,0,1,1,1,1,0,1,1,0,0,0,...","0,39828,132189,153727,169080,178049,274437,348...","5,5,2,5,5,5,5,5,5,5,5,1,1,1,1,1,1,1,1,1,1,1,1,...","25.0,24.0,35.0,88.0,18.0,12.0,5.0,92.0,70.0,14..."
4,8623,"3915,4750,6456,3968,6104,5738,6435,5498,6102,4...","1,1,1,1,1,1,1,0,0,1,1,0,0,1,1,0,1,1,1,0,0,1,1,...","0,38769,72859,116541,155537,189115,221413,2399...","5,5,5,5,5,5,5,5,5,5,5,2,2,2,2,2,5,5,5,5,5,5,5,...","25.0,16.0,33.0,30.0,40.0,35.0,30.0,29.0,15.0,1..."
...,...,...,...,...,...,...
3819,20913319,"6659,5675,3841,5299,5254,4706,5318,6051,174,78...","0,1,0,1,0,0,0,0,1,0,1,1,1,1,1,1,1,0,1,0,0,1,1,...","0,13518,35768,64516,86907,111406,130852,367471...","5,5,5,5,5,5,5,5,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,...","25.0,27.0,10.0,20.0,26.0,20.0,22.0,16.0,12.0,2..."
3820,20913864,"4790,4422,9200,3644,9418,9805,10405,6659,6286,...",100100001110001000100,"0,29051,50530,60217,79747,98008,121888,158226,...",555555155225555555555,"25.0,9.0,21.0,18.0,6.0,16.0,15.0,21.0,33.0,40...."
3821,20938253,"7900,7876,175,1278,2065,2063,2064,3365,3364,33...","0,1,1,1,1,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,1,...","0,4124,115985,130714,149045,149045,149045,1608...","1,1,1,2,3,3,3,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,6,...","25.0,2.0,1.0,4.0,3.0,3.0,3.0,3.0,3.0,3.0,1.0,1..."
3822,20948951,"6040,6444,8933,8537,10471,9236,4707,9353,8969,...","0,1,1,0,1,0,1,0,0,1,0,1,0,0,0,0,1,0,1,1,1,0,1,...","0,24764,45950,71359,95527,120065,145390,172145...","5,5,5,5,1,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,...","25.0,9.0,18.0,18.0,23.0,21.0,22.0,23.0,24.0,24..."


In [276]:
train_user = train_user.rename(columns={'question_id':'question_id_seq',
                            'answered_correctly':'answered_correctly_seq',
                             'timestamp':'timestamp_seq',
                             'part':'part_seq',
                             'prior_question_elapsed_time':'prior_question_elapsed_time_seq'})

In [277]:
train_user.head(10)

Unnamed: 0,user_id,question_id_seq,answered_correctly_seq,timestamp_seq,part_seq,prior_question_elapsed_time_seq
0,115,"5692,5716,128,7860,7922,156,51,50,7896,7863,15...","1,1,1,1,1,1,1,1,1,1,0,0,0,1,1,1,1,0,1,1,0,0,0,...","0,56943,118363,131167,137965,157063,176092,194...","5,5,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,...","25.0,37.0,55.0,19.0,11.0,5.0,17.0,17.0,16.0,16..."
1,124,"7900,7876,175,1278,2064,2065,2063,3363,3364,33...","1,0,1,0,0,1,0,0,0,0,1,0,0,0,1,0,0,0,0,1,0,0,0,...","0,32683,62000,83632,189483,189483,189483,25879...","1,1,1,2,3,3,3,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,6,...","25.0,26.0,29.0,26.0,18.0,18.0,18.0,33.0,33.0,3..."
2,2746,"5273,758,5976,236,404,382,405,873,531,775,294,...",0001011110110101101,"0,21592,49069,72254,91945,111621,134341,234605...",5252222222222222222,"25.0,28.0,17.0,24.0,20.0,16.0,16.0,19.0,18.0,1..."
3,5382,"5000,3944,217,5844,5965,4990,5235,6050,5721,55...","1,0,1,0,1,1,1,1,0,0,0,1,0,1,1,1,1,0,1,1,0,0,0,...","0,39828,132189,153727,169080,178049,274437,348...","5,5,2,5,5,5,5,5,5,5,5,1,1,1,1,1,1,1,1,1,1,1,1,...","25.0,24.0,35.0,88.0,18.0,12.0,5.0,92.0,70.0,14..."
4,8623,"3915,4750,6456,3968,6104,5738,6435,5498,6102,4...","1,1,1,1,1,1,1,0,0,1,1,0,0,1,1,0,1,1,1,0,0,1,1,...","0,38769,72859,116541,155537,189115,221413,2399...","5,5,5,5,5,5,5,5,5,5,5,2,2,2,2,2,5,5,5,5,5,5,5,...","25.0,16.0,33.0,30.0,40.0,35.0,30.0,29.0,15.0,1..."
5,8701,"3901,6671,4963,6143,8279,3964,4002,754,1110,77...",11101000100111011,"0,17833,45872,74561,121601,141679,183773,11482...",55555552222222222,"25.0,13.0,15.0,24.0,25.0,44.0,17.0,39.0,16.0,1..."
6,12741,"5145,9691,9697,5202,4787,5695,7858,5653,5889,4...","0,1,0,1,1,0,1,0,1,0,0,0,1,1,1,0,1,0,0,1,1,0,0,...","0,22273,54323,92046,109716,132679,158477,18403...","5,5,5,5,5,5,1,5,5,5,5,6,6,6,6,6,6,6,6,6,6,6,6,...","25.0,13.0,18.0,29.0,35.0,15.0,21.0,23.0,23.0,3..."
7,13134,"3926,564,3865,4231,3684,3988,3968,5219,4447,61...","1,0,0,1,1,0,0,1,1,0,1,1,1,0,1,1,0,1,0,1,0,1,1,...","0,23840,46834,64749,113000,183369,218217,29783...","5,2,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,2,2,2,2,2,2,...","25.0,22.0,18.0,19.0,13.0,43.0,65.0,31.0,5.0,17..."
8,24418,"7900,7876,175,1278,2063,2065,2064,3363,3364,33...","0,0,1,1,0,0,0,0,1,1,0,0,1,0,0,0,1,0,1,0,0,1,1,...","0,24224,51020,70540,88142,88142,88142,100241,1...","1,1,1,2,3,3,3,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,6,...","25.0,30.0,20.0,24.0,17.0,17.0,17.0,4.0,4.0,4.0..."
9,24600,"7900,7876,175,1278,2063,2065,2064,3365,3363,33...","1,0,1,1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,...","0,25379,50137,70181,148601,148601,148601,21935...","1,1,1,2,3,3,3,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,6,...","25.0,24.0,23.0,22.0,18.0,18.0,18.0,24.0,24.0,2..."


In [266]:
def get_data_for_train_encode(train_user, seq_len):
    all_ques_seq = []
    all_ans_seq = []
    all_parts_seq = []
    all_ela_seq = []
    
    target_ques = []
    target_anss = []
    target_parts = []
    target_elas = []
    
    for row in train_user.itertuples():
        q_ids = getattr(row, 'question_id_seq').strip().split(',')
        ans_ids = getattr(row, 'answered_correctly_seq').strip().split(',')
        part_ids = getattr(row, 'part_seq').strip().split(',')
        ela_ids = getattr(row, 'prior_question_elapsed_time_seq').strip().split(',')
        
        assert len(q_ids) == len(ans_ids) == len(part_ids) == len(ela_ids)
        
        target_index = len(q_ids) - 1
        q_ids_seq = q_ids[:target_index+1]
        ans_ids_seq = ans_ids[:target_index+1]
        part_ids_seq = part_ids[:target_index+1]
        ela_ids_seq = ela_ids[:target_index+1]
        
        length = len(q_ids_seq)
        if length >= seq_len:
            q_ids_seq = q_ids_seq[-seq_len:]
            ans_ids_seq = ans_ids_seq[-seq_len:]
            part_ids_seq = part_ids_seq[-seq_len:]
            ela_ids_seq = ela_ids_seq[-seq_len:]  
                
            pad_counts = 0
        else:
            pad_counts = seq_len - length
            
        q_ids_seq = [int(float(e)) for e in q_ids_seq]
        ans_ids_seq = [int(float(e)) for e in ans_ids_seq]
        part_ids_seq = [int(float(e)) for e in part_ids_seq]
        ela_ids_seq = [int(float(e)) for e in ela_ids_seq]
            
        q_ids_seq = [13523]*pad_counts + q_ids_seq
        # question用13523表示padding位
        ans_ids_seq = [2]*pad_counts  + ans_ids_seq
        # ans用2表示padding位
        # ans因为是输入到decoder中，所以需要一个起始符号，这里选择3作为其实符号，也就是句子序列中的bos的作用
        part_ids_seq = [8]*pad_counts + part_ids_seq
        # part用8来表示padding位
        ela_ids_seq = [301]*pad_counts + ela_ids_seq
        # ela用301来表示padding位
#             print("q_ids length is:{}\n ans_ids length is:{}\n part length is:{}\n ela_ids length is:{}".format(len(q_ids_seq),len(ans_ids_seq),len(part_ids_seq),len(ela_ids_seq)))
        all_ques_seq.append(q_ids_seq)
        all_ans_seq.append(ans_ids_seq)
        all_parts_seq.append(part_ids_seq)
        all_ela_seq.append(ela_ids_seq)        
        
        target_ques.append([int(float(q_ids[-1]))])
        target_anss.append([int(float(ans_ids[-1]))])
        target_parts.append([int(float(part_ids[-1]))])
        target_elas.append([int(float(ela_ids[-1]))])


    return torch.LongTensor(all_ques_seq),\
        torch.LongTensor(all_ans_seq),\
        torch.LongTensor(all_parts_seq),\
        torch.LongTensor(all_ela_seq),\
        torch.LongTensor(target_ques),\
        torch.LongTensor(target_anss),\
        torch.LongTensor(target_parts),\
        torch.LongTensor(target_elas)
            
            

In [267]:
class Rii_dataset_train(Dataset):
    def __init__(self,train_user):
        self.df = train_user
        self.ques_seq, self.ans_seq, self.parts_seq, self.ela_seq,\
        self.trg_que, self.trg_ans, self.trg_part, self.trg_ela = get_data_for_train_encode(self.df, 100)
    def __len__(self):
        return len(self.ques_seq)
    def __getitem__(self, index):
        return self.ques_seq[index], self.ans_seq[index], self.parts_seq[index], self.ela_seq[index],\
        self.trg_que[index], self.trg_ans[index], self.trg_part[index], self.trg_ela[index]

In [278]:
test_df = pd.read_csv('/Users/hesu/Documents/KT/riiid/valid.csv')

In [279]:
test_df = test_df.loc[test_df['content_type_id'] == 0].reset_index(drop=True)
test_df['prior_question_elapsed_time'].fillna(elapsed_mean, inplace=True)
test_df['prior_question_elapsed_time'] = test_df['prior_question_elapsed_time'].apply(lambda x: get_elapsed_time(x))

In [280]:
test_df.head(10)

Unnamed: 0,row_id,timestamp,user_id,content_id,content_type_id,task_container_id,user_answer,answered_correctly,prior_question_elapsed_time,prior_question_had_explanation
0,10000,2868613211,91216,1219,0,780,1,1,17.0,True
1,10001,2868700426,91216,1172,0,781,0,1,17.0,True
2,10002,2868748313,91216,230,0,782,3,1,16.0,True
3,10003,2874335350,91216,6469,0,764,3,0,15.0,True
4,10004,2912644354,91216,5250,0,783,2,0,20.0,True
5,10005,2912756715,91216,8191,0,784,2,0,11.0,True
6,10006,2912855281,91216,5156,0,785,3,0,30.0,True
7,10007,2912982177,91216,3641,0,786,2,1,14.0,True
8,10008,2913096884,91216,4409,0,787,1,0,20.0,True
9,10010,2913266013,91216,4292,0,790,3,1,14.0,True


In [281]:
question.head(10)

Unnamed: 0,question_id,bundle_id,correct_answer,part,tags
0,0,0,0,1,51 131 162 38
1,1,1,1,1,131 36 81
2,2,2,0,1,131 101 162 92
3,3,3,0,1,131 149 162 29
4,4,4,3,1,131 5 162 38
5,5,5,2,1,131 149 162 81
6,6,6,2,1,10 94 162 92
7,7,7,0,1,61 110 162 29
8,8,8,3,1,131 13 162 92
9,9,9,3,1,10 164 81


In [282]:
test_df = pd.merge(test_df, question, left_on='content_id',right_on='question_id', how='left')


In [283]:
test_df.head(10)

Unnamed: 0,row_id,timestamp,user_id,content_id,content_type_id,task_container_id,user_answer,answered_correctly,prior_question_elapsed_time,prior_question_had_explanation,question_id,bundle_id,correct_answer,part,tags
0,10000,2868613211,91216,1219,0,780,1,1,17.0,True,1219,1219,1,2,155 119 92 102
1,10001,2868700426,91216,1172,0,781,0,1,17.0,True,1172,1172,0,2,155 163 81 29
2,10002,2868748313,91216,230,0,782,3,1,16.0,True,230,230,3,2,143 176 29 102
3,10003,2874335350,91216,6469,0,764,3,0,15.0,True,6469,6469,0,5,64
4,10004,2912644354,91216,5250,0,783,2,0,20.0,True,5250,5250,3,5,170
5,10005,2912756715,91216,8191,0,784,2,0,11.0,True,8191,8191,0,5,1
6,10006,2912855281,91216,5156,0,785,3,0,30.0,True,5156,5156,1,5,108
7,10007,2912982177,91216,3641,0,786,2,1,14.0,True,3641,3641,2,5,180
8,10008,2913096884,91216,4409,0,787,1,0,20.0,True,4409,4409,2,5,168
9,10010,2913266013,91216,4292,0,790,3,1,14.0,True,4292,4292,3,5,168


In [284]:
test_df = pd.merge(test_df, train_user, on='user_id',how='left')


In [285]:
test_df.head(10)

Unnamed: 0,row_id,timestamp,user_id,content_id,content_type_id,task_container_id,user_answer,answered_correctly,prior_question_elapsed_time,prior_question_had_explanation,question_id,bundle_id,correct_answer,part,tags,question_id_seq,answered_correctly_seq,timestamp_seq,part_seq,prior_question_elapsed_time_seq
0,10000,2868613211,91216,1219,0,780,1,1,17.0,True,1219,1219,1,2,155 119 92 102,"7900,7876,175,1278,2065,2064,2063,3364,3363,33...","1,0,0,1,1,1,0,1,0,0,0,0,0,0,1,1,0,0,1,0,0,1,0,...","0,23540,48745,67665,161655,161655,161655,26479...","1,1,1,2,3,3,3,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,6,...","25.0,20.0,20.0,23.0,17.0,17.0,17.0,29.0,29.0,2..."
1,10001,2868700426,91216,1172,0,781,0,1,17.0,True,1172,1172,0,2,155 163 81 29,"7900,7876,175,1278,2065,2064,2063,3364,3363,33...","1,0,0,1,1,1,0,1,0,0,0,0,0,0,1,1,0,0,1,0,0,1,0,...","0,23540,48745,67665,161655,161655,161655,26479...","1,1,1,2,3,3,3,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,6,...","25.0,20.0,20.0,23.0,17.0,17.0,17.0,29.0,29.0,2..."
2,10002,2868748313,91216,230,0,782,3,1,16.0,True,230,230,3,2,143 176 29 102,"7900,7876,175,1278,2065,2064,2063,3364,3363,33...","1,0,0,1,1,1,0,1,0,0,0,0,0,0,1,1,0,0,1,0,0,1,0,...","0,23540,48745,67665,161655,161655,161655,26479...","1,1,1,2,3,3,3,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,6,...","25.0,20.0,20.0,23.0,17.0,17.0,17.0,29.0,29.0,2..."
3,10003,2874335350,91216,6469,0,764,3,0,15.0,True,6469,6469,0,5,64,"7900,7876,175,1278,2065,2064,2063,3364,3363,33...","1,0,0,1,1,1,0,1,0,0,0,0,0,0,1,1,0,0,1,0,0,1,0,...","0,23540,48745,67665,161655,161655,161655,26479...","1,1,1,2,3,3,3,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,6,...","25.0,20.0,20.0,23.0,17.0,17.0,17.0,29.0,29.0,2..."
4,10004,2912644354,91216,5250,0,783,2,0,20.0,True,5250,5250,3,5,170,"7900,7876,175,1278,2065,2064,2063,3364,3363,33...","1,0,0,1,1,1,0,1,0,0,0,0,0,0,1,1,0,0,1,0,0,1,0,...","0,23540,48745,67665,161655,161655,161655,26479...","1,1,1,2,3,3,3,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,6,...","25.0,20.0,20.0,23.0,17.0,17.0,17.0,29.0,29.0,2..."
5,10005,2912756715,91216,8191,0,784,2,0,11.0,True,8191,8191,0,5,1,"7900,7876,175,1278,2065,2064,2063,3364,3363,33...","1,0,0,1,1,1,0,1,0,0,0,0,0,0,1,1,0,0,1,0,0,1,0,...","0,23540,48745,67665,161655,161655,161655,26479...","1,1,1,2,3,3,3,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,6,...","25.0,20.0,20.0,23.0,17.0,17.0,17.0,29.0,29.0,2..."
6,10006,2912855281,91216,5156,0,785,3,0,30.0,True,5156,5156,1,5,108,"7900,7876,175,1278,2065,2064,2063,3364,3363,33...","1,0,0,1,1,1,0,1,0,0,0,0,0,0,1,1,0,0,1,0,0,1,0,...","0,23540,48745,67665,161655,161655,161655,26479...","1,1,1,2,3,3,3,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,6,...","25.0,20.0,20.0,23.0,17.0,17.0,17.0,29.0,29.0,2..."
7,10007,2912982177,91216,3641,0,786,2,1,14.0,True,3641,3641,2,5,180,"7900,7876,175,1278,2065,2064,2063,3364,3363,33...","1,0,0,1,1,1,0,1,0,0,0,0,0,0,1,1,0,0,1,0,0,1,0,...","0,23540,48745,67665,161655,161655,161655,26479...","1,1,1,2,3,3,3,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,6,...","25.0,20.0,20.0,23.0,17.0,17.0,17.0,29.0,29.0,2..."
8,10008,2913096884,91216,4409,0,787,1,0,20.0,True,4409,4409,2,5,168,"7900,7876,175,1278,2065,2064,2063,3364,3363,33...","1,0,0,1,1,1,0,1,0,0,0,0,0,0,1,1,0,0,1,0,0,1,0,...","0,23540,48745,67665,161655,161655,161655,26479...","1,1,1,2,3,3,3,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,6,...","25.0,20.0,20.0,23.0,17.0,17.0,17.0,29.0,29.0,2..."
9,10010,2913266013,91216,4292,0,790,3,1,14.0,True,4292,4292,3,5,168,"7900,7876,175,1278,2065,2064,2063,3364,3363,33...","1,0,0,1,1,1,0,1,0,0,0,0,0,0,1,1,0,0,1,0,0,1,0,...","0,23540,48745,67665,161655,161655,161655,26479...","1,1,1,2,3,3,3,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,6,...","25.0,20.0,20.0,23.0,17.0,17.0,17.0,29.0,29.0,2..."


In [290]:
test_df.dtypes

row_id                               int64
timestamp                            int64
user_id                              int64
content_id                           int64
content_type_id                      int64
task_container_id                    int64
user_answer                          int64
answered_correctly                   int64
prior_question_elapsed_time        float64
prior_question_had_explanation      object
question_id                          int64
bundle_id                            int64
correct_answer                       int64
part                                 int64
tags                                object
question_id_seq                     object
answered_correctly_seq              object
timestamp_seq                       object
part_seq                            object
prior_question_elapsed_time_seq     object
dtype: object

In [288]:
def pad_np(nums, pad_index):
    Args.seq_size = 100
    
    if nums.size == 0:
        return np.array([0]*Args.seq_size)

    if nums.size > Args.seq_size:
        nums = nums[-Args.seq_size:]
    else:
        pad_counts = Args.seq_size - len(nums)
        nums = np.pad(nums,(pad_counts,0),'constant',constant_values=(pad_index,0))
        # (pad_counts, 0 )表示在左边填充pad_counts个数字，右边填充0个数字;
        # constant_values=(0,0)表示左边填充0， 右边也填充0
    return nums



def pad_seq(df):
    df['content_id'] = np.array(df['content_id'])

    
#     df['question_id_seq'] = df['question_id_seq'].apply(lambda x: np.array(x.split(',')).astype(np.int16))
    df['question_id_seq'] = df['question_id_seq'].apply(lambda x: np.array(x.split(',')).astype(np.int16))
    df['question_id_seq_input'] = df.apply(lambda x: pad_np(x.question_id_seq, 13523), axis=1)
    
    df['answered_correctly_seq'] = df['answered_correctly_seq'].apply(lambda x: np.array(x.split(',')).astype(np.int16))
    df['answered_correctly_input'] = df.apply(lambda x: pad_np(x.answered_correctly_seq, 2), axis=1)

    df['part_seq'] = df['part_seq'].apply(lambda x: np.array(x.split(',')).astype(np.int16))
    df['part_seq_input'] = df.apply(lambda x: pad_np(x.part_seq, 8), axis=1)
    
    df['prior_question_elapsed_time_seq'] = df['prior_question_elapsed_time_seq'].apply(lambda x: np.array(x.split(',')).astype(np.int16))
    df['prior_question_elapsed_time_seq_input'] = df.apply(lambda x: pad_np(x.prior_question_elapsed_time_seq, 301), axis=1)
    
    return df
    

In [289]:
test_df = pad_seq(test_df)

AttributeError: 'float' object has no attribute 'split'

## Model

In [255]:
class Encoder(nn.Module):
    def __init__(self, 
                 que_num,
                 part_num,
                 ela_num,
                 hid_dim, 
                 n_layers, 
                 n_heads, 
                 pf_dim,
                 dropout, 
                 device,
                 max_length = 100):
        super().__init__()

        self.device = device
        
        self.que_embedding = nn.Embedding(que_num, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        self.part_embedding = nn.Embedding(part_num, hid_dim)
        self.ela_embedding = nn.Embedding(ela_num, hid_dim)
        self.ans_embedding = nn.Embedding(3, hid_dim)
        
        self.layers = nn.ModuleList([EncoderLayer(hid_dim, 
                                                  n_heads, 
                                                  pf_dim,
                                                  dropout, 
                                                  device) 
                                     for _ in range(n_layers)])
        
        self.dropout = nn.Dropout(dropout)
        self.output_layer = nn.Linear(hid_dim, 2)
        self.trg_linear = nn.Linear(hid_dim, hid_dim)
        
        self.output_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.avgpool = nn.AvgPool1d(max_length)
        
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        
        
    def forward(self, src_que,src_ans,src_part,src_ela,src_mask, trg_que, trg_part, trg_ela,trg_src_mask):
        
        #src = [batch size, src len]
        #src_mask = [batch size, src len]
        
        batch_size = src_que.shape[0]
        src_len = src_que.shape[1]
        
        pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
        # pos的维度是[batch_size, src_len]，其中每个一维的都是都是[1,100]，
        # 其中unsqueeze(0)的作用是将tensor由[seq_len]维度变成[batch_size, seq_len]维
        
        que_emb = self.que_embedding(src_que)
        part_emb = self.part_embedding(src_part)
        ela_emb = self.ela_embedding(src_ela)
        ans_emb = self.ans_embedding(src_ans)
        tok_emb = que_emb+part_emb+ela_emb+ans_emb
        
        src = self.dropout((tok_emb * self.scale) + self.pos_embedding(pos))
        
        #src = [batch size, src len, hid dim]
        
        for layer in self.layers:
            src = layer(src, src_mask)
            
        encoder_output = src

        trg_que_emb = self.que_embedding(trg_que)
        trg_part_emb = self.part_embedding(trg_part)
        trg_ela_emb = self.ela_embedding(trg_ela)
        
        trg_emb = trg_que_emb+trg_part_emb+trg_ela_emb
        trg_linear = self.trg_linear(trg_emb)
        
        print("encoder_output shape is:{}\ntrg_linear shape is:{}".format(encoder_output.shape, trg_linear.shape))
        attention_output, _ = self.output_attention(trg_linear, encoder_output, encoder_output, trg_src_mask)
#         print("src shape:{}\ntrg_que_emb shape:{}\ntrg_part_emb shape:{}\n".format(src.shape, trg_que_emb.shape, trg_part_emb.shape))
        #src = [batch size, src len, hid dim]
        print("attention output shape is:{}".format(attention_output.shape))
        
        output_pool = self.avgpool(attention_output.permute(0,2,1)).permute(0,2,1)
        
        output = self.output_layer(output_pool)
        
        return output
        

In [227]:
class EncoderLayer(nn.Module):
    def __init__(self, 
                 hid_dim, 
                 n_heads, 
                 pf_dim,  
                 dropout, 
                 device):
        super().__init__()
        
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.ff_layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, 
                                                                     pf_dim, 
                                                                     dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src, src_mask):
        
        #src = [batch size, src len, hid dim]
        #src_mask = [batch size, src len]
                
        #self attention
#         print("In encoder Q shape is:{}\t K shape is:{}\t V shape is:{}\t mask shape is:{}".format(src.shape,\
#                                                                                                    src.shape,src.shape,src_mask.shape))
        _src, _ = self.self_attention(src, src, src, src_mask)
        
        #dropout, residual connection and layer norm
        src = self.self_attn_layer_norm(src + self.dropout(_src))
        
        #src = [batch size, src len, hid dim]
        
        #positionwise feedforward
        _src = self.positionwise_feedforward(src)
        
        #dropout, residual and layer norm
        src = self.ff_layer_norm(src + self.dropout(_src))
        
        #src = [batch size, src len, hid dim]
        
        return src

In [228]:
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, device):
        super().__init__()
        
        assert hid_dim % n_heads == 0
        
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        
        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)
        
        self.fc_o = nn.Linear(hid_dim, hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
        
    def forward(self, query, key, value, mask = None):
        
        batch_size = query.shape[0]
        
        #query = [batch size, query len, hid dim]
        #key = [batch size, key len, hid dim]
        #value = [batch size, value len, hid dim]
                
        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)
        
        #Q = [batch size, query len, hid dim]
        #K = [batch size, key len, hid dim]
        #V = [batch size, value len, hid dim]
                
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        
        
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
#         print("Q shape is:{}\t K shape is:{}\t V shape is:{}\t energy shape is:{}\tmask shapeis:{}".format(Q.shape,\
#                                                                                 K.shape, V.shape, energy.shape,mask.shape))
        #energy = [batch size, n heads, query len, key len]
        
        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)
        
        attention = torch.softmax(energy, dim = -1)
                
        #attention = [batch size, n heads, query len, key len]
                
        x = torch.matmul(self.dropout(attention), V)
        
        #x = [batch size, n heads, query len, head dim]
        
        x = x.permute(0, 2, 1, 3).contiguous()
        
        #x = [batch size, query len, n heads, head dim]
        
        x = x.view(batch_size, -1, self.hid_dim)
        
        #x = [batch size, query len, hid dim]
        
        x = self.fc_o(x)
        
        #x = [batch size, query len, hid dim]
        
        return x, attention        

In [229]:
class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self, hid_dim, pf_dim, dropout):
        super().__init__()
        
        self.fc_1 = nn.Linear(hid_dim, pf_dim)
        self.fc_2 = nn.Linear(pf_dim, hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        
        #x = [batch size, seq len, hid dim]
        
        x = self.dropout(torch.relu(self.fc_1(x)))
        
        #x = [batch size, seq len, pf dim]
        
        x = self.fc_2(x)
        
        #x = [batch size, seq len, hid dim]
        
        return x

# Seq2Seq

In [230]:
class Seq2Seq(nn.Module):
    def __init__(self, 
                 encoder, 
                 src_pad_idx, 
                 device):
        super().__init__()
        
        self.encoder = encoder
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = src_pad_idx
        self.device = device
        
    def make_trg_src_mask(self, src):
        # 这个是trg和src中的每一个计算attention分布时用的mask
        #src = [batch size, src len]
        
        trg_src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)

        #src_mask = [batch size, 1, 1, src len]

        return trg_src_mask
    
    def make_src_mask(self, src):
        # 这个是encoder部分，只能看见当前que前面que信息的mask矩阵，上三角mask矩阵
        #src = [batch size, trg len]
        
        src_pad_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        
        #src_pad_mask = [batch size, 1, 1, src len]
        
        src_len = src.shape[1]
        
        src_sub_mask = torch.tril(torch.ones((src_len, src_len), device = self.device)).bool()
        
        #src_sub_mask = [src len, src len]
            
        src_mask = src_pad_mask & src_sub_mask
        
        #src_mask = [batch size, 1, src len, src len]
        
        return src_mask

    def forward(self, src_que,src_ans,src_part,src_ela,trg_que, trg_part,trg_ela):
        
        #src = [batch size, src len]
        #trg = [batch size, trg len]
            
        src_mask = self.make_src_mask(src_que)
        trg_src_mask = self.make_src_mask(src_que)
        
        #src_mask = [batch size, 1, 1, src len]
        #trg_src_mask = [batch size, 1, trg len, trg len]
        
        output = self.encoder(src_que,src_ans,src_part,src_ela,src_mask, trg_que, trg_part, trg_ela,trg_src_mask)
        
        #enc_src = [batch size, src len, hid dim]
                
        
        #output = [batch size, trg len, output dim]
        #attention = [batch size, n heads, trg len, src len]
        print("The model output shape is:{}".format(output.shape))
        return output       

In [231]:
que_num = 13524
ans_num = 3
part_num = 9
ela_num = 302

HID_DIM = 256
ENC_LAYERS = 3
ENC_HEADS = 8
ENC_PF_DIM = 512
ENC_DROPOUT = 0.1



device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

enc = Encoder(que_num,part_num,ela_num,
              HID_DIM, 
              ENC_LAYERS, 
              ENC_HEADS, 
              ENC_PF_DIM, 
              ENC_DROPOUT, 
              device)

In [232]:
src_pad_que_idx = 13523
trg_pad_ans_idx = 2

In [233]:
model = Seq2Seq(enc, src_pad_que_idx, device).to(device)

In [234]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 5,478,914 trainable parameters


In [235]:
def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.xavier_uniform_(m.weight.data)

In [236]:
model.apply(initialize_weights)

Seq2Seq(
  (encoder): Encoder(
    (que_embedding): Embedding(13524, 256)
    (pos_embedding): Embedding(100, 256)
    (part_embedding): Embedding(9, 256)
    (ela_embedding): Embedding(302, 256)
    (ans_embedding): Embedding(3, 256)
    (layers): ModuleList(
      (0): EncoderLayer(
        (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (ff_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (self_attention): MultiHeadAttentionLayer(
          (fc_q): Linear(in_features=256, out_features=256, bias=True)
          (fc_k): Linear(in_features=256, out_features=256, bias=True)
          (fc_v): Linear(in_features=256, out_features=256, bias=True)
          (fc_o): Linear(in_features=256, out_features=256, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (positionwise_feedforward): PositionwiseFeedforwardLayer(
          (fc_1): Linear(in_features=256, out_features=512, bias=True)
          (fc_2)

In [237]:
LEARNING_RATE = 5e-4

optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)

In [238]:
criterion = nn.CrossEntropyLoss(ignore_index = trg_pad_ans_idx)


## Train

In [244]:
def train(model, iterator, optimizer, criterion, clip):
    
    model.train()
    
    epoch_loss = 0
    total_num = 0
    right_num = 0
    
    for i, batch in tqdm(enumerate(iterator)):
        
        src_que, src_ans, src_part, src_ela, trg_que, trg_ans, trg_part, trg_ela = batch
        
        optimizer.zero_grad()
        print("src_que is:{}\nsrc_ans is:{}\ntrg_que is:{}\ntrg_ans is:{}".format(src_que.shape, src_ans.shape,trg_que.shape,trg_ans.shape))
        
        output = model(src_que, src_ans, src_part, src_ela, trg_que, trg_part, trg_ela)
        # 由于decoder预测时是错位预测，也就是用trg[t-1]去预测trg[t]，所以输入到decoder模型中的trg缺少最后一个样本的结果 
        
        #output = [batch size, trg len - 1, output dim]
        #trg = [batch size, trg len]
        
        output_dim = output.shape[-1]
            
        output = output.contiguous().view(-1, output_dim)
        trg_ans = trg_ans.contiguous().view(-1)
        # contiguous()用于判定tensor是否是连续的
        
        #output = [batch size * trg len - 1, output dim]
        #trg = [batch size * trg len - 1]
            
        loss = criterion(output, trg_ans)
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
                
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

In [245]:
def evaluate(model, iterator, criterion):
    
    model.eval()
    
    epoch_loss = 0
    
    with torch.no_grad():
    
        for i, batch in enumerate(iterator):

            src = batch.src
            trg = batch.trg

            output, _ = model(src, trg[:,:-1])
            
            #output = [batch size, trg len - 1, output dim]
            #trg = [batch size, trg len]
            
            output_dim = output.shape[-1]
            
            output = output.contiguous().view(-1, output_dim)
            trg = trg[:,1:].contiguous().view(-1)
            print("output")
            #output = [batch size * trg len - 1, output dim]
            #trg = [batch size * trg len - 1]
            
            loss = criterion(output, trg)

            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

In [246]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [247]:

train_dataset = Rii_dataset_train(train_user.head(100))
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)

In [248]:
N_EPOCHS = 10
CLIP = 1

best_valid_loss = float('inf')


for epoch in range(N_EPOCHS):
    start_time = time.time()
    
    train_loss = train(model, train_dataloader, optimizer, criterion, CLIP)
    end_time = time.time()
    
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f}')

1it [00:00,  6.17it/s]

src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])
src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])


3it [00:00,  6.27it/s]

src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])
src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])


5it [00:00,  6.56it/s]

src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])
src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])


7it [00:01,  6.72it/s]

src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])
src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])


9it [00:01,  6.56it/s]

src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])
src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])


11it [00:01,  6.60it/s]

src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])
src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])


13it [00:01,  6.46it/s]

src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])
src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])


15it [00:02,  6.39it/s]

src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])
src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])


17it [00:02,  6.41it/s]

src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])
src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])


19it [00:02,  6.44it/s]

src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])
src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])


21it [00:03,  6.43it/s]

src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])
src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])


23it [00:03,  6.41it/s]

src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])
src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])


25it [00:03,  6.41it/s]

src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])
src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])


27it [00:04,  6.37it/s]

src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])
src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])


29it [00:04,  6.41it/s]

src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])
src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])


31it [00:04,  6.38it/s]

src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])
src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])


33it [00:05,  6.33it/s]

src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])
src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])


35it [00:05,  6.39it/s]

src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])
src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])


37it [00:05,  6.46it/s]

src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])
src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])


39it [00:06,  6.48it/s]

src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])
src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])


41it [00:06,  6.50it/s]

src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])
src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])


43it [00:06,  6.39it/s]

src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])
src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])


45it [00:06,  6.43it/s]

src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])
src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])


47it [00:07,  6.48it/s]

src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])
src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])


49it [00:07,  6.45it/s]

src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])
src_que is:torch.Size([2, 100])
src_ans is:torch.Size([2, 100])
trg_que is:torch.Size([2, 1])
trg_ans is:torch.Size([2, 1])
encoder_output shape is:torch.Size([2, 100, 256])
trg_linear shape is:torch.Size([2, 1, 256])
attention output shape is:torch.Size([2, 100, 256])
The model output shape is:torch.Size([2, 1, 2])


50it [00:07,  6.45it/s]


NameError: name 'epoch_mins' is not defined