-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
65 lines (50 loc) · 2.03 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import argparse
import os
import pickle
from time import time
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.svm import LinearSVC
from util.load_data import load_dataset
from util.model_evaluation import get_metrics
parser = argparse.ArgumentParser("train.py")
parser.add_argument("--mode", help="available modes: train_test", required=True)
parser.add_argument("--train", help="train folder")
parser.add_argument("--test", help="test folder")
parser.add_argument("--s", help="path to save model")
args = parser.parse_args()
def save_model(filename, clf):
with open(filename, 'wb') as f:
pickle.dump(clf, f)
if args.mode == "train-test":
if not (args.train and args.test):
parser.error("Mode train-test requires --train and --test")
if not args.s:
parser.error("Mode train-test requires --s")
train_path = os.path.abspath(args.train)
test_path = os.path.abspath(args.test)
print("train model")
model_path = os.path.abspath(args.s)
print("loading data")
X_train, y_train = load_dataset(train_path)
X_test, y_test = load_dataset(test_path)
target_names = list(set([i[0] for i in y_train]))
print("%d documents (training set)" % len(X_train))
print("%d documents (test set)" % len(X_test))
print("%d categories" % len(target_names))
print()
print("training model")
t0 = time()
transformer = TfidfVectorizer(ngram_range=(1, 2), max_df=0.5)
X_train = transformer.fit_transform(X_train)
X_test = transformer.transform(X_test)
model = LinearSVC()
estimator = model.fit(X_train, [i[0] for i in y_train])
train_time = time() - t0
print("train time: %dm %0.3fs" % (train_time/60, train_time - 60*(train_time//60)))
t0 = time()
y_pred = estimator.predict(X_test)
test_time = time() - t0
print("test time: %dm %0.3fs" % (test_time/60, test_time - 60*(test_time//60)))
get_metrics(y_test, y_pred)
save_model(model_path + "/x_transformer.pkl", transformer)
save_model(model_path + "/model.pkl", estimator)