In [2]:
import argparse
import torch
import torch.optim as optim
from model.casRel import CasRel
from model.callback import MyCallBack
from model.data import load_data, get_data_iterator
from model.config import Config
from model.evaluate import metric
import torch.nn.functional as F
from fastNLP import Trainer, LossBase

seed = 226
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

parser = argparse.ArgumentParser(description='Model Controller')
parser.add_argument('--lr', type=float, default=1e-5, help='learning rate')
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--max_epoch', type=int, default=10)
parser.add_argument('--max_len', type=int, default=300)
parser.add_argument('--dataset', default='baidu', type=str, help='define your own dataset names')
parser.add_argument("--bert_name", default='./pretrained_models/bert-base-chinese/', type=str, help='choose pretrained bert name')
parser.add_argument('--bert_dim', default=768, type=int)
args = parser.parse_args(args=[])
con = Config(args)


def loss_fn(pred, gold, mask):
    pred = pred.squeeze(-1)
    loss = F.binary_cross_entropy(pred, gold, reduction='none')
    if loss.shape != mask.shape:
        mask = mask.unsqueeze(-1)
    loss = torch.sum(loss * mask) / torch.sum(mask)
    return loss

def get_loss(predict, target):  # Casrel计算loss的方程
    mask = target['mask']
    return loss_fn(predict['sub_heads'], target['sub_heads'], mask) + \
            loss_fn(predict['sub_tails'], target['sub_tails'], mask) + \
            loss_fn(predict['obj_heads'], target['obj_heads'], mask) + \
            loss_fn(predict['obj_tails'], target['obj_tails'], mask)

# if __name__ == '__main__':
model = CasRel(con).to(device)
data_bundle, rel_vocab = load_data(con.train_path, con.dev_path, con.test_path, con.rel_path)
train_dataset = get_data_iterator(con, data_bundle.get_dataset('train'), rel_vocab)
dev_dataset = get_data_iterator(con, data_bundle.get_dataset('dev'), rel_vocab, is_test=True)
test_dataset = get_data_iterator(con, data_bundle.get_dataset('test'), rel_vocab, is_test=True)
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=con.lr)
# trainer = Trainer(train_data=train_data, model=model, optimizer=optimizer, batch_size=con.batch_size,
#                     n_epochs=con.max_epoch, loss=MyLoss(), print_every=con.period, use_tqdm=True,
#                     callbacks=MyCallBack(dev_data, rel_vocab, con))
#     trainer.train()
#     print("-" * 5 + "Begin Testing" + "-" * 5)
#     metric(test_data, rel_vocab, con, model)


Some weights of the model checkpoint at ./pretrained_models/bert-base-chinese/ were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at ./pretrained_models/bert-base-chinese/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
from tqdm import tqdm
def train_epoch(train_loader, model, optimizer, epoch):
    global train_data, batch_samples
    # set model to training mode
    model.train() # 不固定batch normalization和dropout，需要更新
    # step number in one epoch: 336
    train_losses = 0
    for idx, batch_samples in enumerate(tqdm(train_loader)):  # tqdm是显示进度条的,每次加载一个batch32个数据
        train_data = batch_samples[0]
        target_data = batch_samples[1]
        predict = model(train_data['token_ids'], train_data['mask'], 
                train_data['sub_head'], train_data['sub_tail'])
        loss = get_loss(predict, target_data)
        # compute model output and loss
        train_losses += loss.item()
        # clear previous gradients, compute gradients of all variables wrt loss
        model.zero_grad()
        loss.backward()
        # gradient clipping
        # performs updates using calculated gradients
        optimizer.step()  # 要先Loss.backward()之后再用step,先清零参数空间的梯度,用了才会更新模型
        print(idx, train_losses)
    train_loss = float(train_losses) / len(train_loader)
    print("Epoch: {}, train loss: {}".format(epoch, train_loss))

best_val_f1 = 0.0
patience_counter = 0
# start training
epoch_num = 5
for epoch in range(1, epoch_num + 1):
    train_epoch(train_dataset, model, optimizer, epoch)
    # val_metrics = evaluate(dev_loader, model, mode='dev')
    # val_f1 = val_metrics['f1']
    # val_p = val_metrics['p']
    # val_r = val_metrics['r']
#     logging.info("Epoch: {}, dev loss: {}, f1 score: {}, precision: {}, recall: {}".format(epoch, val_metrics['loss'], val_f1, val_p, val_r))
#     improve_f1 = val_f1 - best_val_f1
#     if improve_f1 > 1e-5:
#         best_val_f1 = val_f1
#         model.save_pretrained(model_dir)
#         logging.info("--------Save best model!--------")
#         if improve_f1 < config.patience:
#             patience_counter += 1
#         else:
#             patience_counter = 0
#     else:
#         patience_counter += 1
#     # Early stopping and logging best f1
#     if (patience_counter >= config.patience_num and epoch > config.min_epoch_num) or epoch == config.epoch_num:
#         logging.info("Best val f1: {}".format(best_val_f1))
#         break
# logging.info("Training Finished!")

  2%|▏         | 1/64 [00:03<04:04,  3.88s/it]

0 27.29383087158203


  3%|▎         | 2/64 [00:06<03:10,  3.08s/it]

1 52.47056198120117


  5%|▍         | 3/64 [00:08<02:40,  2.63s/it]

2 76.3529052734375


  6%|▋         | 4/64 [00:12<03:03,  3.06s/it]

3 97.56804275512695


  8%|▊         | 5/64 [00:14<02:42,  2.76s/it]

4 117.76546096801758


  9%|▉         | 6/64 [00:16<02:29,  2.57s/it]

5 135.85931968688965


 11%|█         | 7/64 [00:18<02:15,  2.38s/it]

6 152.78892707824707


 12%|█▎        | 8/64 [00:22<02:44,  2.93s/it]

7 168.62221240997314


 14%|█▍        | 9/64 [00:26<02:48,  3.07s/it]

8 183.07087993621826


 16%|█▌        | 10/64 [00:28<02:41,  2.98s/it]

9 195.86003398895264


 17%|█▋        | 11/64 [00:30<02:21,  2.68s/it]

10 207.67243480682373


 19%|█▉        | 12/64 [00:32<02:02,  2.35s/it]

11 219.03192615509033


 20%|██        | 13/64 [00:34<01:50,  2.17s/it]

12 228.7885046005249


 22%|██▏       | 14/64 [00:36<01:56,  2.33s/it]

13 237.6734676361084


 23%|██▎       | 15/64 [00:39<01:55,  2.36s/it]

14 245.82949924468994


 25%|██▌       | 16/64 [00:43<02:24,  3.00s/it]

15 253.18018341064453


 27%|██▋       | 17/64 [00:46<02:14,  2.87s/it]

16 260.08449840545654


 28%|██▊       | 18/64 [00:49<02:19,  3.03s/it]

17 266.0562572479248


 30%|██▉       | 19/64 [00:53<02:22,  3.17s/it]

18 271.76064682006836


 31%|███▏      | 20/64 [00:57<02:31,  3.43s/it]

19 276.73747730255127


 33%|███▎      | 21/64 [00:59<02:08,  2.99s/it]

20 281.51493883132935


 34%|███▍      | 22/64 [01:02<02:08,  3.06s/it]

21 285.9951992034912


 36%|███▌      | 23/64 [01:05<01:59,  2.91s/it]

22 290.17182540893555


 38%|███▊      | 24/64 [01:07<01:48,  2.71s/it]

23 293.91568636894226


 39%|███▉      | 25/64 [01:10<01:50,  2.83s/it]

24 297.5037899017334


 41%|████      | 26/64 [01:14<02:04,  3.28s/it]

25 300.86815333366394


 42%|████▏     | 27/64 [01:18<02:10,  3.52s/it]

26 304.0716335773468


 44%|████▍     | 28/64 [01:21<02:00,  3.34s/it]

27 307.11314392089844


 45%|████▌     | 29/64 [01:23<01:40,  2.88s/it]

28 310.10758113861084


 47%|████▋     | 30/64 [01:27<01:50,  3.24s/it]

29 313.02364897727966


 48%|████▊     | 31/64 [01:30<01:44,  3.18s/it]

30 315.78454327583313


 50%|█████     | 32/64 [01:33<01:38,  3.08s/it]

31 318.4409682750702


 52%|█████▏    | 33/64 [01:36<01:33,  3.02s/it]

32 320.88542222976685


 53%|█████▎    | 34/64 [01:38<01:20,  2.67s/it]

33 323.4167709350586


 55%|█████▍    | 35/64 [01:41<01:22,  2.86s/it]

34 325.80726861953735


 56%|█████▋    | 36/64 [01:44<01:21,  2.93s/it]

35 328.01715183258057


 58%|█████▊    | 37/64 [01:49<01:36,  3.57s/it]

36 330.1315336227417


 59%|█████▉    | 38/64 [01:52<01:27,  3.38s/it]

37 332.33672642707825


 61%|██████    | 39/64 [01:54<01:13,  2.94s/it]

38 334.5858156681061


 62%|██████▎   | 40/64 [01:56<01:05,  2.75s/it]

39 336.7723205089569


 64%|██████▍   | 41/64 [02:00<01:09,  3.03s/it]

40 338.744025349617


 66%|██████▌   | 42/64 [02:03<01:06,  3.02s/it]

41 340.63176143169403


 67%|██████▋   | 43/64 [02:06<01:04,  3.05s/it]

42 342.4799770116806


 69%|██████▉   | 44/64 [02:09<01:00,  3.05s/it]

43 344.48086726665497


 70%|███████   | 45/64 [02:12<00:56,  2.95s/it]

44 346.4862788915634


 72%|███████▏  | 46/64 [02:16<01:00,  3.34s/it]

45 348.1983948945999


 73%|███████▎  | 47/64 [02:19<00:53,  3.17s/it]

46 349.95953691005707


 75%|███████▌  | 48/64 [02:21<00:46,  2.88s/it]

47 351.75483679771423


 77%|███████▋  | 49/64 [02:25<00:44,  3.00s/it]

48 353.4512507915497


 78%|███████▊  | 50/64 [02:28<00:44,  3.18s/it]

49 355.0318088531494


 80%|███████▉  | 51/64 [02:31<00:41,  3.17s/it]

50 356.7401747703552


 81%|████████▏ | 52/64 [02:34<00:34,  2.90s/it]

51 358.54018235206604


 83%|████████▎ | 53/64 [02:36<00:30,  2.79s/it]

52 360.2198734283447


 84%|████████▍ | 54/64 [02:38<00:26,  2.62s/it]

53 361.94763708114624


 86%|████████▌ | 55/64 [02:41<00:22,  2.51s/it]

54 363.5488817691803


 88%|████████▊ | 56/64 [02:43<00:20,  2.58s/it]

55 365.14849042892456


 89%|████████▉ | 57/64 [02:47<00:19,  2.83s/it]

56 366.6849715709686


 91%|█████████ | 58/64 [02:49<00:15,  2.56s/it]

57 368.5080221891403


 92%|█████████▏| 59/64 [02:51<00:12,  2.45s/it]

58 370.1350291967392


 94%|█████████▍| 60/64 [02:54<00:11,  2.76s/it]

59 371.6985059976578


 95%|█████████▌| 61/64 [02:57<00:08,  2.70s/it]

60 373.2074216604233


 97%|█████████▋| 62/64 [03:00<00:05,  2.84s/it]

61 374.61226093769073


 98%|█████████▊| 63/64 [03:03<00:02,  2.92s/it]

62 376.01344752311707


100%|██████████| 64/64 [03:04<00:00,  2.89s/it]


63 377.49265480041504
Epoch: 1, train loss: 5.898322731256485


  2%|▏         | 1/64 [00:02<02:06,  2.00s/it]

0 1.5580439567565918


  3%|▎         | 2/64 [00:04<02:21,  2.29s/it]

1 3.0094090700149536


  5%|▍         | 3/64 [00:06<02:15,  2.22s/it]

2 4.566168904304504


  6%|▋         | 4/64 [00:09<02:37,  2.63s/it]

3 5.96911346912384


  8%|▊         | 5/64 [00:13<02:51,  2.91s/it]

4 7.337297677993774


  9%|▉         | 6/64 [00:16<02:46,  2.88s/it]

5 8.773299932479858


 11%|█         | 7/64 [00:19<02:49,  2.97s/it]

6 10.073745012283325


 12%|█▎        | 8/64 [00:24<03:27,  3.70s/it]

7 11.390488147735596


 14%|█▍        | 9/64 [00:26<02:54,  3.17s/it]

8 12.911989212036133


 16%|█▌        | 10/64 [00:29<02:45,  3.06s/it]

9 14.358134269714355


 17%|█▋        | 11/64 [00:33<02:54,  3.29s/it]

10 15.712691307067871


 19%|█▉        | 12/64 [00:36<02:44,  3.17s/it]

11 17.03742742538452


 20%|██        | 13/64 [00:38<02:33,  3.02s/it]

12 18.233359336853027


 22%|██▏       | 14/64 [00:40<02:18,  2.76s/it]

13 19.657650232315063


 23%|██▎       | 15/64 [00:44<02:24,  2.94s/it]

14 20.973708629608154


 25%|██▌       | 16/64 [00:46<02:13,  2.78s/it]

15 22.398618459701538


 27%|██▋       | 17/64 [00:48<02:01,  2.59s/it]

16 23.7459237575531


 28%|██▊       | 18/64 [00:53<02:23,  3.13s/it]

17 24.9407320022583


 30%|██▉       | 19/64 [00:57<02:36,  3.47s/it]

18 26.25359559059143


 31%|███▏      | 20/64 [01:01<02:42,  3.69s/it]

19 27.49490475654602


 33%|███▎      | 21/64 [01:05<02:43,  3.80s/it]

20 28.675480604171753


 34%|███▍      | 22/64 [01:08<02:26,  3.50s/it]

21 29.943108916282654


 36%|███▌      | 23/64 [01:12<02:24,  3.52s/it]

22 31.111348748207092


 38%|███▊      | 24/64 [01:14<02:11,  3.30s/it]

23 32.51300013065338


 39%|███▉      | 25/64 [01:18<02:07,  3.26s/it]

24 33.66608417034149


 41%|████      | 26/64 [01:20<01:50,  2.91s/it]

25 35.21214711666107


 42%|████▏     | 27/64 [01:22<01:41,  2.73s/it]

26 36.638514041900635


 44%|████▍     | 28/64 [01:25<01:40,  2.81s/it]

27 37.913641691207886


 45%|████▌     | 29/64 [01:28<01:45,  3.00s/it]

28 39.186408281326294


 47%|████▋     | 30/64 [01:31<01:41,  2.99s/it]

29 40.413705229759216


 48%|████▊     | 31/64 [01:35<01:42,  3.10s/it]

30 41.51995134353638


 50%|█████     | 32/64 [01:38<01:43,  3.23s/it]

31 42.783655285835266


 52%|█████▏    | 33/64 [01:42<01:42,  3.30s/it]

32 44.07882499694824


 53%|█████▎    | 34/64 [01:44<01:29,  2.97s/it]

33 45.43236207962036


 55%|█████▍    | 35/64 [01:47<01:24,  2.91s/it]

34 46.789676547050476


 56%|█████▋    | 36/64 [01:51<01:31,  3.28s/it]

35 47.99936604499817


 58%|█████▊    | 37/64 [01:54<01:24,  3.13s/it]

36 49.1823068857193


 59%|█████▉    | 38/64 [01:56<01:17,  2.98s/it]

37 50.312360882759094


 61%|██████    | 39/64 [01:58<01:08,  2.73s/it]

38 51.64632189273834


 62%|██████▎   | 40/64 [02:01<01:05,  2.72s/it]

39 52.84569585323334


 64%|██████▍   | 41/64 [02:03<00:58,  2.56s/it]

40 54.15953052043915


 66%|██████▌   | 42/64 [02:06<00:54,  2.47s/it]

41 55.27200186252594


 67%|██████▋   | 43/64 [02:09<00:58,  2.77s/it]

42 56.5702018737793


 69%|██████▉   | 44/64 [02:11<00:51,  2.59s/it]

43 57.88258147239685


 70%|███████   | 45/64 [02:14<00:52,  2.76s/it]

44 58.92141079902649


 72%|███████▏  | 46/64 [02:17<00:48,  2.67s/it]

45 59.97687768936157


 73%|███████▎  | 47/64 [02:19<00:44,  2.60s/it]

46 61.208762645721436


 75%|███████▌  | 48/64 [02:21<00:36,  2.30s/it]

47 62.55113172531128


 77%|███████▋  | 49/64 [02:23<00:34,  2.32s/it]

48 63.72300839424133


 78%|███████▊  | 50/64 [02:26<00:34,  2.49s/it]

49 64.97595191001892


 80%|███████▉  | 51/64 [02:28<00:31,  2.43s/it]

50 66.08209013938904


 81%|████████▏ | 52/64 [02:30<00:27,  2.32s/it]

51 67.30619418621063


 83%|████████▎ | 53/64 [02:34<00:29,  2.68s/it]

52 68.39991819858551


 84%|████████▍ | 54/64 [02:36<00:24,  2.47s/it]

53 69.52296876907349


 86%|████████▌ | 55/64 [02:38<00:22,  2.47s/it]

54 70.64855456352234


 88%|████████▊ | 56/64 [02:40<00:18,  2.29s/it]

55 71.87181580066681


 89%|████████▉ | 57/64 [02:43<00:16,  2.35s/it]

56 72.95076584815979


 91%|█████████ | 58/64 [02:45<00:14,  2.36s/it]

57 74.06312227249146


 92%|█████████▏| 59/64 [02:47<00:11,  2.26s/it]

58 75.29052805900574


 94%|█████████▍| 60/64 [02:49<00:08,  2.17s/it]

59 76.47011995315552


 95%|█████████▌| 61/64 [02:51<00:06,  2.15s/it]

60 77.59869694709778


 97%|█████████▋| 62/64 [02:54<00:04,  2.46s/it]

61 78.59882962703705


 98%|█████████▊| 63/64 [02:57<00:02,  2.55s/it]

62 79.73324418067932


100%|██████████| 64/64 [02:58<00:00,  2.79s/it]


63 80.78293859958649
Epoch: 2, train loss: 1.2622334156185389


  2%|▏         | 1/64 [00:02<02:53,  2.75s/it]

0 1.1573526859283447


  3%|▎         | 2/64 [00:04<02:21,  2.28s/it]

1 2.2944061756134033


  5%|▍         | 3/64 [00:06<02:11,  2.15s/it]

2 3.5320292711257935


  6%|▋         | 4/64 [00:08<02:01,  2.02s/it]

3 4.668139100074768


  8%|▊         | 5/64 [00:12<02:48,  2.85s/it]

4 5.583497762680054


  9%|▉         | 6/64 [00:16<02:59,  3.09s/it]

5 6.549239754676819


 11%|█         | 7/64 [00:18<02:33,  2.70s/it]

6 7.75945508480072


 12%|█▎        | 8/64 [00:21<02:38,  2.84s/it]

7 8.870033264160156


 14%|█▍        | 9/64 [00:23<02:15,  2.47s/it]

8 10.074916362762451


 16%|█▌        | 10/64 [00:25<02:14,  2.49s/it]

9 11.222720503807068


 17%|█▋        | 11/64 [00:27<01:58,  2.24s/it]

10 12.557731032371521


 19%|█▉        | 12/64 [00:29<02:03,  2.38s/it]

11 13.538386523723602


 20%|██        | 13/64 [00:32<02:00,  2.37s/it]

12 14.61361175775528


 22%|██▏       | 14/64 [00:34<01:48,  2.18s/it]

13 15.768455564975739


 23%|██▎       | 15/64 [00:35<01:39,  2.03s/it]

14 16.95198565721512


 25%|██▌       | 16/64 [00:38<01:53,  2.36s/it]

15 18.136392295360565


 27%|██▋       | 17/64 [00:42<02:08,  2.74s/it]

16 19.296452939510345


 28%|██▊       | 18/64 [00:46<02:16,  2.97s/it]

17 20.33767718076706


 30%|██▉       | 19/64 [00:48<02:04,  2.76s/it]

18 21.353420078754425


 31%|███▏      | 20/64 [00:51<02:11,  2.99s/it]

19 22.35784202814102


 33%|███▎      | 21/64 [00:54<02:00,  2.79s/it]

20 23.40116983652115


 34%|███▍      | 22/64 [00:56<01:52,  2.69s/it]

21 24.362449407577515


 36%|███▌      | 23/64 [00:58<01:38,  2.39s/it]

22 25.58771562576294


 38%|███▊      | 24/64 [01:01<01:43,  2.59s/it]

23 26.546633541584015


 39%|███▉      | 25/64 [01:03<01:30,  2.31s/it]

24 27.72427636384964


 41%|████      | 26/64 [01:04<01:22,  2.18s/it]

25 28.896146833896637


 42%|████▏     | 27/64 [01:07<01:23,  2.25s/it]

26 30.02939945459366


 44%|████▍     | 28/64 [01:10<01:29,  2.48s/it]

27 31.04733031988144


 45%|████▌     | 29/64 [01:12<01:22,  2.37s/it]

28 32.10566836595535


 47%|████▋     | 30/64 [01:14<01:17,  2.27s/it]

29 33.31824463605881


 48%|████▊     | 31/64 [01:17<01:27,  2.65s/it]

30 34.42282086610794


 50%|█████     | 32/64 [01:21<01:31,  2.86s/it]

31 35.54418021440506


 52%|█████▏    | 33/64 [01:23<01:20,  2.58s/it]

32 36.65636450052261


 53%|█████▎    | 34/64 [01:27<01:27,  2.92s/it]

33 37.56827932596207


 55%|█████▍    | 35/64 [01:28<01:15,  2.62s/it]

34 38.62415462732315


 56%|█████▋    | 36/64 [01:31<01:13,  2.62s/it]

35 39.62127465009689


 58%|█████▊    | 37/64 [01:34<01:13,  2.73s/it]

36 40.57473176717758


 59%|█████▉    | 38/64 [01:36<01:02,  2.41s/it]

37 41.67141968011856


 61%|██████    | 39/64 [01:38<00:57,  2.29s/it]

38 42.97092431783676


 62%|██████▎   | 40/64 [01:39<00:51,  2.13s/it]

39 44.12574225664139


 64%|██████▍   | 41/64 [01:42<00:51,  2.24s/it]

40 45.170625388622284


 66%|██████▌   | 42/64 [01:44<00:46,  2.12s/it]

41 46.28780668973923


 67%|██████▋   | 43/64 [01:47<00:48,  2.30s/it]

42 47.26845741271973


 69%|██████▉   | 44/64 [01:49<00:44,  2.25s/it]

43 48.37771987915039


 70%|███████   | 45/64 [01:50<00:39,  2.09s/it]

44 49.44722008705139


 72%|███████▏  | 46/64 [01:53<00:37,  2.11s/it]

45 50.46847128868103


 73%|███████▎  | 47/64 [01:56<00:41,  2.42s/it]

46 51.44875192642212


 75%|███████▌  | 48/64 [01:58<00:36,  2.30s/it]

47 52.43788379430771


 77%|███████▋  | 49/64 [02:00<00:32,  2.18s/it]

48 53.504755437374115


 78%|███████▊  | 50/64 [02:02<00:31,  2.27s/it]

49 54.488969683647156


 80%|███████▉  | 51/64 [02:05<00:31,  2.40s/it]

50 55.41562110185623


 81%|████████▏ | 52/64 [02:07<00:29,  2.47s/it]

51 56.38862144947052


 83%|████████▎ | 53/64 [02:09<00:25,  2.34s/it]

52 57.550063729286194


 84%|████████▍ | 54/64 [02:12<00:22,  2.29s/it]

53 58.57973086833954


 86%|████████▌ | 55/64 [02:13<00:19,  2.12s/it]

54 59.61495625972748


 88%|████████▊ | 56/64 [02:15<00:16,  2.02s/it]

55 60.54258334636688


 89%|████████▉ | 57/64 [02:18<00:15,  2.16s/it]

56 61.63653087615967


 91%|█████████ | 58/64 [02:19<00:11,  1.99s/it]

57 62.82610464096069


 92%|█████████▏| 59/64 [02:22<00:10,  2.13s/it]

58 63.785897612571716


 94%|█████████▍| 60/64 [02:23<00:08,  2.02s/it]

59 64.95950901508331


 95%|█████████▌| 61/64 [02:25<00:05,  1.83s/it]

60 66.16548943519592


 97%|█████████▋| 62/64 [02:27<00:04,  2.07s/it]

61 67.04815697669983


 98%|█████████▊| 63/64 [02:30<00:02,  2.16s/it]

62 68.01098024845123


100%|██████████| 64/64 [02:31<00:00,  2.36s/it]


63 68.9896469116211
Epoch: 3, train loss: 1.0779632329940796


  2%|▏         | 1/64 [00:01<01:54,  1.82s/it]

0 0.9517499804496765


  3%|▎         | 2/64 [00:03<02:05,  2.03s/it]

1 2.040139853954315


  5%|▍         | 3/64 [00:05<02:01,  1.99s/it]

2 2.9998642802238464


  6%|▋         | 4/64 [00:09<02:26,  2.44s/it]

3 3.8443498611450195


  8%|▊         | 5/64 [00:10<02:08,  2.17s/it]

4 4.837569057941437


  9%|▉         | 6/64 [00:12<01:58,  2.04s/it]

5 5.9691122174263


 11%|█         | 7/64 [00:14<01:47,  1.88s/it]

6 7.0866659283638


 12%|█▎        | 8/64 [00:16<01:54,  2.05s/it]

7 7.993262946605682


 14%|█▍        | 9/64 [00:18<01:58,  2.16s/it]

8 8.901896178722382


 16%|█▌        | 10/64 [00:20<01:49,  2.03s/it]

9 9.867933809757233


 17%|█▋        | 11/64 [00:23<01:57,  2.21s/it]

10 10.736327767372131


 19%|█▉        | 12/64 [00:25<01:54,  2.21s/it]

11 11.714833378791809


 20%|██        | 13/64 [00:27<01:56,  2.29s/it]

12 12.7889643907547


 22%|██▏       | 14/64 [00:30<01:53,  2.27s/it]

13 13.869576692581177


 23%|██▎       | 15/64 [00:31<01:42,  2.10s/it]

14 14.869939804077148


 25%|██▌       | 16/64 [00:34<01:50,  2.31s/it]

15 15.800019383430481


 27%|██▋       | 17/64 [00:36<01:48,  2.31s/it]

16 16.805744767189026


 28%|██▊       | 18/64 [00:39<01:43,  2.25s/it]

17 17.891396403312683


 30%|██▉       | 19/64 [00:41<01:44,  2.33s/it]

18 18.80858713388443


 31%|███▏      | 20/64 [00:43<01:36,  2.20s/it]

19 19.833737194538116


 33%|███▎      | 21/64 [00:45<01:37,  2.26s/it]

20 20.84757751226425


 34%|███▍      | 22/64 [00:49<01:51,  2.65s/it]

21 21.685361444950104


 36%|███▌      | 23/64 [00:51<01:39,  2.43s/it]

22 22.764762222766876


 38%|███▊      | 24/64 [00:53<01:33,  2.34s/it]

23 23.833130180835724


 39%|███▉      | 25/64 [00:56<01:35,  2.45s/it]

24 24.74060767889023


 41%|████      | 26/64 [00:58<01:29,  2.36s/it]

25 25.726482570171356


 42%|████▏     | 27/64 [01:00<01:23,  2.27s/it]

26 26.70174217224121


 44%|████▍     | 28/64 [01:03<01:26,  2.39s/it]

27 27.69081676006317


 45%|████▌     | 29/64 [01:06<01:29,  2.57s/it]

28 28.549103379249573


 47%|████▋     | 30/64 [01:07<01:18,  2.32s/it]

29 29.700901865959167


 48%|████▊     | 31/64 [01:09<01:13,  2.22s/it]

30 30.68601357936859


 50%|█████     | 32/64 [01:13<01:20,  2.52s/it]

31 31.713485598564148


 52%|█████▏    | 33/64 [01:16<01:25,  2.74s/it]

32 32.539435505867004


 53%|█████▎    | 34/64 [01:17<01:12,  2.41s/it]

33 33.691691160202026


 55%|█████▍    | 35/64 [01:20<01:10,  2.43s/it]

34 34.63409960269928


 56%|█████▋    | 36/64 [01:23<01:15,  2.68s/it]

35 35.407403349876404


 58%|█████▊    | 37/64 [01:25<01:04,  2.39s/it]

36 36.44278335571289


 59%|█████▉    | 38/64 [01:27<01:01,  2.36s/it]

37 37.46475064754486


 61%|██████    | 39/64 [01:29<00:58,  2.32s/it]

38 38.40419840812683


 62%|██████▎   | 40/64 [01:31<00:53,  2.22s/it]

39 39.31965517997742


 64%|██████▍   | 41/64 [01:34<00:50,  2.21s/it]

40 40.201202154159546


 66%|██████▌   | 42/64 [01:35<00:44,  2.03s/it]

41 41.289544105529785


 67%|██████▋   | 43/64 [01:37<00:38,  1.85s/it]

42 42.322200536727905


 69%|██████▉   | 44/64 [01:39<00:38,  1.90s/it]

43 43.322460412979126


 70%|███████   | 45/64 [01:40<00:33,  1.79s/it]

44 44.497424602508545


 72%|███████▏  | 46/64 [01:42<00:34,  1.90s/it]

45 45.51637411117554


 73%|███████▎  | 47/64 [01:45<00:37,  2.19s/it]

46 46.4157851934433


 75%|███████▌  | 48/64 [01:49<00:43,  2.73s/it]

47 47.224803149700165


 77%|███████▋  | 49/64 [01:53<00:43,  2.91s/it]

48 48.04203420877457


 78%|███████▊  | 50/64 [01:55<00:40,  2.89s/it]

49 48.96277165412903


 80%|███████▉  | 51/64 [01:58<00:35,  2.69s/it]

50 50.023351311683655


 81%|████████▏ | 52/64 [02:00<00:29,  2.48s/it]

51 51.12173080444336


 83%|████████▎ | 53/64 [02:02<00:27,  2.53s/it]

52 52.01001352071762


 84%|████████▍ | 54/64 [02:05<00:25,  2.53s/it]

53 53.021930038928986


 86%|████████▌ | 55/64 [02:07<00:21,  2.38s/it]

54 54.01066118478775


 88%|████████▊ | 56/64 [02:09<00:17,  2.21s/it]

55 55.00689721107483


 89%|████████▉ | 57/64 [02:11<00:16,  2.34s/it]

56 56.0370010137558


 91%|█████████ | 58/64 [02:14<00:15,  2.60s/it]

57 56.89330267906189


 92%|█████████▏| 59/64 [02:17<00:12,  2.45s/it]

58 57.89658164978027


 94%|█████████▍| 60/64 [02:18<00:08,  2.19s/it]

59 59.03364145755768


 95%|█████████▌| 61/64 [02:20<00:06,  2.16s/it]

60 60.03192323446274


 97%|█████████▋| 62/64 [02:23<00:04,  2.23s/it]

61 60.99470943212509


 98%|█████████▊| 63/64 [02:25<00:02,  2.18s/it]

62 61.89899629354477


100%|██████████| 64/64 [02:26<00:00,  2.28s/it]


63 62.98568552732468
Epoch: 4, train loss: 0.9841513363644481


  2%|▏         | 1/64 [00:02<02:29,  2.37s/it]

0 1.0102583169937134


  3%|▎         | 2/64 [00:05<02:59,  2.90s/it]

1 1.8107606172561646


  5%|▍         | 3/64 [00:07<02:32,  2.49s/it]

2 2.7026408910751343


  6%|▋         | 4/64 [00:10<02:44,  2.74s/it]

3 3.4973164796829224


  8%|▊         | 5/64 [00:14<03:09,  3.21s/it]

4 4.311240971088409


  9%|▉         | 6/64 [00:16<02:32,  2.63s/it]

5 5.264890670776367


 11%|█         | 7/64 [00:18<02:18,  2.43s/it]

6 6.247483730316162


 12%|█▎        | 8/64 [00:20<02:06,  2.26s/it]

7 7.182108163833618


 14%|█▍        | 9/64 [00:23<02:20,  2.55s/it]

8 8.050804018974304


 16%|█▌        | 10/64 [00:25<02:07,  2.35s/it]

9 8.875226438045502


 17%|█▋        | 11/64 [00:28<02:10,  2.47s/it]

10 9.76707261800766


 19%|█▉        | 12/64 [00:30<02:00,  2.32s/it]

11 10.772428333759308


 20%|██        | 13/64 [00:32<01:53,  2.23s/it]

12 11.741455972194672


 22%|██▏       | 14/64 [00:33<01:40,  2.01s/it]

13 12.825484931468964


 23%|██▎       | 15/64 [00:36<01:56,  2.38s/it]

14 13.628354847431183


 25%|██▌       | 16/64 [00:39<01:54,  2.38s/it]

15 14.440026879310608


 27%|██▋       | 17/64 [00:41<01:57,  2.49s/it]

16 15.30502724647522


 28%|██▊       | 18/64 [00:44<01:55,  2.52s/it]

17 16.241825580596924


 30%|██▉       | 19/64 [00:46<01:41,  2.26s/it]

18 17.324240565299988


 31%|███▏      | 20/64 [00:48<01:36,  2.19s/it]

19 18.284852743148804


 33%|███▎      | 21/64 [00:50<01:33,  2.19s/it]

20 19.13705265522003


 34%|███▍      | 22/64 [00:53<01:47,  2.56s/it]

21 19.949784636497498


 36%|███▌      | 23/64 [00:55<01:35,  2.34s/it]

22 20.984815001487732


 38%|███▊      | 24/64 [00:58<01:35,  2.38s/it]

23 21.834978103637695


 39%|███▉      | 25/64 [00:59<01:25,  2.19s/it]

24 22.863155007362366


 41%|████      | 26/64 [01:01<01:19,  2.10s/it]

25 23.886313319206238


 42%|████▏     | 27/64 [01:04<01:23,  2.25s/it]

26 24.70133328437805


 44%|████▍     | 28/64 [01:06<01:16,  2.14s/it]

27 25.800378918647766


 45%|████▌     | 29/64 [01:07<01:10,  2.00s/it]

28 26.69475567340851


 47%|████▋     | 30/64 [01:09<01:08,  2.01s/it]

29 27.558474898338318


 48%|████▊     | 31/64 [01:11<01:02,  1.89s/it]

30 28.666333079338074


 50%|█████     | 32/64 [01:14<01:07,  2.11s/it]

31 29.571189522743225


 52%|█████▏    | 33/64 [01:16<01:11,  2.30s/it]

32 30.481091022491455


 53%|█████▎    | 34/64 [01:18<01:02,  2.09s/it]

33 31.50442135334015


 55%|█████▍    | 35/64 [01:22<01:15,  2.59s/it]

34 32.431136786937714


 56%|█████▋    | 36/64 [01:24<01:11,  2.54s/it]

35 33.29769670963287


 58%|█████▊    | 37/64 [01:26<01:03,  2.34s/it]

36 34.10146474838257


 59%|█████▉    | 38/64 [01:28<00:56,  2.17s/it]

37 35.16708493232727


 61%|██████    | 39/64 [01:31<01:00,  2.40s/it]

38 36.04345917701721


 62%|██████▎   | 40/64 [01:33<00:58,  2.42s/it]

39 36.89807850122452


 64%|██████▍   | 41/64 [01:35<00:54,  2.37s/it]

40 37.69619232416153


 66%|██████▌   | 42/64 [01:38<00:50,  2.30s/it]

41 38.657981157302856


 67%|██████▋   | 43/64 [01:40<00:47,  2.25s/it]

42 39.493426501750946


 69%|██████▉   | 44/64 [01:42<00:47,  2.38s/it]

43 40.31835728883743


 70%|███████   | 45/64 [01:45<00:45,  2.40s/it]

44 41.100996911525726


 72%|███████▏  | 46/64 [01:47<00:43,  2.40s/it]

45 42.11043065786362


 73%|███████▎  | 47/64 [01:50<00:42,  2.47s/it]

46 42.9860822558403


 75%|███████▌  | 48/64 [01:52<00:37,  2.35s/it]

47 43.925968408584595


 77%|███████▋  | 49/64 [01:54<00:33,  2.22s/it]

48 45.02064621448517


 78%|███████▊  | 50/64 [01:57<00:33,  2.40s/it]

49 45.84380912780762


 80%|███████▉  | 51/64 [01:58<00:28,  2.18s/it]

50 46.72983229160309


 81%|████████▏ | 52/64 [02:00<00:24,  2.03s/it]

51 47.83744990825653


 83%|████████▎ | 53/64 [02:03<00:24,  2.21s/it]

52 48.542774736881256


 84%|████████▍ | 54/64 [02:04<00:20,  2.05s/it]

53 49.516280472278595


 86%|████████▌ | 55/64 [02:06<00:18,  2.05s/it]

54 50.51474076509476


 88%|████████▊ | 56/64 [02:08<00:15,  1.99s/it]

55 51.363266706466675


 89%|████████▉ | 57/64 [02:10<00:13,  1.86s/it]

56 52.45897126197815


 91%|█████████ | 58/64 [02:13<00:13,  2.32s/it]

57 53.18927466869354


 92%|█████████▏| 59/64 [02:17<00:13,  2.65s/it]

58 53.94772672653198


 94%|█████████▍| 60/64 [02:18<00:09,  2.40s/it]

59 54.9305579662323


 95%|█████████▌| 61/64 [02:20<00:06,  2.18s/it]

60 56.01670181751251


 97%|█████████▋| 62/64 [02:22<00:04,  2.17s/it]

61 56.8251610994339


 98%|█████████▊| 63/64 [02:25<00:02,  2.26s/it]

62 57.54821187257767


100%|██████████| 64/64 [02:26<00:00,  2.28s/it]

63 58.48323893547058
Epoch: 5, train loss: 0.9138006083667278





In [4]:
from model.evaluate import metric
metric(dev_dataset, rel_vocab, con, model)

100%|██████████| 36/36 [00:02<00:00, 13.10it/s]

correct_num:   0, predict_num:   0, gold_num:  63
f1: 0.00, precision: 0.00, recall: 0.00





(0.0, 0.0, 0.0)

In [20]:
from transformers import BertTokenizer
orders = ['subject', 'relation', 'object']
correct_num, predict_num, gold_num = 0, 0, 0
tokenizer = BertTokenizer.from_pretrained(con.bert_name)

for batch_x, batch_y in tqdm(dev_dataset):  # x用来测试，y是准确的数据集
    with torch.no_grad():
        token_ids = batch_x['token_ids']
        mask = batch_x['mask']
        encoded_text = model.get_encoded_text(token_ids, mask)
        pred_sub_heads, pred_sub_tails = model.get_subs(encoded_text)  # 预测
        sub_heads = torch.where(pred_sub_heads[0] > 0.05)[0]
        # if len(sub_heads)>0:
        #     print(sub_heads)
        sub_tails = torch.where(pred_sub_tails[0] > 0.5)[0]
        subjects = []
        for sub_head in sub_heads:
            sub_tail = sub_tails[sub_tails >= sub_head]
            if len(sub_tail) > 0:
                sub_tail = sub_tail[0]
                subject = ''.join(tokenizer.decode(token_ids[0][sub_head: sub_tail + 1]).split())
                subjects.append((subject, sub_head, sub_tail))

100%|██████████| 36/36 [00:03<00:00, 11.06it/s]


[]

In [94]:
optimizer.step()

In [3]:
from random import choice
from transformers import BertTokenizer
from transformers import BertModel
from collections import defaultdict
def find_head_idx(source, target):
    target_len = len(target)
    for i in range(len(source)):
        if source[i: i + target_len] == target:
            return i
    return -1

tokenizer = BertTokenizer.from_pretrained(con.bert_name)
bert = BertModel.from_pretrained(con.bert_name)

json_data = data_bundle.get_dataset('train')[223]
tokenized = tokenizer(json_data['text'])
tokens = tokenized['input_ids'] # 句子的length
masks = tokenized['attention_mask']
text_len = len(tokens)

token_ids = torch.tensor(tokens, dtype=torch.long)
masks = torch.tensor(masks, dtype=torch.bool)
"""主体和客体起始位置的记录"""
sub_heads, sub_tails = torch.zeros(text_len), torch.zeros(text_len)
sub_head, sub_tail = torch.zeros(text_len), torch.zeros(text_len)
obj_heads = torch.zeros((text_len, con.num_relations))
obj_tails = torch.zeros((text_len, con.num_relations))

s2ro_map = defaultdict(list)  # 创建一个dictionary，将键-值对更新为键-列表对，每个键可以调用list的属性
for spo in json_data['spo_list']:
    triple = (tokenizer(spo['subject'], add_special_tokens=False)['input_ids'], 
                rel_vocab.to_index(spo['predicate']),
                tokenizer(spo['object'], add_special_tokens=False)['input_ids']) # 把文本转换成id然后记录三元组,同时避免加入[CLS][SEP]这些特殊符号
    """
    - ISSUE: 如果某一个词语多次出现则只能找到第一个位置，这里有问题
    - SOLUTION: 再加一个变量记录这个词语是否出现过，如果出现过就记录他的位置，然后从这个位置开始往后找
    - WHY: 其实不改也可以，因为一个主体在文本中的意思应该是一样的。但是因为BERT模型会考虑前后文信息所以最好还是改一下？不确定，看验证结果，不是很严重的问题。
    """
    sub_head_idx = find_head_idx(tokens, triple[0])
    obj_head_idx = find_head_idx(tokens, triple[2])
    """可以试一下assert判断+终止"""
    if sub_head_idx != -1 and obj_head_idx != -1:
        sub = (sub_head_idx, sub_head_idx + len(triple[0]) - 1) # 主体位置
        s2ro_map[sub].append(
            (obj_head_idx, obj_head_idx + len(triple[2]) - 1, triple[1]))  # 用append解决一个主体对应多个客体的问题。客体位置+关系

if s2ro_map:  # 可能没有记录
    for s in s2ro_map:
        sub_heads[s[0]] = 1
        sub_tails[s[1]] = 1
    sub_head_idx, sub_tail_idx = choice(list(s2ro_map.keys()))
    sub_head[sub_head_idx] = 1
    sub_tail[sub_tail_idx] = 1
    for ro in s2ro_map.get((sub_head_idx, sub_tail_idx), []):
        obj_heads[ro[0]][ro[2]] = 1
        obj_tails[ro[1]][ro[2]] = 1

Some weights of the model checkpoint at ./pretrained_models/bert-base-chinese/ were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at ./pretrained_models/bert-base-chinese/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
