In [9]:
params = {
    'max_iter': 50, # 'number of iterations'
    'show': True, # 'print progress')
    'init_std': 0.1, # 'weight initialization std'
    'init_lr': 0.01, # 'initial learning rate'
    'lr_decay': 0.75, # 'learning rate decay'
    'final_lr': 1E-5, # 'learning rate will not decrease after hitting this threshold'
    'momentum': 0.9, # 'momentum rate'
    'maxgradnorm': 50.0, # 'maximum gradient norm'
    'final_fc_dim': 50, # 'hidden state dim for final fc layer'
    'key_embedding_dim': 50, # 'question embedding dimensions')
    'batch_size': 64, # 'the batch size')
    'value_embedding_dim': 200, # 'answer and question embedding dimensions')
    'memory_size': 20, # 'memory size')
    'n_question': 123, # 'the number of unique questions in the dataset')
    'seqlen': 200, # 'the allowed maximum length of a sequence')
    'data_dir': '../dkt', # 'data directory')
    'data_name': '', # 'data set name')
    'load': 'dkvmn.params', # 'model file to load')
    'save': 'dkvmn.params' # 'path to save model')
}

params['lr'] = params['init_lr']
params['key_memory_state_dim'] = params['key_embedding_dim']
params['value_memory_state_dim'] = params['value_embedding_dim']

In [10]:
from load_data import Data

dat = Data(n_question=params['n_question'], seqlen=params['seqlen'], separate_char=',') 

train_data_path = params['data_dir'] + "/" + params['data_name'] + "train.txt"
test_data_path = params['data_dir'] + "/" + params['data_name'] + "test.txt"
train_data = dat.load_data(train_data_path)
test_data = dat.load_data(test_data_path)


In [11]:
from EduKTM import DKVMN

dkvmn = DKVMN(n_question=params['n_question'],
                  batch_size=params['batch_size'],
                  key_embedding_dim=params['key_embedding_dim'],
                  value_embedding_dim=params['value_embedding_dim'],
                  memory_size=params['memory_size'],
                  key_memory_state_dim=params['key_memory_state_dim'],
                  value_memory_state_dim=params['value_memory_state_dim'],
                  final_fc_dim=params['final_fc_dim'])

dkvmn.train(params, train_data)
dkvmn.save(params['save'])

Epoch 0: 100%|██████████| 1363/1363 [03:15<00:00,  6.97it/s]


Epoch 1/50, loss : 0.31091, auc : 0.92866, accuracy : 0.86597


Epoch 1: 100%|██████████| 1363/1363 [04:03<00:00,  5.60it/s]


Epoch 2/50, loss : 0.30109, auc : 0.93303, accuracy : 0.87125


Epoch 2: 100%|██████████| 1363/1363 [04:14<00:00,  5.36it/s]


Epoch 3/50, loss : 0.30059, auc : 0.93343, accuracy : 0.87144


Epoch 3: 100%|██████████| 1363/1363 [03:57<00:00,  5.74it/s]


Epoch 4/50, loss : 0.29891, auc : 0.93417, accuracy : 0.87278


Epoch 4: 100%|██████████| 1363/1363 [03:51<00:00,  5.89it/s]


Epoch 5/50, loss : 0.29792, auc : 0.93458, accuracy : 0.87329


Epoch 5: 100%|██████████| 1363/1363 [03:50<00:00,  5.92it/s]


Epoch 6/50, loss : 0.29715, auc : 0.93494, accuracy : 0.87387


Epoch 6: 100%|██████████| 1363/1363 [14:32<00:00,  1.56it/s]   


Epoch 7/50, loss : 0.29644, auc : 0.93522, accuracy : 0.87427


Epoch 7: 100%|██████████| 1363/1363 [03:14<00:00,  7.02it/s]


Epoch 8/50, loss : 0.29578, auc : 0.93552, accuracy : 0.87480


Epoch 8: 100%|██████████| 1363/1363 [03:18<00:00,  6.85it/s]


Epoch 9/50, loss : 0.29475, auc : 0.93594, accuracy : 0.87551


Epoch 9: 100%|██████████| 1363/1363 [19:08<00:00,  1.19it/s]   


Epoch 10/50, loss : 0.29399, auc : 0.93621, accuracy : 0.87588


Epoch 10: 100%|██████████| 1363/1363 [03:18<00:00,  6.86it/s]


Epoch 11/50, loss : 0.29357, auc : 0.93635, accuracy : 0.87607


Epoch 11: 100%|██████████| 1363/1363 [03:10<00:00,  7.14it/s]


Epoch 12/50, loss : 0.29383, auc : 0.93621, accuracy : 0.87604


Epoch 12: 100%|██████████| 1363/1363 [02:56<00:00,  7.72it/s]


Epoch 13/50, loss : 0.29355, auc : 0.93637, accuracy : 0.87613


Epoch 13: 100%|██████████| 1363/1363 [03:08<00:00,  7.24it/s]


Epoch 14/50, loss : 0.29358, auc : 0.93640, accuracy : 0.87611


Epoch 14:  35%|███▌      | 479/1363 [01:23<02:34,  5.71it/s]


KeyboardInterrupt: 

In [None]:
dkvmn.load(params['load'])
dkvmn.eval(params, test_data)