-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
executable file
·87 lines (74 loc) · 3.17 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import argparse
import random
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import inference
import train
from omegaconf import OmegaConf
if __name__ == "__main__":
"""
하이퍼 파라미터 등 각종 설정값을 입력받습니다
터미널 실행 예시 : python3 run.py --batch_size=64 ...
실행 시 '--batch_size=64' 같은 인자를 입력하지 않으면 default 값이 기본으로 실행됩니다
"""
parser = argparse.ArgumentParser()
parser.add_argument("--config", "-c", type=str, default="custom_config")
parser.add_argument("--mode", "-m", default="train")
parser.add_argument(
"--saved_model",
"-s",
default=None,
help="저장된 모델의 파일 경로를 입력해주세요. 예시: saved_models/klue/roberta-small/epoch=?-step=?.ckpt 또는 save_models/model.pt",
)
args, _ = parser.parse_known_args()
config = OmegaConf.load(f"./config/{args.config}.yaml")
if args.saved_model:
config.path.saved_model = args.saved_model
SEED = config.utils.seed
pl.seed_everything(SEED, workers=True) # covers torch, numpy, random
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
if args.mode == "train" or args.mode == "t":
if config.k_fold.use_k_fold is True:
train.train_cv(config)
else:
train.train(config)
elif args.mode == "exp" or args.mode == "e":
exp_count = int(input("실험할 횟수를 입력해주세요 "))
train.sweep(config, exp_count)
elif args.mode == "inference" or args.mode == "i":
if args.saved_model is None:
print("경로를 입력해주세요")
else:
config.path.resume_path = args.saved_model
inference.inference(args, config)
elif args.mode == "ensemble":
import ensemble
import re
assert config.ensemble.use_ensemble is True
assert any(config.ensemble.ckpt_paths) + any(config.ensemble.csv_paths) == 1
if any(config.ensemble.ckpt_paths):
assert config.ensemble.architecture == "EnsembleVotingModel"
ensemble.inference(args, config)
elif any(config.ensemble.csv_paths):
df = ensemble.ensemble_csvs(config.ensemble.csv_paths)
df["probs"] = df["probs"].apply(list).apply(str)
if ensemble._sanity_check(df):
print(len(df))
save_name = "_".join([re.search(r".+(?=\.csv)", path.split("/")[-1]).group() for path in config.ensemble.csv_paths])
df.to_csv(f"./prediction/ensemble_{save_name}.csv", index=False)
elif args.mode == "all" or args.mode == "a":
assert args.saved_model is None, "Cannot use 'saved_model' args for 'all' mode"
if config.k_fold.use_k_fold is True:
train.train_cv(config)
else:
train.train(config)
inference.inference(args, config)
else:
print("모드를 다시 설정해주세요 ")
print("train : t,\ttrain")
print("exp : e,\texp")
print("inference : i,\tinference")
print("continue train : ct,\tcontinue train")