### Train on the BioRED dataset

In [1]:
import argparse
import sys
import datetime
import os
import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer
from model_bio import Model
from utils import set_seed
from prepro_bio import read_bio
from save_result import Logger
from evaluation import train, evaluate

parser = argparse.ArgumentParser()
parser.add_argument("--task", default="biored_cd", type=str)
parser.add_argument("--data_dir", default="./dataset/biored_cd", type=str)
parser.add_argument("--transformer_type", default="bert", type=str)
parser.add_argument("--model_name_or_path", default="bert-base-cased", type=str)
parser.add_argument("--train_file", default="Train.BioC.JSON", type=str)
parser.add_argument("--dev_file", default="Dev.BioC.JSON", type=str)
parser.add_argument("--test_file", default="Test.BioC.JSON", type=str)
parser.add_argument("--save_path", default="out", type=str)
parser.add_argument("--load_path", default="", type=str)
parser.add_argument("--config_name", default="", type=str,
                    help="Pretrained config name or path if not the same as model_name")
parser.add_argument("--tokenizer_name", default="", type=str,
                    help="Pretrained tokenizer name or path if not the same as model_name")
parser.add_argument("--max_seq_length", default=1024, type=int,
                    help="The maximum total input sequence length after tokenization. Sequences longer "
                         "than this will be truncated, sequences shorter will be padded.")
parser.add_argument("--train_batch_size", default=4, type=int, help="Batch size for training.")
parser.add_argument("--test_batch_size", default=8, type=int, help="Batch size for testing.")
parser.add_argument("--gradient_accumulation_steps", default=1, type=int,
                    help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--adam_epsilon", default=1e-6, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--warmup_ratio", default=0.06, type=float, help="Warm up ratio for Adam.")
parser.add_argument("--num_train_epochs", default=30.0, type=float, help="Total number of training epochs to perform.")
parser.add_argument("--evaluation_steps", default=-1, type=int, help="Number of training steps between evaluations.")
parser.add_argument("--seed", type=int, default=66, help="random seed for initialization")
parser.add_argument("--num_class", type=int, default=4, help="Number of relation types in dataset.")
parser.add_argument('--gnn', type=str, default='GCN', help="GCN/GAT")
parser.add_argument('--use_gcn', type=str, default='tree', help="use gcn, both/mentions/tree/false")
parser.add_argument('--dropout', type=float, default=0.5, help="0.0/0.2/0.5")
parser.add_argument('--loss', type=str, default='BSCELoss',
                    help="use BSCELoss/BalancedLoss/ATLoss/AsymmetricLoss/APLLoss")
parser.add_argument('--s0', type=float, default=0.3)
parser.add_argument("--demo", type=str, default='false', help='use a few data to test. default true/false')
parser.add_argument("--unet_in_dim", type=int, default=3, help="unet_in_dim.")
parser.add_argument("--unet_out_dim", type=int, default=256, help="unet_out_dim.")
parser.add_argument("--down_dim", type=int, default=256, help="down_dim.")
parser.add_argument("--bert_lr", default=3e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--max_height", type=int, default=64, help="max_height.")
parser.add_argument("--rel2", type=int, default=0, help="")
parser.add_argument("--save_result", type=str, default="", help="save predict result.")
args, _ = parser.parse_known_args()

if args.task == 'biored_cd':
    args.data_dir = './dataset/biored_cd'
    args.train_file = 'train+dev.data'
    args.dev_file = 'test.data'
    args.test_file = 'test.data'
    args.model_name_or_path = '/data/pretrained/BiomedNLP-PubMedBERT-base-uncased-abstract'
    args.train_batch_size = 12
    args.test_batch_size = 12
    args.learning_rate = 2e-5
    args.num_class = 4
    args.num_train_epochs = 30
    if args.rel2:
        args.train_file = 'train+dev.data'
        args.dev_file = 'test.data'
        args.num_class = 2

if not os.path.exists(args.save_path):
    os.mkdir(args.save_path)
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M")
file_name = "{}_{}_{}_{}_seed_{}_{}_{}_{}_{}".format(
    args.train_file.split('.')[0], timestamp,
    args.transformer_type, args.data_dir.split('/')[-1],
    args.loss, args.use_gcn, args.s0, args.dropout, str(args.seed), )
args.save_path = os.path.join(args.save_path, file_name)
args.save_pubtator = os.path.join('./result/' + args.task + '/' + args.task + '_' + timestamp + '_' + args.loss
                                  + '_' + str(args.use_gcn) + '_s0=' + str(args.s0)
                                  + '_dropout=' + str(args.dropout))
if args.load_path == "":
    sys.stdout = Logger(stream=sys.stdout, filename=args.save_pubtator + '_test.log')
read = read_bio
print(args)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.n_gpu = torch.cuda.device_count()
args.device = device

config = AutoConfig.from_pretrained(
    args.config_name if args.config_name else args.model_name_or_path, num_labels=args.num_class, )
tokenizer = AutoTokenizer.from_pretrained(
    args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, )
model = AutoModel.from_pretrained(
    args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, )
config.cls_token_id = tokenizer.cls_token_id
config.sep_token_id = tokenizer.sep_token_id
config.transformer_type = args.transformer_type
set_seed(args)
model = Model(args, config, model, num_labels=1)
model.to(0)

# Training
train_file = os.path.join(args.data_dir, args.train_file)
dev_file = os.path.join(args.data_dir, args.dev_file)
test_file = os.path.join(args.data_dir, args.test_file)
train_features = read(args, train_file, tokenizer, max_seq_length=args.max_seq_length)
dev_features = read(args, dev_file, tokenizer, max_seq_length=args.max_seq_length)
test_features = read(args, test_file, tokenizer, max_seq_length=args.max_seq_length)
train(args, model, train_features, dev_features, test_features)

print("TEST")
# Here is the final training result obtained
model.load_state_dict(torch.load(args.save_path))
test_score, test_output = evaluate(args, model, test_features, tag="test", generate=True)
print(test_output)

Namespace(adam_epsilon=1e-06, bert_lr=3e-05, config_name='', data_dir='./dataset/biored_cd', demo='false', dev_file='test.data', down_dim=256, dropout=0.5, evaluation_steps=-1, gnn='GCN', gradient_accumulation_steps=1, learning_rate=2e-05, load_path='', loss='BSCELoss', max_grad_norm=1.0, max_height=64, max_seq_length=1024, model_name_or_path='/data/pretrained/BiomedNLP-PubMedBERT-base-uncased-abstract', num_class=4, num_train_epochs=30, rel2=0, s0=0.3, save_path='out/train+dev_20240819-2126_bert_biored_cd_seed_BSCELoss_tree_0.3_0.5_66', save_pubtator='./result/biored_cd/biored_cd_20240819-2126_BSCELoss_tree_s0=0.3_dropout=0.5', save_result='', seed=66, task='biored_cd', test_batch_size=12, test_file='test.data', tokenizer_name='', train_batch_size=12, train_file='train+dev.data', transformer_type='bert', unet_in_dim=3, unet_out_dim=256, use_gcn='tree', warmup_ratio=0.06)


100%|██████████| 490/490 [00:30<00:00, 16.25it/s]
100%|██████████| 100/100 [00:05<00:00, 17.44it/s]
100%|██████████| 100/100 [00:06<00:00, 14.85it/s]
epoch:   0%|          | 0/30 [00:00<?, ?it/s]

Total steps: 810
Warmup steps: 48
training risk: 10.854312896728516    step: 26
dev_score:  0.16888618591891452 dev_output:  {'dev_F1': 16.88861859189145, 'dev_P': 52.777763117288025, 'dev_R': 10.05290952101008}
test_score:  0.16888618591891452 test_output:  {'test_F1': 16.88861859189145, 'test_P': 52.777763117288025, 'test_R': 10.05290952101008}


epoch:   3%|▎         | 1/30 [00:35<17:17, 35.77s/it]

training risk: 11.25239372253418    step: 53
dev_score:  0.09174080428273232 dev_output:  {'dev_F1': 9.174080428273232, 'dev_P': 34.482746730087335, 'dev_R': 5.291005011057936}
test_score:  0.09174080428273232 test_output:  {'test_F1': 9.174080428273232, 'test_P': 34.482746730087335, 'test_R': 5.291005011057936}


epoch:   7%|▋         | 2/30 [01:11<16:43, 35.83s/it]

training risk: 12.04153823852539    step: 80
dev_score:  0.12149325228360318 dev_output:  {'dev_F1': 12.149325228360318, 'dev_P': 51.99997920000832, 'dev_R': 6.878306514375317}
test_score:  0.12149325228360318 test_output:  {'test_F1': 12.149325228360318, 'test_P': 51.99997920000832, 'test_R': 6.878306514375317}


epoch:  10%|█         | 3/30 [01:47<16:08, 35.88s/it]

training risk: 10.9982328414917    step: 107
dev_score:  0.3197232776354228 dev_output:  {'dev_F1': 31.97232776354228, 'dev_P': 44.761900498866616, 'dev_R': 24.867723551972297}
test_score:  0.3197232776354228 test_output:  {'test_F1': 31.97232776354228, 'test_P': 44.761900498866616, 'test_R': 24.867723551972297}


epoch:  13%|█▎        | 4/30 [02:27<16:11, 37.38s/it]

training risk: 11.900747299194336    step: 134
dev_score:  0.5064054477356846 dev_output:  {'dev_F1': 50.640544773568465, 'dev_P': 64.22763705466366, 'dev_R': 41.798939587357694}
test_score:  0.5064054477356846 test_output:  {'test_F1': 50.640544773568465, 'test_P': 64.22763705466366, 'test_R': 41.798939587357694}


epoch:  17%|█▋        | 5/30 [03:09<16:18, 39.14s/it]

training risk: 11.241999626159668    step: 161
dev_score:  0.5844108047694245 dev_output:  {'dev_F1': 58.44108047694245, 'dev_P': 75.6302457453575, 'dev_R': 47.61904509952142}
test_score:  0.5844108047694245 test_output:  {'test_F1': 58.44108047694245, 'test_P': 75.6302457453575, 'test_R': 47.61904509952142}


epoch:  20%|██        | 6/30 [03:51<15:58, 39.95s/it]

training risk: 10.889534950256348    step: 188
dev_score:  0.6178811483868564 dev_output:  {'dev_F1': 61.78811483868564, 'dev_P': 63.33332981481501, 'dev_R': 60.31745712606047}
test_score:  0.6178811483868564 test_output:  {'test_F1': 61.78811483868564, 'test_P': 63.33332981481501, 'test_R': 60.31745712606047}


epoch:  23%|██▎       | 7/30 [04:28<15:03, 39.29s/it]

training risk: 11.229690551757812    step: 215
dev_score:  0.6234518905652457 dev_output:  {'dev_F1': 62.345189056524575, 'dev_P': 74.81480927297709, 'dev_R': 53.43915061168515}
test_score:  0.6234518905652457 test_output:  {'test_F1': 62.345189056524575, 'test_P': 74.81480927297709, 'test_R': 53.43915061168515}


epoch:  27%|██▋       | 8/30 [05:09<14:36, 39.84s/it]

training risk: 11.951997756958008    step: 242
dev_score:  0.646734099016624 dev_output:  {'dev_F1': 64.67340990166241, 'dev_P': 66.48044321338305, 'dev_R': 62.962959631589435}
test_score:  0.646734099016624 test_output:  {'test_F1': 64.67340990166241, 'test_P': 66.48044321338305, 'test_R': 62.962959631589435}


epoch:  30%|███       | 9/30 [05:48<13:50, 39.56s/it]

training risk: 11.108534812927246    step: 269
dev_score:  0.6650781559939485 dev_output:  {'dev_F1': 66.50781559939485, 'dev_P': 60.34482498513686, 'dev_R': 74.0740701548111}
test_score:  0.6650781559939485 test_output:  {'test_F1': 66.50781559939485, 'test_P': 60.34482498513686, 'test_R': 74.0740701548111}


epoch:  33%|███▎      | 10/30 [06:27<13:05, 39.26s/it]

training risk: 10.539522171020508    step: 296
dev_score:  0.5611466857642297 dev_output:  {'dev_F1': 56.114668576422964, 'dev_P': 87.64043959096183, 'dev_R': 41.2698390862519}
test_score:  0.5611466857642297 test_output:  {'test_F1': 56.114668576422964, 'test_P': 87.64043959096183, 'test_R': 41.2698390862519}


epoch:  37%|███▋      | 11/30 [07:03<12:09, 38.40s/it]

training risk: 10.412266731262207    step: 323
dev_score:  0.6285665886987628 dev_output:  {'dev_F1': 62.856658869876284, 'dev_P': 78.5714223356014, 'dev_R': 52.38094960947356}
test_score:  0.6285665886987628 test_output:  {'test_F1': 62.856658869876284, 'test_P': 78.5714223356014, 'test_R': 52.38094960947356}


epoch:  40%|████      | 12/30 [07:41<11:27, 38.19s/it]

training risk: 10.380870819091797    step: 350
dev_score:  0.6781559177599256 dev_output:  {'dev_F1': 67.81559177599256, 'dev_P': 74.21383181045083, 'dev_R': 62.43385913048364}
test_score:  0.6781559177599256 test_output:  {'test_F1': 67.81559177599256, 'test_P': 74.21383181045083, 'test_R': 62.43385913048364}


epoch:  43%|████▎     | 13/30 [08:19<10:46, 38.03s/it]

training risk: 11.604613304138184    step: 377
dev_score:  0.6605455399737313 dev_output:  {'dev_F1': 66.05455399737313, 'dev_P': 78.2608638941403, 'dev_R': 57.142854119425714}
test_score:  0.6605455399737313 test_output:  {'test_F1': 66.05455399737313, 'test_P': 78.2608638941403, 'test_R': 57.142854119425714}


epoch:  47%|████▋     | 14/30 [08:55<10:01, 37.61s/it]

training risk: 9.496519088745117    step: 404
dev_score:  0.680227567471134 dev_output:  {'dev_F1': 68.0227567471134, 'dev_P': 75.4838660978151, 'dev_R': 61.90475862937785}
test_score:  0.680227567471134 test_output:  {'test_F1': 68.0227567471134, 'test_P': 75.4838660978151, 'test_R': 61.90475862937785}


epoch:  50%|█████     | 15/30 [09:37<09:40, 38.69s/it]

training risk: 9.312358856201172    step: 431
dev_score:  0.6528140263999913 dev_output:  {'dev_F1': 65.28140263999913, 'dev_P': 74.32431930241086, 'dev_R': 58.20105512163729}
test_score:  0.6528140263999913 test_output:  {'test_F1': 65.28140263999913, 'test_P': 74.32431930241086, 'test_R': 58.20105512163729}


epoch:  53%|█████▎    | 16/30 [10:12<08:47, 37.70s/it]

training risk: 10.80964469909668    step: 458
dev_score:  0.709136246073726 dev_output:  {'dev_F1': 70.91362460737261, 'dev_P': 74.41860032449998, 'dev_R': 67.72486414154157}
test_score:  0.709136246073726 test_output:  {'test_F1': 70.91362460737261, 'test_P': 74.41860032449998, 'test_R': 67.72486414154157}


epoch:  57%|█████▋    | 17/30 [10:51<08:14, 38.03s/it]

training risk: 10.886937141418457    step: 485
dev_score:  0.6607619888813447 dev_output:  {'dev_F1': 66.07619888813447, 'dev_P': 74.66666168888922, 'dev_R': 59.25925612384888}
test_score:  0.6607619888813447 test_output:  {'test_F1': 66.07619888813447, 'test_P': 74.66666168888922, 'test_R': 59.25925612384888}


epoch:  60%|██████    | 18/30 [11:29<07:35, 37.97s/it]

training risk: 10.533876419067383    step: 512
dev_score:  0.6394935635824007 dev_output:  {'dev_F1': 63.949356358240074, 'dev_P': 78.46153242603596, 'dev_R': 53.968251112790945}
test_score:  0.6394935635824007 test_output:  {'test_F1': 63.949356358240074, 'test_P': 78.46153242603596, 'test_R': 53.968251112790945}


epoch:  63%|██████▎   | 19/30 [12:04<06:50, 37.30s/it]

training risk: 10.87967300415039    step: 539
dev_score:  0.6441668668741378 dev_output:  {'dev_F1': 64.41668668741379, 'dev_P': 76.64233017209268, 'dev_R': 55.55555261610833}
test_score:  0.6441668668741378 test_output:  {'test_F1': 64.41668668741379, 'test_P': 76.64233017209268, 'test_R': 55.55555261610833}


epoch:  67%|██████▋   | 20/30 [12:40<06:06, 36.70s/it]

training risk: 9.731158256530762    step: 566
dev_score:  0.6896501699722943 dev_output:  {'dev_F1': 68.96501699722944, 'dev_P': 75.47169336656016, 'dev_R': 63.49206013269523}
test_score:  0.6896501699722943 test_output:  {'test_F1': 68.96501699722944, 'test_P': 75.47169336656016, 'test_R': 63.49206013269523}


epoch:  70%|███████   | 21/30 [13:16<05:28, 36.49s/it]

training risk: 10.678374290466309    step: 593
dev_score:  0.6997034661077948 dev_output:  {'dev_F1': 69.97034661077947, 'dev_P': 77.92207286220307, 'dev_R': 63.49206013269523}
test_score:  0.6997034661077948 test_output:  {'test_F1': 69.97034661077947, 'test_P': 77.92207286220307, 'test_R': 63.49206013269523}


epoch:  73%|███████▎  | 22/30 [13:51<04:48, 36.09s/it]

training risk: 9.568910598754883    step: 620
dev_score:  0.6855473942050616 dev_output:  {'dev_F1': 68.55473942050617, 'dev_P': 73.78048330606809, 'dev_R': 64.02116063380102}
test_score:  0.6855473942050616 test_output:  {'test_F1': 68.55473942050617, 'test_P': 73.78048330606809, 'test_R': 64.02116063380102}


epoch:  77%|███████▋  | 23/30 [14:26<04:10, 35.79s/it]

training risk: 8.982707977294922    step: 647
dev_score:  0.7014442804802726 dev_output:  {'dev_F1': 70.14442804802727, 'dev_P': 77.56409759204502, 'dev_R': 64.02116063380102}
test_score:  0.7014442804802726 test_output:  {'test_F1': 70.14442804802727, 'test_P': 77.56409759204502, 'test_R': 64.02116063380102}


epoch:  80%|████████  | 24/30 [15:02<03:35, 35.88s/it]

training risk: 9.607145309448242    step: 674
dev_score:  0.6815592233752988 dev_output:  {'dev_F1': 68.15592233752989, 'dev_P': 72.18934484086716, 'dev_R': 64.55026113490682}
test_score:  0.6815592233752988 test_output:  {'test_F1': 68.15592233752989, 'test_P': 72.18934484086716, 'test_R': 64.55026113490682}


epoch:  83%|████████▎ | 25/30 [15:38<02:58, 35.77s/it]

training risk: 10.705179214477539    step: 701
dev_score:  0.6860415206211801 dev_output:  {'dev_F1': 68.604152062118, 'dev_P': 76.12902734651436, 'dev_R': 62.43385913048364}
test_score:  0.6860415206211801 test_output:  {'test_F1': 68.604152062118, 'test_P': 76.12902734651436, 'test_R': 62.43385913048364}


epoch:  87%|████████▋ | 26/30 [16:13<02:22, 35.53s/it]

training risk: 9.366463661193848    step: 728
dev_score:  0.705197314178317 dev_output:  {'dev_F1': 70.5197314178317, 'dev_P': 77.70700141993622, 'dev_R': 64.55026113490682}
test_score:  0.705197314178317 test_output:  {'test_F1': 70.5197314178317, 'test_P': 77.70700141993622, 'test_R': 64.55026113490682}


epoch:  90%|█████████ | 27/30 [16:48<01:46, 35.63s/it]

training risk: 9.526522636413574    step: 755
dev_score:  0.6842055417408689 dev_output:  {'dev_F1': 68.42055417408689, 'dev_P': 76.47058323721679, 'dev_R': 61.90475862937785}
test_score:  0.6842055417408689 test_output:  {'test_F1': 68.42055417408689, 'test_P': 76.47058323721679, 'test_R': 61.90475862937785}


epoch:  93%|█████████▎| 28/30 [17:23<01:10, 35.32s/it]

training risk: 10.233739852905273    step: 782
dev_score:  0.6898500782544377 dev_output:  {'dev_F1': 68.98500782544377, 'dev_P': 76.28204639217651, 'dev_R': 62.962959631589435}
test_score:  0.6898500782544377 test_output:  {'test_F1': 68.98500782544377, 'test_P': 76.28204639217651, 'test_R': 62.962959631589435}


epoch:  97%|█████████▋| 29/30 [18:00<00:35, 35.72s/it]

training risk: 9.399181365966797    step: 809
dev_score:  0.6994169676919197 dev_output:  {'dev_F1': 69.94169676919198, 'dev_P': 77.07005878534657, 'dev_R': 64.02116063380102}
test_score:  0.6994169676919197 test_output:  {'test_F1': 69.94169676919198, 'test_P': 77.07005878534657, 'test_R': 64.02116063380102}


epoch: 100%|██████████| 30/30 [18:36<00:00, 37.22s/it]


TEST
generate predict result in ./result/biored_cd/biored_cd_20240819-2126_BSCELoss_tree_s0=0.3_dropout=0.5
./result/biored_cd/biored_cd_20240819-2126_BSCELoss_tree_s0=0.3_dropout=0.5.pubtator
{'test_F1': 69.94169676919198, 'test_P': 77.07005878534657, 'test_R': 64.02116063380102}
