-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
132 lines (107 loc) · 4.41 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import warnings
import numpy as np
import pytorch_lightning as pl
from imblearn.over_sampling import SMOTE
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch.utils.data import DataLoader
import models
import setting_path as PATH
from datamodules import ECGDataset, Net
def main():
warnings.simplefilter("ignore")
load_dir = "without_time" # preprocessed or lc_preprocessed or without_time
epochs = 1000
num_gpu = 1
lr = 1e-3
batch_size = 512
dropout = 0.6
dence_input = 832 # if window_size = 360→2752 120→832 80→512 models.pyで調査
# !!!!!!! If you train Only CNN model, CHAGE loading data from c to t !!!!!!!!!
# !!!!! If you train TimeEmbedding or TimeSin model, CHAGE loading data from t to c !!!!!!!
input_layer = models.NonTimed() # If you train LCANN model, Set in_channnels = 1
model = models.LCCNN(dence_input=dence_input, dropout=dropout)
# If you train LCCNN or CNN model, set dence_input = dence_input.
# If you train LCCNNLight2 with dataset with lengths other than 120, change last_kernel
X_train = np.load(os.path.join(PATH.ecg_path, load_dir, "train", "X.npy"))
y_train = np.load(os.path.join(PATH.ecg_path, load_dir, "train", "y.npy"))
X_valid = np.load(os.path.join(PATH.ecg_path, load_dir, "valid", "X.npy"))
y_valid = np.load(os.path.join(PATH.ecg_path, load_dir, "valid", "y.npy"))
# If you train without_time data, change pathes.
sm = SMOTE()
if load_dir != "preprocessed":
t_train = np.load(os.path.join(PATH.ecg_path, load_dir, "train", "t.npy"))
t_valid = np.load(os.path.join(PATH.ecg_path, load_dir, "valid", "t.npy"))
# c_train = np.load(os.path.join(PATH.ecg_path, load_dir, "train", "c.npy")) - 1
# c_valid = np.load(os.path.join(PATH.ecg_path, load_dir, "valid", "c.npy")) - 1
X_train = np.stack([X_train, t_train]).transpose(1, 0, 2).reshape(-1, 240) # c_train or t_train
X_train, y_train = sm.fit_resample(X_train, y_train)
X_train = X_train.reshape(-1, 2, 120)
X_valid = np.stack([X_valid, t_valid]).transpose(1, 0, 2) # c_valid or t_valid
else:
X_train, y_train = sm.fit_resample(X_train, y_train)
print("X_train.shape = ", X_train.shape, " \t y_train.shape = ", y_train.shape)
print("X_valid.shape = ", X_valid.shape, " \t y_valid.shape = ", y_valid.shape)
uniq_train, counts_train = np.unique(y_train, return_counts=True)
print("y_train count each labels: ", dict(zip(uniq_train, counts_train)))
uniq_test, counts_test = np.unique(y_valid, return_counts=True)
print("y_test count each labels: ", dict(zip(uniq_test, counts_test)))
train_set = ECGDataset(X_train, y_train)
train_loader = DataLoader(
train_set,
batch_size=batch_size,
shuffle=True,
num_workers=2,
pin_memory=True,
drop_last=True,
)
valid_set = ECGDataset(X_valid, y_valid)
valid_loader = DataLoader(
valid_set,
batch_size=batch_size,
shuffle=True,
num_workers=2,
pin_memory=True,
drop_last=True,
)
print("input shape is")
for x, y in train_loader:
print(x.shape, y.shape)
break
net = Net(input_layer, model, lr)
callbacks = []
checkpoint = ModelCheckpoint(
dirpath="./check_point",
filename="{epoch}-{recall:.2f}",
monitor="acc",
save_last=True,
save_weights_only=True,
save_top_k=1,
mode="max",
)
callbacks.append(checkpoint)
callbacks.append(
EarlyStopping(
"recall",
patience=300,
verbose=True,
mode="max",
check_on_train_epoch_end=False,
)
)
callbacks.append(
EarlyStopping(
"loss",
patience=300,
verbose=True,
mode="min",
check_on_train_epoch_end=False,
)
)
trainer = pl.Trainer(max_epochs=epochs, gpus=num_gpu, accelerator="gpu", check_val_every_n_epoch=10)
# if you use EarlyStopping, set callbacks=callbacks
trainer.fit(net, train_loader, valid_loader)
trainer.test(dataloaders=valid_loader, ckpt_path="best")
if __name__ == "__main__":
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
main()