-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
105 lines (86 loc) · 2.69 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch.utils.data import DataLoader
import lstm
import preprocessing
from datamodule import lstm_dataset, Net
EMBEDDING_DIM = 256
HIDDEN_DIM = 256
num_layers = 3
num_epoch = 500
batch_size = 50
lr = 0.0001
use_w2v = True
def main():
train_data, test_data, encode_dicts = preprocessing.preprocessing()
word_dict = encode_dicts["word_dict"]
pos_dict = encode_dicts["pos_dict"]
chunk_dict = encode_dicts["chunk_dict"]
device = "cuda"
if use_w2v:
import w2v
model = w2v.w2v_init_model(encode_dicts, batch_size, HIDDEN_DIM, num_layers, device=device)
else:
model = lstm.lstm(batch_size, len(word_dict), len(pos_dict), chunk_dict, EMBEDDING_DIM, HIDDEN_DIM, num_layers=num_layers, device=device)
model = lstm.dnn_crf(model, batch_size, len(chunk_dict), device=device)
train_set = lstm_dataset(train_data["text"], train_data["pos"], train_data["chunk"])
train_loader = DataLoader(
train_set,
batch_size=batch_size,
shuffle=True,
num_workers=2,
pin_memory=True,
drop_last=True,
)
test_set = lstm_dataset(test_data["text"], test_data["pos"], test_data["chunk"])
test_loader = DataLoader(
test_set,
batch_size=batch_size,
shuffle=False,
num_workers=2,
pin_memory=True,
drop_last=True,
)
print("input shape is")
for x, p, y in train_loader:
print(x.shape, p.shape, y.shape)
break
net = Net(model, lr, crf=True)
callbacks = []
checkpoint = ModelCheckpoint(
dirpath="./check_point",
filename="{epoch}-{f1:.2f}",
monitor="f1",
save_last=True,
save_weights_only=True,
save_top_k=1,
mode="max",
)
callbacks.append(checkpoint)
"""
callbacks.append(
EarlyStopping(
"f1",
patience=30,
verbose=True,
mode="max",
check_on_train_epoch_end=False,
)
)
"""
callbacks.append(
EarlyStopping(
"loss",
patience=30,
verbose=True,
mode="min",
check_on_train_epoch_end=False,
)
)
trainer = pl.Trainer(max_epochs=num_epoch, gpus=1, accelerator="gpu", check_val_every_n_epoch=10, callbacks=callbacks)
trainer.fit(net, train_loader, test_loader)
trainer.test(dataloaders=test_loader, ckpt_path="best")
if __name__ == "__main__":
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
main()