### Train on the CDR dataset

In [1]:
import argparse
import sys
import os
import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer
from model_bio import Model
from utils import set_seed
from prepro_bio import read_bio
from save_result import Logger
from evaluation import train, evaluate

parser = argparse.ArgumentParser()
parser.add_argument("--task", default="cdr", type=str)
parser.add_argument("--data_dir", default="./dataset/cdr", type=str)
parser.add_argument("--transformer_type", default="bert", type=str)
parser.add_argument("--model_name_or_path", default="bert-base-cased", type=str)
parser.add_argument("--train_file", default="Train.BioC.JSON", type=str)
parser.add_argument("--dev_file", default="Dev.BioC.JSON", type=str)
parser.add_argument("--test_file", default="Test.BioC.JSON", type=str)
parser.add_argument("--save_path", default="out", type=str)
parser.add_argument("--load_path", default="", type=str)
parser.add_argument("--config_name", default="", type=str,
                    help="Pretrained config name or path if not the same as model_name")
parser.add_argument("--tokenizer_name", default="", type=str,
                    help="Pretrained tokenizer name or path if not the same as model_name")
parser.add_argument("--max_seq_length", default=1024, type=int,
                    help="The maximum total input sequence length after tokenization. Sequences longer "
                         "than this will be truncated, sequences shorter will be padded.")
parser.add_argument("--train_batch_size", default=4, type=int, help="Batch size for training.")
parser.add_argument("--test_batch_size", default=8, type=int, help="Batch size for testing.")
parser.add_argument("--gradient_accumulation_steps", default=1, type=int,
                    help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--adam_epsilon", default=1e-6, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--warmup_ratio", default=0.06, type=float, help="Warm up ratio for Adam.")
parser.add_argument("--num_train_epochs", default=30.0, type=float, help="Total number of training epochs to perform.")
parser.add_argument("--evaluation_steps", default=-1, type=int, help="Number of training steps between evaluations.")
parser.add_argument("--seed", type=int, default=66, help="random seed for initialization")
parser.add_argument("--num_class", type=int, default=97, help="Number of relation types in dataset.")
parser.add_argument('--gnn', type=str, default='GCN', help="GCN/GAT")
parser.add_argument('--use_gcn', type=str, default='tree', help="use gcn, both/mentions/tree/false")
parser.add_argument('--dropout', type=float, default=0.5, help="0.0/0.2/0.5")
parser.add_argument('--loss', type=str, default='BSCELoss',
                    help="use BSCELoss/BalancedLoss/ATLoss/AsymmetricLoss/APLLoss")
parser.add_argument('--s0', type=float, default=0.3)
parser.add_argument("--demo", type=str, default='false', help='use a few data to test. default true/false')
parser.add_argument("--unet_in_dim", type=int, default=3, help="unet_in_dim.")
parser.add_argument("--unet_out_dim", type=int, default=256, help="unet_out_dim.")
parser.add_argument("--down_dim", type=int, default=256, help="down_dim.")
parser.add_argument("--bert_lr", default=3e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--max_height", type=int, default=64, help="max_height.")
parser.add_argument("--rel2", type=int, default=0, help="")
parser.add_argument("--save_result", type=str, default="", help="save predict result.")
args, _ = parser.parse_known_args()

if args.task == 'cdr':
    args.data_dir = './dataset/cdr'
    args.train_file = 'train_filter.data'
    args.dev_file = 'dev_filter.data'
    args.test_file = 'test_filter.data'
    args.model_name_or_path = '/data/pretrained/scibert_scivocab_cased'
    args.train_batch_size = 12
    args.test_batch_size = 12
    args.learning_rate = 2e-5
    args.num_class = 2
    args.num_train_epochs = 30

if not os.path.exists(args.save_path):
    os.mkdir(args.save_path)
file_name = "{}_{}_{}_seed_{}_{}_{}_{}_{}".format(
    args.train_file.split('.')[0],
    args.transformer_type, args.data_dir.split('/')[-1],
    args.loss, args.use_gcn, args.s0, args.dropout, str(args.seed), )
args.save_path = os.path.join(args.save_path, file_name)
args.save_pubtator = os.path.join('./result/' + args.task + '/' + args.task  + '_' + args.loss
                                  + '_' + str(args.use_gcn) + '_s0=' + str(args.s0)
                                  + '_dropout=' + str(args.dropout) + '_' + str(args.seed))
if args.load_path == "":
    sys.stdout = Logger(stream=sys.stdout, filename=args.save_pubtator + '_test.log')
read = read_bio
print(args)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.n_gpu = torch.cuda.device_count()
args.device = device

config = AutoConfig.from_pretrained(
    args.config_name if args.config_name else args.model_name_or_path, num_labels=args.num_class, )
tokenizer = AutoTokenizer.from_pretrained(
    args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, )
model = AutoModel.from_pretrained(
    args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, )
config.cls_token_id = tokenizer.cls_token_id
config.sep_token_id = tokenizer.sep_token_id
config.transformer_type = args.transformer_type
set_seed(args)
model = Model(args, config, model, num_labels=1)
model.to(0)

# Training
train_file = os.path.join(args.data_dir, args.train_file)
dev_file = os.path.join(args.data_dir, args.dev_file)
test_file = os.path.join(args.data_dir, args.test_file)
train_features = read(args, train_file, tokenizer, max_seq_length=args.max_seq_length)
dev_features = read(args, dev_file, tokenizer, max_seq_length=args.max_seq_length)
test_features = read(args, test_file, tokenizer, max_seq_length=args.max_seq_length)
train(args, model, train_features, dev_features, test_features)

print("BEST TEST")
model.load_state_dict(torch.load(args.save_path + '_best'))
dev_score, dev_output = evaluate(args, model, dev_features, tag="dev")
print(dev_output)
test_score, test_output = evaluate(args, model, test_features, tag="test", generate=True)
print(test_output)

Namespace(adam_epsilon=1e-06, bert_lr=3e-05, config_name='', data_dir='./dataset/cdr', demo='false', dev_file='dev_filter.data', down_dim=256, dropout=0.5, evaluation_steps=-1, gnn='GCN', gradient_accumulation_steps=1, learning_rate=2e-05, load_path='', loss='BSCELoss', max_grad_norm=1.0, max_height=64, max_seq_length=1024, model_name_or_path='/data/pretrained/scibert_scivocab_cased', num_class=2, num_train_epochs=30, rel2=0, s0=0.3, save_path='out/train_filter_bert_cdr_seed_BSCELoss_tree_0.3_0.5_66', save_pubtator='./result/cdr/cdr_BSCELoss_tree_s0=0.3_dropout=0.5_66', save_result='', seed=66, task='cdr', test_batch_size=12, test_file='test_filter.data', tokenizer_name='', train_batch_size=12, train_file='train_filter.data', transformer_type='bert', unet_in_dim=3, unet_out_dim=256, use_gcn='tree', warmup_ratio=0.06)


100%|██████████| 500/500 [00:34<00:00, 14.35it/s]
100%|██████████| 500/500 [00:36<00:00, 13.67it/s]
100%|██████████| 500/500 [00:39<00:00, 12.80it/s]
epoch:   0%|          | 0/30 [00:00<?, ?it/s]

Total steps: 1230
Warmup steps: 73
training risk: 10.586490631103516    step: 40
dev_score:  0.6418147103574903 dev_output:  {'dev_F1': 64.18147103574903, 'dev_P': 88.96672348569662, 'dev_R': 50.19762796247404}
test_score:  0.6150989700815321 test_output:  {'test_F1': 61.50989700815322, 'test_P': 87.67360958899984, 'test_R': 47.373357904565125}


epoch:   3%|▎         | 1/30 [01:19<38:13, 79.07s/it]

training risk: 10.249295234680176    step: 81
dev_score:  0.8171876429947745 dev_output:  {'dev_F1': 81.71876429947746, 'dev_P': 90.92009575157269, 'dev_R': 74.2094854327126}
test_score:  0.8008402904989914 test_output:  {'test_F1': 80.08402904989914, 'test_P': 91.65658897634113, 'test_R': 71.10694117160467}


epoch:   7%|▋         | 2/30 [02:40<37:31, 80.41s/it]

training risk: 10.116876602172852    step: 122
dev_score:  0.8293971361901716 dev_output:  {'dev_F1': 82.93971361901717, 'dev_P': 93.21824792579227, 'dev_R': 74.70355657407552}
test_score:  0.8247480406328663 test_output:  {'test_F1': 82.47480406328663, 'test_P': 92.53208760172593, 'test_R': 74.39024320459434}


epoch:  10%|█         | 3/30 [04:04<36:55, 82.07s/it]

training risk: 9.820235252380371    step: 163
dev_score:  0.8402254965376281 dev_output:  {'dev_F1': 84.02254965376281, 'dev_P': 89.40914058629721, 'dev_R': 79.24901107461451}
test_score:  0.8482321289406123 test_output:  {'test_F1': 84.82321289406123, 'test_P': 93.15375877492976, 'test_R': 77.861162496612}


epoch:  13%|█▎        | 4/30 [05:25<35:21, 81.60s/it]

training risk: 8.374422073364258    step: 204
dev_score:  0.8574277701849586 dev_output:  {'dev_F1': 85.74277701849586, 'dev_P': 88.11261639090077, 'dev_R': 83.49802289033575}
test_score:  0.8531696112937529 test_output:  {'test_F1': 85.31696112937529, 'test_P': 90.52631483656512, 'test_R': 80.6754213820317}


epoch:  17%|█▋        | 5/30 [06:45<33:44, 81.00s/it]

training risk: 9.408763885498047    step: 245
dev_score:  0.8163216300085845 dev_output:  {'dev_F1': 81.63216300085845, 'dev_P': 95.74467957786331, 'dev_R': 71.14624435626241}
test_score:  0.8110102243505771 test_output:  {'test_F1': 81.10102243505771, 'test_P': 95.54707257573699, 'test_R': 70.45028076500674}


epoch:  20%|██        | 6/30 [08:04<32:11, 80.49s/it]

training risk: 9.6510591506958    step: 286
dev_score:  0.8351478842586991 dev_output:  {'dev_F1': 83.51478842586991, 'dev_P': 93.29268178911364, 'dev_R': 75.59288462852881}
test_score:  0.8320283317685377 test_output:  {'test_F1': 83.20283317685377, 'test_P': 93.34889039266172, 'test_R': 75.04690361119228}


epoch:  23%|██▎       | 7/30 [09:21<30:25, 79.36s/it]

training risk: 9.52318000793457    step: 327
dev_score:  0.8594167266587668 dev_output:  {'dev_F1': 85.94167266587668, 'dev_P': 86.7203210591517, 'dev_R': 85.17786477096972}
test_score:  0.8481574724375613 test_output:  {'test_F1': 84.81574724375614, 'test_P': 87.5249492262979, 'test_R': 82.27016808376953}


epoch:  27%|██▋       | 8/30 [10:43<29:23, 80.18s/it]

training risk: 9.483315467834473    step: 368
dev_score:  0.8716912952580538 dev_output:  {'dev_F1': 87.16912952580537, 'dev_P': 84.84564934662629, 'dev_R': 89.62450504323611}
test_score:  0.8528813983245486 test_output:  {'test_F1': 85.28813983245486, 'test_P': 84.65803988301258, 'test_R': 85.92870463481516}


epoch:  30%|███       | 9/30 [12:05<28:11, 80.55s/it]

training risk: 9.157684326171875    step: 409
dev_score:  0.8533140835796414 dev_output:  {'dev_F1': 85.33140835796415, 'dev_P': 93.10747554780987, 'dev_R': 78.75493993325159}
test_score:  0.8392715289677375 test_output:  {'test_F1': 83.92715289677375, 'test_P': 93.44073540344378, 'test_R': 76.17260716536016}


epoch:  33%|███▎      | 10/30 [13:23<26:36, 79.81s/it]

training risk: 11.858595848083496    step: 450
dev_score:  0.8572883249337049 dev_output:  {'dev_F1': 85.72883249337049, 'dev_P': 91.770010239346, 'dev_R': 80.43478181388556}
test_score:  0.8424223624503385 test_output:  {'test_F1': 84.24223624503385, 'test_P': 92.2905017621173, 'test_R': 77.48592797855602}


epoch:  37%|███▋      | 11/30 [14:42<25:10, 79.50s/it]

training risk: 9.78331470489502    step: 491
dev_score:  0.8666954507421748 dev_output:  {'dev_F1': 86.66954507421748, 'dev_P': 88.96982217513192, 'dev_R': 84.48616517306161}
test_score:  0.8582705277963791 test_output:  {'test_F1': 85.82705277963791, 'test_P': 90.96638559909259, 'test_R': 81.23827315911564}


epoch:  40%|████      | 12/30 [16:01<23:53, 79.62s/it]

training risk: 9.06671142578125    step: 532
dev_score:  0.865783181683025 dev_output:  {'dev_F1': 86.5783181683025, 'dev_P': 85.5351968607985, 'dev_R': 87.64822047778439}
test_score:  0.857544850197207 test_output:  {'test_F1': 85.7544850197207, 'test_P': 86.82692224204884, 'test_R': 84.70919245113329}


epoch:  43%|████▎     | 13/30 [17:21<22:31, 79.50s/it]

training risk: 8.35084342956543    step: 573
dev_score:  0.8681491529156229 dev_output:  {'dev_F1': 86.81491529156229, 'dev_P': 89.16666573784724, 'dev_R': 84.5849794013342}
test_score:  0.8607918553603531 test_output:  {'test_F1': 86.07918553603531, 'test_P': 90.48603836105441, 'test_R': 82.08255082474155}


epoch:  47%|████▋     | 14/30 [18:37<20:59, 78.69s/it]

training risk: 9.855271339416504    step: 614
dev_score:  0.86302280644055 dev_output:  {'dev_F1': 86.302280644055, 'dev_P': 90.10752591282231, 'dev_R': 82.80632329242763}
test_score:  0.8636766018950229 test_output:  {'test_F1': 86.36766018950229, 'test_P': 91.94915156833527, 'test_R': 81.42589041814362}


epoch:  50%|█████     | 15/30 [19:58<19:46, 79.11s/it]

training risk: 11.243638038635254    step: 655
dev_score:  0.8668588968247809 dev_output:  {'dev_F1': 86.68588968247809, 'dev_P': 86.51574717996311, 'dev_R': 86.85770665160369}
test_score:  0.8596187295049491 test_output:  {'test_F1': 85.96187295049491, 'test_P': 88.48063467248625, 'test_R': 83.58348889696539}


epoch:  53%|█████▎    | 16/30 [21:16<18:23, 78.83s/it]

training risk: 8.956011772155762    step: 696
dev_score:  0.8697305452778235 dev_output:  {'dev_F1': 86.97305452778235, 'dev_P': 86.21359139598455, 'dev_R': 87.74703470605697}
test_score:  0.8637186028376272 test_output:  {'test_F1': 86.37186028376273, 'test_P': 88.4086435323316, 'test_R': 84.42776656259132}


epoch:  57%|█████▋    | 17/30 [22:35<17:08, 79.10s/it]

training risk: 10.352824211120605    step: 737
dev_score:  0.8572894478290825 dev_output:  {'dev_F1': 85.72894478290824, 'dev_P': 92.55440901999532, 'dev_R': 79.84189644425004}
test_score:  0.8552447740065416 test_output:  {'test_F1': 85.52447740065416, 'test_P': 93.63839181207152, 'test_R': 78.70544016223789}


epoch:  60%|██████    | 18/30 [23:55<15:50, 79.25s/it]

training risk: 8.790007591247559    step: 778
dev_score:  0.8715824560029128 dev_output:  {'dev_F1': 87.15824560029128, 'dev_P': 89.23395352759883, 'dev_R': 85.17786477096972}
test_score:  0.8680675865159154 test_output:  {'test_F1': 86.80675865159154, 'test_P': 90.95580584834732, 'test_R': 83.02063711988146}


epoch:  63%|██████▎   | 19/30 [25:15<14:33, 79.44s/it]

training risk: 10.92805004119873    step: 819
dev_score:  0.8746127331089403 dev_output:  {'dev_F1': 87.46127331089403, 'dev_P': 90.3157885229917, 'dev_R': 84.78260785787937}
test_score:  0.867524890933107 test_output:  {'test_F1': 86.7524890933107, 'test_P': 92.46284402905685, 'test_R': 81.70731630668558}


epoch:  67%|██████▋   | 20/30 [26:37<13:22, 80.27s/it]

training risk: 9.54214096069336    step: 860
dev_score:  0.8651065565976978 dev_output:  {'dev_F1': 86.51065565976978, 'dev_P': 88.85416574110245, 'dev_R': 84.28853671651643}
test_score:  0.8640582783696148 test_output:  {'test_F1': 86.40582783696148, 'test_P': 91.32706278655107, 'test_R': 81.98874219522756}


epoch:  70%|███████   | 21/30 [27:57<12:00, 80.03s/it]

training risk: 9.21198844909668    step: 901
dev_score:  0.8672959436839064 dev_output:  {'dev_F1': 86.72959436839064, 'dev_P': 87.91878083331187, 'dev_R': 85.57312168406006}
test_score:  0.8655206729850626 test_output:  {'test_F1': 86.55206729850626, 'test_P': 90.39836475589004, 'test_R': 83.02063711988146}


epoch:  73%|███████▎  | 22/30 [29:15<10:36, 79.58s/it]

training risk: 9.658025741577148    step: 942
dev_score:  0.8679195198350449 dev_output:  {'dev_F1': 86.79195198350449, 'dev_P': 87.22554803168116, 'dev_R': 86.36363551024075}
test_score:  0.8703653695155436 test_output:  {'test_F1': 87.03653695155435, 'test_P': 90.56795039991937, 'test_R': 83.77110615599338}


epoch:  77%|███████▋  | 23/30 [30:34<09:15, 79.37s/it]

training risk: 9.782114028930664    step: 983
dev_score:  0.8675141789679208 dev_output:  {'dev_F1': 86.75141789679208, 'dev_P': 89.92576786929196, 'dev_R': 83.79446557515351}
test_score:  0.8636994554785806 test_output:  {'test_F1': 86.36994554785807, 'test_P': 92.31590082907256, 'test_R': 81.14446452960165}


epoch:  80%|████████  | 24/30 [31:53<07:56, 79.37s/it]

training risk: 9.308781623840332    step: 1024
dev_score:  0.8692989481705401 dev_output:  {'dev_F1': 86.929894817054, 'dev_P': 88.12182651652968, 'dev_R': 85.77075014060523}
test_score:  0.8687910716463422 test_output:  {'test_F1': 86.87910716463422, 'test_P': 91.22806923397245, 'test_R': 82.92682849036747}


epoch:  83%|████████▎ | 25/30 [33:12<06:36, 79.26s/it]

training risk: 9.75108528137207    step: 1065
dev_score:  0.8707230746435455 dev_output:  {'dev_F1': 87.07230746435455, 'dev_P': 87.28897629305882, 'dev_R': 86.85770665160369}
test_score:  0.8727400017544147 test_output:  {'test_F1': 87.27400017544147, 'test_P': 90.86294323997012, 'test_R': 83.95872341502137}


epoch:  87%|████████▋ | 26/30 [34:32<05:17, 79.39s/it]

training risk: 9.383520126342773    step: 1106
dev_score:  0.8714375824879648 dev_output:  {'dev_F1': 87.14375824879647, 'dev_P': 86.54970675877479, 'dev_R': 87.74703470605697}
test_score:  0.8730032045304789 test_output:  {'test_F1': 87.30032045304789, 'test_P': 89.95024786119157, 'test_R': 84.80300108064726}


epoch:  90%|█████████ | 27/30 [35:51<03:57, 79.27s/it]

training risk: 9.542134284973145    step: 1147
dev_score:  0.8713601960041967 dev_output:  {'dev_F1': 87.13601960041967, 'dev_P': 86.92231969594573, 'dev_R': 87.35177779296663}
test_score:  0.8716656990290245 test_output:  {'test_F1': 87.16656990290245, 'test_P': 90.09008918828741, 'test_R': 84.42776656259132}


epoch:  93%|█████████▎| 28/30 [37:08<02:37, 78.61s/it]

training risk: 9.840618133544922    step: 1188
dev_score:  0.8708036699612896 dev_output:  {'dev_F1': 87.08036699612896, 'dev_P': 86.90944796348968, 'dev_R': 87.25296356469404}
test_score:  0.8720880202610122 test_output:  {'test_F1': 87.20880202610121, 'test_P': 90.18035981783207, 'test_R': 84.42776656259132}


epoch:  97%|█████████▋| 29/30 [38:27<01:18, 78.73s/it]

training risk: 9.796359062194824    step: 1229
dev_score:  0.8711061025374628 dev_output:  {'dev_F1': 87.11061025374627, 'dev_P': 87.06811365184488, 'dev_R': 87.15414933642145}
test_score:  0.8710412272815966 test_output:  {'test_F1': 87.10412272815967, 'test_P': 90.49544903442418, 'test_R': 83.95872341502137}


epoch: 100%|██████████| 30/30 [39:46<00:00, 79.54s/it]


BEST TEST
{'dev_F1': 87.46127331089403, 'dev_P': 90.3157885229917, 'dev_R': 84.78260785787937}
generate predict result in ./result/cdr/cdr_BSCELoss_tree_s0=0.3_dropout=0.5_66
./result/cdr/cdr_BSCELoss_tree_s0=0.3_dropout=0.5_66.pubtator
{'test_F1': 86.7524890933107, 'test_P': 92.46284402905685, 'test_R': 81.70731630668558}
