-
Notifications
You must be signed in to change notification settings - Fork 0
/
STRec.py
79 lines (59 loc) · 2.68 KB
/
STRec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# STRec is implemented based on Recbole.
# https://recbole.io/
import sys
import time
import torch
from logging import getLogger
import logging
from recbole.utils import init_logger, init_seed
from recbole.trainer import Trainer
from recbole.config import Config
from recbole.data import create_dataset, data_preparation
from models import STRec_pre, STRec, print_time
if __name__ == '__main__':
start_time = time.time()
config = Config(model=STRec,config_file_list=['STRec.yaml'])
init_seed(config['seed'], config['reproducibility'])
logging.getLogger().setLevel(logging.INFO)
init_logger(config)
logger = getLogger()
logger.info(config)
dataset = create_dataset(config)
logger.info(dataset)
train_data, valid_data, test_data = data_preparation(config, dataset)
if config['mode'] == 'pre_train':
config['model'] = 'STRec_pre'
elif config['mode'] == 'train':
config['model'] = 'STRec'
elif config['mode'] == 'test':
checkpoint = torch.load(config['load_dir'])
model = locals()[config['model']](config, train_data.dataset).to(config['device'])
model.load_state_dict(checkpoint['state_dict'])
model.load_other_parameter(checkpoint.get('other_parameter'))
trainer = Trainer(config, model)
test_result = trainer.evaluate(test_data, load_best_model=False, show_progress=True)
logger.info('test result: {}'.format(test_result))
sys.exit()
elif config['mode'] == 'speed':
print_time(start_time)
checkpoint = torch.load(config['load_dir'])
model = locals()[config['model']](config, train_data.dataset).to(config['device'])
model.load_state_dict(checkpoint['state_dict'])
model.load_other_parameter(checkpoint.get('other_parameter'))
model.eval()
for batch_idx, batched_data in enumerate(train_data):
if batch_idx==2:
print(torch.cuda.max_memory_allocated())
batched_data = batched_data.to(config['device'])
out = model.predictx(batched_data)
print_time(start_time)
sys.exit()
else:
raise NotImplementedError("Make sure 'mode' in ['pre_train', 'train', 'test', 'speed']!")
model = locals()[config['model']](config, train_data.dataset).to(config['device'])
logger.info(model)
trainer = Trainer(config, model)
best_valid_score, best_valid_result = trainer.fit(train_data, valid_data)
test_result = trainer.evaluate(test_data)
logger.info('best valid result: {}'.format(best_valid_result))
logger.info('test result: {}'.format(test_result))