### Train on the BioRED dataset

In [1]:
import argparse
import sys
import os

import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer
from model_bio import Model
from utils import set_seed
from prepro_bio import read_bio
from save_result import Logger
from evaluation import train, evaluate

parser = argparse.ArgumentParser()
parser.add_argument("--task", default="biored_cd", type=str)
parser.add_argument("--data_dir", default="./dataset/biored_cd", type=str)
parser.add_argument("--transformer_type", default="bert", type=str)
parser.add_argument("--model_name_or_path", default="bert-base-cased", type=str)
parser.add_argument("--train_file", default="Train.BioC.JSON", type=str)
parser.add_argument("--dev_file", default="Dev.BioC.JSON", type=str)
parser.add_argument("--test_file", default="Test.BioC.JSON", type=str)
parser.add_argument("--save_path", default="out/rel2/", type=str)
parser.add_argument("--load_path", default="", type=str)
parser.add_argument("--config_name", default="", type=str,
                    help="Pretrained config name or path if not the same as model_name")
parser.add_argument("--tokenizer_name", default="", type=str,
                    help="Pretrained tokenizer name or path if not the same as model_name")
parser.add_argument("--max_seq_length", default=1024, type=int,
                    help="The maximum total input sequence length after tokenization. Sequences longer "
                         "than this will be truncated, sequences shorter will be padded.")
parser.add_argument("--train_batch_size", default=4, type=int, help="Batch size for training.")
parser.add_argument("--test_batch_size", default=8, type=int, help="Batch size for testing.")
parser.add_argument("--gradient_accumulation_steps", default=1, type=int,
                    help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--adam_epsilon", default=1e-6, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--warmup_ratio", default=0.06, type=float, help="Warm up ratio for Adam.")
parser.add_argument("--num_train_epochs", default=30.0, type=float, help="Total number of training epochs to perform.")
parser.add_argument("--evaluation_steps", default=-1, type=int, help="Number of training steps between evaluations.")
parser.add_argument("--seed", type=int, default=68, help="random seed for initialization")
parser.add_argument("--num_class", type=int, default=4, help="Number of relation types in dataset.")
parser.add_argument('--gnn', type=str, default='GCN', help="GCN/GAT")
parser.add_argument('--use_gcn', type=str, default='tree', help="use gcn, both/mentions/tree/false")
parser.add_argument('--dropout', type=float, default=0.5, help="0.0/0.2/0.5")
parser.add_argument('--loss', type=str, default='BSCELoss',
                    help="use BSCELoss/BalancedLoss/ATLoss/AsymmetricLoss/APLLoss")
parser.add_argument('--s0', type=float, default=0.3)
parser.add_argument("--demo", type=str, default='false', help='use a few data to test. default true/false')
parser.add_argument("--unet_in_dim", type=int, default=3, help="unet_in_dim.")
parser.add_argument("--unet_out_dim", type=int, default=256, help="unet_out_dim.")
parser.add_argument("--down_dim", type=int, default=256, help="down_dim.")
parser.add_argument("--bert_lr", default=3e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--max_height", type=int, default=64, help="max_height.")
parser.add_argument("--rel2", type=int, default=1, help="")
parser.add_argument("--save_result", type=str, default="", help="save predict result.")
args, _ = parser.parse_known_args()

if args.task == 'biored_cd':
    args.data_dir = './dataset/biored_cd'
    args.train_file = 'train+dev.data'
    args.dev_file = 'test.data'
    args.test_file = 'test.data'
    args.model_name_or_path = '/data/pretrained/BiomedNLP-PubMedBERT-base-uncased-abstract'
    args.train_batch_size = 12
    args.test_batch_size = 12
    args.learning_rate = 2e-5
    args.num_class = 4
    args.num_train_epochs = 30
    if args.rel2:
        args.train_file = 'train+dev.data'
        args.dev_file = 'test.data'
        args.num_class = 2

if not os.path.exists(args.save_path):
    os.mkdir(args.save_path)
file_name = "{}_{}_{}_seed_{}_{}_{}_{}_{}".format(
    args.train_file.split('.')[0], 
    args.transformer_type, args.data_dir.split('/')[-1],
    args.loss, args.use_gcn, args.s0, args.dropout, str(args.seed), )
args.save_path = os.path.join(args.save_path, file_name)
args.save_pubtator = os.path.join('./result/' + args.task + '/' + args.task + '_' + args.loss
                                  + '_' + str(args.use_gcn) + '_s0=' + str(args.s0)
                                  + '_dropout=' + str(args.dropout))
if args.load_path == "":
    sys.stdout = Logger(stream=sys.stdout, filename=args.save_pubtator + '_test.log')
read = read_bio
print(args)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.n_gpu = torch.cuda.device_count()
args.device = device

config = AutoConfig.from_pretrained(
    args.config_name if args.config_name else args.model_name_or_path, num_labels=args.num_class, )
tokenizer = AutoTokenizer.from_pretrained(
    args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, )
model = AutoModel.from_pretrained(
    args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, )
config.cls_token_id = tokenizer.cls_token_id
config.sep_token_id = tokenizer.sep_token_id
config.transformer_type = args.transformer_type
set_seed(args)
model = Model(args, config, model, num_labels=1)
model.to(0)

# Training
train_file = os.path.join(args.data_dir, args.train_file)
dev_file = os.path.join(args.data_dir, args.dev_file)
test_file = os.path.join(args.data_dir, args.test_file)
train_features = read(args, train_file, tokenizer, max_seq_length=args.max_seq_length)
dev_features = read(args, dev_file, tokenizer, max_seq_length=args.max_seq_length)
test_features = read(args, test_file, tokenizer, max_seq_length=args.max_seq_length)
train(args, model, train_features, dev_features, test_features)

print("TEST")
# Here is the final training result obtained
model.load_state_dict(torch.load(args.save_path))
test_score, test_output = evaluate(args, model, test_features, tag="test", generate=True)
print(test_output)

Namespace(adam_epsilon=1e-06, bert_lr=3e-05, config_name='', data_dir='./dataset/biored_cd', demo='false', dev_file='test.data', down_dim=256, dropout=0.5, evaluation_steps=-1, gnn='GCN', gradient_accumulation_steps=1, learning_rate=2e-05, load_path='', loss='BSCELoss', max_grad_norm=1.0, max_height=64, max_seq_length=1024, model_name_or_path='/data/pretrained/BiomedNLP-PubMedBERT-base-uncased-abstract', num_class=2, num_train_epochs=30, rel2=1, s0=0.3, save_path='out/rel2/train+dev_bert_biored_cd_seed_BSCELoss_tree_0.3_0.5_68', save_pubtator='./result/biored_cd/biored_cd_BSCELoss_tree_s0=0.3_dropout=0.5', save_result='', seed=68, task='biored_cd', test_batch_size=12, test_file='test.data', tokenizer_name='', train_batch_size=12, train_file='train+dev.data', transformer_type='bert', unet_in_dim=3, unet_out_dim=256, use_gcn='tree', warmup_ratio=0.06)


100%|██████████| 490/490 [00:29<00:00, 16.41it/s]
100%|██████████| 100/100 [00:05<00:00, 17.85it/s]
100%|██████████| 100/100 [00:06<00:00, 14.95it/s]
epoch:   0%|          | 0/30 [00:00<?, ?it/s]

Total steps: 810
Warmup steps: 48
training risk: 10.286075592041016    step: 26
dev_score:  0.437271583913276 dev_output:  {'dev_F1': 43.7271583913276, 'dev_P': 67.77777024691441, 'dev_R': 32.27513056745341}
test_score:  0.437271583913276 test_output:  {'test_F1': 43.7271583913276, 'test_P': 67.77777024691441, 'test_R': 32.27513056745341}


epoch:   3%|▎         | 1/30 [00:41<19:54, 41.18s/it]

training risk: 11.302806854248047    step: 53
dev_score:  0.48201000081557827 dev_output:  {'dev_F1': 48.20100008155783, 'dev_P': 75.28089041787747, 'dev_R': 35.44973357408817}
test_score:  0.48201000081557827 test_output:  {'test_F1': 48.20100008155783, 'test_P': 75.28089041787747, 'test_R': 35.44973357408817}


epoch:   7%|▋         | 2/30 [01:21<18:55, 40.57s/it]

training risk: 12.045235633850098    step: 80
dev_score:  0.5964862486302929 dev_output:  {'dev_F1': 59.64862486302928, 'dev_P': 66.66666230936848, 'dev_R': 53.968251112790945}
test_score:  0.5964862486302929 test_output:  {'test_F1': 59.64862486302928, 'test_P': 66.66666230936848, 'test_R': 53.968251112790945}


epoch:  10%|█         | 3/30 [02:00<18:01, 40.07s/it]

training risk: 10.101417541503906    step: 107
dev_score:  0.6473938493145766 dev_output:  {'dev_F1': 64.73938493145765, 'dev_P': 71.3375750740398, 'dev_R': 59.25925612384888}
test_score:  0.6473938493145766 test_output:  {'test_F1': 64.73938493145765, 'test_P': 71.3375750740398, 'test_R': 59.25925612384888}


epoch:  13%|█▎        | 4/30 [02:38<16:52, 38.95s/it]

training risk: 11.741877555847168    step: 134
dev_score:  0.6585316614356718 dev_output:  {'dev_F1': 65.85316614356718, 'dev_P': 77.69783613684632, 'dev_R': 57.142854119425714}
test_score:  0.6585316614356718 test_output:  {'test_F1': 65.85316614356718, 'test_P': 77.69783613684632, 'test_R': 57.142854119425714}


epoch:  17%|█▋        | 5/30 [03:14<15:54, 38.20s/it]

training risk: 10.182096481323242    step: 161
dev_score:  0.6745512431647662 dev_output:  {'dev_F1': 67.45512431647663, 'dev_P': 76.51006197919047, 'dev_R': 60.31745712606047}
test_score:  0.6745512431647662 test_output:  {'test_F1': 67.45512431647663, 'test_P': 76.51006197919047, 'test_R': 60.31745712606047}


epoch:  20%|██        | 6/30 [03:55<15:32, 38.87s/it]

training risk: 10.241240501403809    step: 188
dev_score:  0.7337228372937191 dev_output:  {'dev_F1': 73.37228372937192, 'dev_P': 83.2214709247335, 'dev_R': 65.60846213711841}
test_score:  0.7337228372937191 test_output:  {'test_F1': 73.37228372937192, 'test_P': 83.2214709247335, 'test_R': 65.60846213711841}


epoch:  23%|██▎       | 7/30 [04:30<14:30, 37.86s/it]

training risk: 10.387145042419434    step: 215
dev_score:  0.6900910182123646 dev_output:  {'dev_F1': 69.00910182123646, 'dev_P': 87.09676716961555, 'dev_R': 57.142854119425714}
test_score:  0.6900910182123646 test_output:  {'test_F1': 69.00910182123646, 'test_P': 87.09676716961555, 'test_R': 57.142854119425714}


epoch:  27%|██▋       | 8/30 [05:07<13:43, 37.44s/it]

training risk: 9.91720199584961    step: 242
dev_score:  0.729880052715735 dev_output:  {'dev_F1': 72.9880052715735, 'dev_P': 79.87420881294284, 'dev_R': 67.19576364043579}
test_score:  0.729880052715735 test_output:  {'test_F1': 72.9880052715735, 'test_P': 79.87420881294284, 'test_R': 67.19576364043579}


epoch:  30%|███       | 9/30 [05:42<12:50, 36.71s/it]

training risk: 9.683487892150879    step: 269
dev_score:  0.7306452545635835 dev_output:  {'dev_F1': 73.06452545635835, 'dev_P': 88.05969492091829, 'dev_R': 62.43385913048364}
test_score:  0.7306452545635835 test_output:  {'test_F1': 73.06452545635835, 'test_P': 88.05969492091829, 'test_R': 62.43385913048364}


epoch:  33%|███▎      | 10/30 [06:18<12:09, 36.48s/it]

training risk: 9.77787971496582    step: 296
dev_score:  0.7692257343941973 dev_output:  {'dev_F1': 76.92257343941972, 'dev_P': 79.9999954285717, 'dev_R': 74.0740701548111}
test_score:  0.7692257343941973 test_output:  {'test_F1': 76.92257343941972, 'test_P': 79.9999954285717, 'test_R': 74.0740701548111}


epoch:  37%|███▋      | 11/30 [06:56<11:39, 36.82s/it]

training risk: 10.347344398498535    step: 323
dev_score:  0.7650222860071088 dev_output:  {'dev_F1': 76.50222860071088, 'dev_P': 79.09604072903724, 'dev_R': 74.0740701548111}
test_score:  0.7650222860071088 test_output:  {'test_F1': 76.50222860071088, 'test_P': 79.09604072903724, 'test_R': 74.0740701548111}


epoch:  40%|████      | 12/30 [07:32<10:58, 36.61s/it]

training risk: 9.497493743896484    step: 350
dev_score:  0.8054003642384638 dev_output:  {'dev_F1': 80.54003642384639, 'dev_P': 82.3204374408598, 'dev_R': 78.83597466476324}
test_score:  0.8054003642384638 test_output:  {'test_F1': 80.54003642384639, 'test_P': 82.3204374408598, 'test_R': 78.83597466476324}


epoch:  43%|████▎     | 13/30 [08:08<10:21, 36.57s/it]

training risk: 9.565771102905273    step: 377
dev_score:  0.7603255451913423 dev_output:  {'dev_F1': 76.03255451913424, 'dev_P': 79.31034026952068, 'dev_R': 73.0158691525995}
test_score:  0.7603255451913423 test_output:  {'test_F1': 76.03255451913424, 'test_P': 79.31034026952068, 'test_R': 73.0158691525995}


epoch:  47%|████▋     | 14/30 [08:46<09:51, 36.96s/it]

training risk: 10.11739444732666    step: 404
dev_score:  0.8195825899709468 dev_output:  {'dev_F1': 81.95825899709467, 'dev_P': 79.89949347238726, 'dev_R': 84.12697967582118}
test_score:  0.8195825899709468 test_output:  {'test_F1': 81.95825899709467, 'test_P': 79.89949347238726, 'test_R': 84.12697967582118}


epoch:  50%|█████     | 15/30 [09:23<09:15, 37.00s/it]

training risk: 11.12531566619873    step: 431
dev_score:  0.795964744684172 dev_output:  {'dev_F1': 79.5964744684172, 'dev_P': 75.9615348095416, 'dev_R': 83.59787917471539}
test_score:  0.795964744684172 test_output:  {'test_F1': 79.5964744684172, 'test_P': 75.9615348095416, 'test_R': 83.59787917471539}


epoch:  53%|█████▎    | 16/30 [09:59<08:34, 36.74s/it]

training risk: 10.848336219787598    step: 458
dev_score:  0.785510290609779 dev_output:  {'dev_F1': 78.5510290609779, 'dev_P': 82.94117159169578, 'dev_R': 74.6031706559169}
test_score:  0.785510290609779 test_output:  {'test_F1': 78.5510290609779, 'test_P': 82.94117159169578, 'test_R': 74.6031706559169}


epoch:  57%|█████▋    | 17/30 [10:34<07:49, 36.10s/it]

training risk: 9.927074432373047    step: 485
dev_score:  0.7899109394660279 dev_output:  {'dev_F1': 78.99109394660279, 'dev_P': 83.92856643282343, 'dev_R': 74.6031706559169}
test_score:  0.7899109394660279 test_output:  {'test_F1': 78.99109394660279, 'test_P': 83.92856643282343, 'test_R': 74.6031706559169}


epoch:  60%|██████    | 18/30 [11:10<07:12, 36.08s/it]

training risk: 9.334772109985352    step: 512
dev_score:  0.7921298060848485 dev_output:  {'dev_F1': 79.21298060848486, 'dev_P': 84.43113266879445, 'dev_R': 74.6031706559169}
test_score:  0.7921298060848485 test_output:  {'test_F1': 79.21298060848486, 'test_P': 84.43113266879445, 'test_R': 74.6031706559169}


epoch:  63%|██████▎   | 19/30 [11:45<06:34, 35.85s/it]

training risk: 9.703449249267578    step: 539
dev_score:  0.8063610050335027 dev_output:  {'dev_F1': 80.63610050335028, 'dev_P': 80.85105952919896, 'dev_R': 80.42327616808063}
test_score:  0.8063610050335027 test_output:  {'test_F1': 80.63610050335028, 'test_P': 80.85105952919896, 'test_R': 80.42327616808063}


epoch:  67%|██████▋   | 20/30 [12:21<05:59, 35.98s/it]

training risk: 10.270545959472656    step: 566
dev_score:  0.8159949568306355 dev_output:  {'dev_F1': 81.59949568306355, 'dev_P': 82.25806009365269, 'dev_R': 80.95237666918642}
test_score:  0.8159949568306355 test_output:  {'test_F1': 81.59949568306355, 'test_P': 82.25806009365269, 'test_R': 80.95237666918642}


epoch:  70%|███████   | 21/30 [12:57<05:22, 35.80s/it]

training risk: 9.38011360168457    step: 593
dev_score:  0.7880384391562064 dev_output:  {'dev_F1': 78.80384391562065, 'dev_P': 81.00558206672726, 'dev_R': 76.71957266034008}
test_score:  0.7880384391562064 test_output:  {'test_F1': 78.80384391562065, 'test_P': 81.00558206672726, 'test_R': 76.71957266034008}


epoch:  73%|███████▎  | 22/30 [13:32<04:44, 35.62s/it]

training risk: 9.747406959533691    step: 620
dev_score:  0.7923446888845417 dev_output:  {'dev_F1': 79.23446888845417, 'dev_P': 81.92089932650286, 'dev_R': 76.71957266034008}
test_score:  0.7923446888845417 test_output:  {'test_F1': 79.23446888845417, 'test_P': 81.92089932650286, 'test_R': 76.71957266034008}


epoch:  77%|███████▋  | 23/30 [14:07<04:07, 35.35s/it]

training risk: 10.042274475097656    step: 647
dev_score:  0.795750926021188 dev_output:  {'dev_F1': 79.5750926021188, 'dev_P': 79.7872297985516, 'dev_R': 79.36507516586903}
test_score:  0.795750926021188 test_output:  {'test_F1': 79.5750926021188, 'test_P': 79.7872297985516, 'test_R': 79.36507516586903}


epoch:  80%|████████  | 24/30 [14:43<03:32, 35.46s/it]

training risk: 9.912673950195312    step: 674
dev_score:  0.8165324282366897 dev_output:  {'dev_F1': 81.65324282366898, 'dev_P': 79.79797576777901, 'dev_R': 83.59787917471539}
test_score:  0.8165324282366897 test_output:  {'test_F1': 81.65324282366898, 'test_P': 79.79797576777901, 'test_R': 83.59787917471539}


epoch:  83%|████████▎ | 25/30 [15:18<02:57, 35.55s/it]

training risk: 8.740376472473145    step: 701
dev_score:  0.8032294598577544 dev_output:  {'dev_F1': 80.32294598577543, 'dev_P': 81.86812736988311, 'dev_R': 78.83597466476324}
test_score:  0.8032294598577544 test_output:  {'test_F1': 80.32294598577543, 'test_P': 81.86812736988311, 'test_R': 78.83597466476324}


epoch:  87%|████████▋ | 26/30 [15:54<02:22, 35.70s/it]

training risk: 8.527180671691895    step: 728
dev_score:  0.8074815884378298 dev_output:  {'dev_F1': 80.74815884378297, 'dev_P': 81.62161720964232, 'dev_R': 79.89417566697483}
test_score:  0.8074815884378298 test_output:  {'test_F1': 80.74815884378297, 'test_P': 81.62161720964232, 'test_R': 79.89417566697483}


epoch:  90%|█████████ | 27/30 [16:29<01:46, 35.39s/it]

training risk: 9.348267555236816    step: 755
dev_score:  0.8085055954643755 dev_output:  {'dev_F1': 80.85055954643755, 'dev_P': 81.2834181131862, 'dev_R': 80.42327616808063}
test_score:  0.8085055954643755 test_output:  {'test_F1': 80.85055954643755, 'test_P': 81.2834181131862, 'test_R': 80.42327616808063}


epoch:  93%|█████████▎| 28/30 [17:05<01:11, 35.56s/it]

training risk: 9.108833312988281    step: 782
dev_score:  0.811822914674092 dev_output:  {'dev_F1': 81.1822914674092, 'dev_P': 82.5136566932428, 'dev_R': 79.89417566697483}
test_score:  0.811822914674092 test_output:  {'test_F1': 81.1822914674092, 'test_P': 82.5136566932428, 'test_R': 79.89417566697483}


epoch:  97%|█████████▋| 29/30 [17:39<00:35, 35.23s/it]

training risk: 9.066167831420898    step: 809
dev_score:  0.811822914674092 dev_output:  {'dev_F1': 81.1822914674092, 'dev_P': 82.5136566932428, 'dev_R': 79.89417566697483}
test_score:  0.811822914674092 test_output:  {'test_F1': 81.1822914674092, 'test_P': 82.5136566932428, 'test_R': 79.89417566697483}


epoch: 100%|██████████| 30/30 [18:15<00:00, 36.53s/it]


TEST
generate predict result in ./result/biored_cd/biored_cd_BSCELoss_tree_s0=0.3_dropout=0.5
./result/biored_cd/biored_cd_BSCELoss_tree_s0=0.3_dropout=0.5.pubtator
{'test_F1': 81.1822914674092, 'test_P': 82.5136566932428, 'test_R': 79.89417566697483}
