/
main.py
125 lines (97 loc) · 4.06 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# Import packages
import os
import sys
import pickle
from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
import torch.multiprocessing as multiprocessing
import utils.Constant as CONSTANT
from dataloader import UIRTDatset
from evaluation import Evaluator
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import warnings
import gc
warnings.filterwarnings("ignore")
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
if __name__ == '__main__':
multiprocessing.set_start_method('spawn')
from experiment import EarlyStop, train_model
from utils import Config, Logger, ResultTable, make_log_dir#, set_random_seed
# read configs
config = Config(main_conf_path='./', model_conf_path='model_config')
# apply system arguments if exist
argv = sys.argv[1:]
if len(argv) > 0:
cmd_arg = OrderedDict()
argvs = ' '.join(sys.argv[1:]).split(' ')
for i in range(0, len(argvs), 2):
arg_name, arg_value = argvs[i], argvs[i + 1]
arg_name = arg_name.strip('-')
cmd_arg[arg_name] = arg_value
config.update_params(cmd_arg)
gpu = config.get_param('Experiment', 'gpu')
gpu = str(gpu)
os.environ["CUDA_VISIBLE_DEVICES"] = gpu
# os.environ["CUDA_DEVICE_ORDER"] = "0, 1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = config.get_param('Experiment', 'model_name')
# logger
log_dir = make_log_dir(os.path.join('saves', model_name))
logger = Logger(log_dir)
config.save(log_dir)
# dataset
dataset_name = config.get_param('Dataset', 'dataset')
dataset_type = CONSTANT.DATASET_TO_TYPE[dataset_name]
dataset = UIRTDatset(**config['Dataset'])
# evaluator
num_users, num_items = dataset.num_users, dataset.num_items
###
test_eval_pos, test_eval_target, vali_eval_target, eval_neg_candidates = dataset.test_data()
test_evaluator = Evaluator(test_eval_pos, test_eval_target, vali_eval_target, eval_neg_candidates, **config['Evaluator'], num_users=num_users, num_items=num_items, item_id=None)
# early stop
early_stop = EarlyStop(**config['EarlyStop'])
# Save log & dataset config.
logger.info(config)
logger.info(dataset)
import model
MODEL_CLASS = getattr(model, model_name)
# seed = config.get_param('Experiment', 'seed')
# build model
# set_random_seed(seed)
model = MODEL_CLASS(dataset, config['Model'], device)
# train
test_score, train_time = train_model(model, dataset, test_evaluator, early_stop, logger, config)
m, s = divmod(train_time, 60)
h, m = divmod(m, 60)
logger.info('\nTotal training time - %d:%d:%d(=%.1f sec)' % (h, m, s, train_time))
# show result
evaluation_table = ResultTable(table_name='Best Result', header=list(test_score.keys()))
evaluation_table.add_row('Score', test_score)
# evaluation_table.show()
logger.info(evaluation_table.to_string())
logger.info("Saved to %s" % (log_dir))
# Extract global model
if 'LOCA' not in model_name and model_name != 'MF' and model_name != 'MOE' and model_name != 'WL':
output = model.get_output(dataset)
output_dir = os.path.join(dataset.data_dir, dataset.data_name, 'output')
if not os.path.exists(output_dir):
os.mkdir(output_dir)
output_file = os.path.join(output_dir, model_name + '_output.p')
with open(output_file, 'wb') as f:
pickle.dump(output, f, protocol=4)
config.save(output_dir)
print(f"{model_name} output extracted!")
# Extract Embedding
if model_name == 'MultVAE':
user_embedding = model.user_embedding(test_eval_pos)
emb_dir = os.path.join(dataset.data_dir, dataset.data_name, 'embedding')
if not os.path.exists(emb_dir):
os.mkdir(emb_dir)
emb_file = os.path.join(emb_dir, model_name + '_user.p')
with open(emb_file, 'wb') as f:
pickle.dump(user_embedding, f, protocol=4)
config.save(emb_dir)
print(f"{model_name} embedding extracted!")