In [1]:
import json
import os
from tqdm import tqdm
import re
from IPython.core.debugger import set_trace
from pprint import pprint
import unicodedata
from transformers import AutoModel, BasicTokenizer, BertTokenizerFast
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import glob
import time
from layers import LayerNorm
import logging
from utils import Preprocessor, HandshakingTaggingScheme
import wandb

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device {} will be used".format(device))

device cuda:0 will be used


In [3]:
pretrained_model_home = "/data/yubowen/experiments/relextr/pretrained_model"
project_root = "/data/yubowen/experiments/relextr"
data_home = os.path.join(project_root, "data")

experiment_dir = os.path.join(project_root, "exp")
experiment_name = "nyt11"

nyt_data_dir = os.path.join(data_home, experiment_name)
nyt_train_data_path = os.path.join(nyt_data_dir, "train_triples.json")
nyt_valid_data_path = os.path.join(nyt_data_dir, "dev_triples.json")
nyt_test_data_path = os.path.join(nyt_data_dir, "test_triples.json")

In [4]:
wandb.init(project = experiment_name)

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable


W&B Run: https://app.wandb.ai/wycheng/nyt11/runs/3in264ld

In [5]:
# hyperparameters
config = wandb.config          # Initialize config
config.batch_size = 6          # input batch size for training (default: 64)
config.test_batch_size = 32    # input batch size for testing (default: 1000)
config.epochs = 50             # number of epochs to train (default: 10)
config.lr = 5e-5               # learning rate (default: 0.01)
config.seed = 2333               # random seed (default: 42)
config.log_interval = 10  
config.max_seq_len = 100
config.sliding_len = 20
config.loss_weight_recover_steps = 10000

torch.manual_seed(config.seed) # pytorch random seed
torch.backends.cudnn.deterministic = True

model_state_dict_dir = wandb.run.dir
schedule_state_dict_dir = wandb.run.dir

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable


# Load Data

In [6]:
nyt_train_data = json.load(open(nyt_train_data_path, "r", encoding = "utf-8"))
nyt_valid_data = json.load(open(nyt_valid_data_path, "r", encoding = "utf-8"))
nyt_test_data = json.load(open(nyt_test_data_path, "r", encoding = "utf-8"))

# Preprocess

In [7]:
# bert tokenizer
model_path = os.path.join(pretrained_model_home, "bert-base-cased")
tokenizer = BertTokenizerFast.from_pretrained(model_path, add_special_tokens = False, do_lower_case = False)

In [8]:
# @specific
def get_tok2char_span_map(text):
    return tokenizer.encode_plus(text, 
                               return_offsets_mapping = True, 
                               add_special_tokens = False)["offset_mapping"]

In [9]:
def tran2normal_samples(data):
    normal_sample_list = []
    for sample in tqdm(data, desc = "Transforming data format"):
        text = sample["text"]
        spo_list = sample["triple_list"]
        normal_sample = {
            "text": text,
            "id": sample["id"],
        }
        normal_rel_list = []
        for rel in spo_list:
            normal_rel_list.append({
                "subject": rel[0],
                "predicate": rel[1],
                "object": rel[2],
            })
        normal_sample["relation_list"] = normal_rel_list
        normal_sample_list.append(normal_sample)
    return normal_sample_list

In [10]:
preprocessor = Preprocessor(transform_func = tran2normal_samples, 
                            get_tok2char_span_map_func = get_tok2char_span_map)
train_data = preprocessor.get_normal_dataset(nyt_train_data, add_id = True, dataset_name = "train")
valid_data = preprocessor.get_normal_dataset(nyt_valid_data, add_id = True, dataset_name = "valid")
test_data = preprocessor.get_normal_dataset(nyt_test_data, add_id = True, dataset_name = "test")

62335it [00:00, 206203.44it/s]
Transforming data format: 100%|██████████| 62335/62335 [00:02<00:00, 28968.03it/s]
Adding token level spans: 100%|██████████| 62335/62335 [00:43<00:00, 1421.93it/s]
313it [00:00, 306876.38it/s]
Transforming data format: 100%|██████████| 313/313 [00:00<00:00, 140603.74it/s]
Adding token level spans: 100%|██████████| 313/313 [00:00<00:00, 1463.31it/s]
369it [00:00, 357601.24it/s]
Transforming data format: 100%|██████████| 369/369 [00:00<00:00, 178450.15it/s]
Adding token level spans: 100%|██████████| 369/369 [00:00<00:00, 1573.62it/s]


In [11]:
# # check token level span
# def extr_ent(text, tok_span, tok2char_span):
#     char_span_list = tok2char_span[tok_span[0]:tok_span[1]]
#     char_span = (char_span_list[0][0], char_span_list[-1][1])
#     decoded_ent = text[char_span[0]:char_span[1]]
#     return decoded_ent

# for sample in tqdm(train_data + valid_data + test_data):
#     text = sample["text"]
#     tok2char_span = get_tok2char_span_map(text)
#     for rel in sample["relation_list"]:
#         subj_span, obj_span = rel["subj_span"], rel["obj_span"]
#         assert extr_ent(text, subj_span, tok2char_span) == rel["subject"] and extr_ent(text, obj_span, tok2char_span) == rel["object"]

# Split

In [12]:
def split_into_short_samples(sample_list, sliding_len = 50):
    new_sample_list = []
    for sample in tqdm(sample_list, desc = "Splitting"):
        text_id = sample["id"]
        text = sample["text"]
        tokens = tokenizer.tokenize(text)
        offset_map = tokenizer.encode_plus(text, return_offsets_mapping = True, 
                                           add_special_tokens = False)["offset_mapping"]
        
        # sliding on token level
        split_sample_list = []
        for start_ind in range(0, len(tokens), sliding_len):
            while "##" in tokens[start_ind]:
                start_ind -= 1
            end_ind = start_ind + config.max_seq_len
#             while "##" in tokens[end_ind]:
#                 end_ind += 1
            char_span_list = offset_map[start_ind:end_ind]
            char_level_span = (char_span_list[0][0], char_span_list[-1][1])
            sub_text = text[char_level_span[0]:char_level_span[1]]

            new_sample = {
                "id": text_id,
                "text": sub_text,
                "relation_list": []
            }
            for rel in sample["relation_list"]:
                subj_span = rel["subj_span"]
                obj_span = rel["obj_span"]
                if subj_span[0] >= start_ind and subj_span[1] <= end_ind \
                    and obj_span[0] >= start_ind and obj_span[1] <= end_ind:
                    new_rel = copy.deepcopy(rel)
                    new_rel["subj_span"] = (subj_span[0] - start_ind, subj_span[1] - start_ind)
                    new_rel["obj_span"] = (obj_span[0] - start_ind, obj_span[1] - start_ind)
                    new_sample["relation_list"].append(new_rel)
#                 else:
#                     set_trace()
            if len(new_sample["relation_list"]) > 0:
                split_sample_list.append(new_sample)
#         if len(split_sample_list) == 0:
#             set_trace()
        new_sample_list.extend(split_sample_list)
    return new_sample_list

In [13]:
short_train_data = split_into_short_samples(train_data, sliding_len = config.sliding_len)
short_valid_data = split_into_short_samples(valid_data, sliding_len = config.sliding_len)
short_test_data = split_into_short_samples(test_data, sliding_len = config.sliding_len)

Splitting: 100%|██████████| 62335/62335 [00:41<00:00, 1495.70it/s]
Splitting: 100%|██████████| 313/313 [00:00<00:00, 1519.52it/s]
Splitting: 100%|██████████| 369/369 [00:00<00:00, 1602.94it/s]


In [14]:
print(len(short_train_data), len(short_valid_data), len(short_test_data))

89465 458 552


# Tagging

In [15]:
rel_set = set()
for sample in tqdm(train_data + valid_data + test_data):
    for rel in sample["relation_list"]:
        rel_set.add(rel["predicate"])
rel_set = sorted(rel_set)

100%|██████████| 63017/63017 [00:00<00:00, 400198.13it/s]


In [16]:
rel2id = {rel:ind for ind, rel in enumerate(rel_set)}

In [17]:
handshaking_tagger = HandshakingTaggingScheme(rel2id = rel2id, max_seq_len = config.max_seq_len)

In [18]:
def sample_equal_to(sample1, sample2):
    assert sample1["id"] == sample2["id"]
    assert sample1["text"] == sample2["text"]
    memory_set = set()
    for rel in sample2["relation_list"]:
        memory = "{}\u2E80{}\u2E80{}\u2E80{}\u2E80{}".format(rel["subject"], 
                                                             rel["predicate"], 
                                                             rel["object"], 
                                                             str(rel["subj_span"]), 
                                                             str(rel["obj_span"]))
        memory_set.add(memory)
    for rel in sample1["relation_list"]:
        memory = "{}\u2E80{}\u2E80{}\u2E80{}\u2E80{}".format(rel["subject"], 
                                                             rel["predicate"], 
                                                             rel["object"], 
                                                             str(rel["subj_span"]), 
                                                             str(rel["obj_span"]))
        if memory not in memory_set:
            set_trace()
            return False
    return True

In [19]:
# sample = train_data[1000]
# ent_matrix_spots, head_rel_matrix_spots, tail_rel_matrix_spots = handshaking_tagger.get_spots(sample)

# ent_shaking_tag = handshaking_tagger.sharing_spots2shaking_tag(ent_matrix_spots)
# head_rel_shaking_tag = handshaking_tagger.spots2shaking_tag(head_rel_matrix_spots)
# tail_rel_shaking_tag = handshaking_tagger.spots2shaking_tag(tail_rel_matrix_spots)
# # %timeit spots2shaking_tag(ent_matrix_spots)
# text = sample["text"]
# tok2char_span_map = get_tok2char_span_map(text)
# decoded_rel_list = handshaking_tagger.decode_rel_fr_shaking_tag(text, 
#                                              ent_shaking_tag, 
#                                              head_rel_shaking_tag, 
#                                              tail_rel_shaking_tag, 
#                                              tok2char_span_map)
# # %timeit decode_rel_fr_shaking_tag(text, ent_shaking_tag, head_rel_shaking_tag, tail_rel_shaking_tag, offset_map)
# pred_sample = {
#     "id": sample["id"],
#     "text": text,
#     "relation_list": decoded_rel_list,
# }
# sample_equal_to(pred_sample, sample)

In [20]:
# for sample in tqdm(train_data + valid_data + test_data):
#     ent_matrix_spots, head_rel_matrix_spots, tail_rel_matrix_spots = handshaking_tagger.get_spots(sample)

#     ent_shaking_tag = handshaking_tagger.sharing_spots2shaking_tag(ent_matrix_spots)
#     head_rel_shaking_tag = handshaking_tagger.spots2shaking_tag(head_rel_matrix_spots)
#     tail_rel_shaking_tag = handshaking_tagger.spots2shaking_tag(tail_rel_matrix_spots)
#     # %timeit spots2shaking_tag(ent_matrix_spots)
#     text = sample["text"]
#     tok2char_span_map = get_tok2char_span_map(text)
#     decoded_rel_list = handshaking_tagger.decode_rel_fr_shaking_tag(text, 
#                                                  ent_shaking_tag, 
#                                                  head_rel_shaking_tag, 
#                                                  tail_rel_shaking_tag, 
#                                                  tok2char_span_map)
#     # %timeit decode_rel_fr_shaking_tag(text, ent_shaking_tag, head_rel_shaking_tag, tail_rel_shaking_tag, offset_map)
#     pred_sample = {
#         "id": sample["id"],
#         "text": text,
#         "relation_list": decoded_rel_list,
#     }
#     if not sample_equal_to(pred_sample, sample) or not sample_equal_to(pred_sample, sample):
# #         set_trace()
#         pass

In [21]:
# # check batch tagging and decoding
# batch_samples = train_data[:100]
# batch_ent_spots, batch_head_rel_spots, batch_tail_rel_spots = [], [], []
# for sample in tqdm(batch_samples):
#     ent_matrix_spots, head_rel_matrix_spots, tail_rel_matrix_spots = handshaking_tagger.get_spots(sample)
#     batch_ent_spots.append(ent_matrix_spots)
#     batch_head_rel_spots.append(head_rel_matrix_spots)
#     batch_tail_rel_spots.append(tail_rel_matrix_spots)
# batch_ent_shaking_tag = handshaking_tagger.sharing_spots2shaking_tag4batch(batch_ent_spots)
# batch_head_rel_shaking_tag = handshaking_tagger.spots2shaking_tag4batch(batch_head_rel_spots)
# batch_tail_rel_shaking_tag = handshaking_tagger.spots2shaking_tag4batch(batch_tail_rel_spots)

# for ind, sample in tqdm(enumerate(batch_samples)):
#     text = sample["text"]
#     tok2char_span = get_tok2char_span_map(text)
#     ent_shaking_tag = batch_ent_shaking_tag[ind]
#     head_rel_shaking_tag = batch_head_rel_shaking_tag[ind]
#     tail_rel_shaking_tag = batch_tail_rel_shaking_tag[ind]
#     decoded_rel_list = handshaking_tagger.decode_rel_fr_shaking_tag(text, 
#                                                                      ent_shaking_tag,
#                                                                      head_rel_shaking_tag,
#                                                                      tail_rel_shaking_tag,
#                                                                      tok2char_span)
#     pred_sample = {
#         "id": sample["id"],
#         "text": text,
#         "relation_list": decoded_rel_list,
#     }
#     if not sample_equal_to(pred_sample, sample):
#         set_trace()

# Dataset

In [22]:
# @specific
def get_indexed_train_valid_data(data):
    indexed_samples = []
    for ind, sample in tqdm(enumerate(data), desc = "Generate indexed train or valid data"):
        text = sample["text"]
        text_id = sample["id"]
        # codes for bert input
        # @specific
        codes = tokenizer.encode_plus(text, 
                                    return_offsets_mapping = True, 
                                    add_special_tokens = False,
                                    max_length = config.max_seq_len, 
                                    pad_to_max_length = True)
        
        
        # tagging
        spots_tuple = handshaking_tagger.get_spots(sample)
        
        # get codes
        # @specific
        input_ids = torch.tensor(codes["input_ids"]).long()
        attention_mask = torch.tensor(codes["attention_mask"]).long()
        token_type_ids = torch.tensor(codes["token_type_ids"]).long()
        offset_map = codes["offset_mapping"]

        sample_tp = (text_id,
                     text, 
                     input_ids,
                     attention_mask,
                     token_type_ids,
                     offset_map,
                     spots_tuple,
                    )
        indexed_samples.append(sample_tp)       
    return indexed_samples

In [23]:
# @specific
def get_indexed_pred_data(data):
    indexed_samples = []
    for ind, sample in tqdm(enumerate(data), desc = "Generate indexed pred data"):
        text = sample["text"] 
        text_id = sample["id"]
        # @specific
        codes = tokenizer.encode_plus(text, 
                                    return_offsets_mapping = True, 
                                    add_special_tokens = False,
                                    max_length = config.max_seq_len, 
                                    pad_to_max_length = True)
        
        input_ids = torch.tensor(codes["input_ids"]).long()
        attention_mask = torch.tensor(codes["attention_mask"]).long()
        token_type_ids = torch.tensor(codes["token_type_ids"]).long()
        offset_map = codes["offset_mapping"]

        sample_tp = (text_id,
                     text, 
                     input_ids,
                     attention_mask,
                     token_type_ids,
                     offset_map,
                     )
        indexed_samples.append(sample_tp)       
    return indexed_samples

In [24]:
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)

In [25]:
def generate_train_dev_batch(batch_data):
    text_id_list = []
    text_list = []
    input_ids_list = []
    attention_mask_list = []
    token_type_ids_list = [] 
    offset_map_list = []
    ent_spots_list = []
    head_rel_spots_list = []
    tail_rel_spots_list = []
    
    for sample in batch_data:
        text_id_list.append(sample[0])
        text_list.append(sample[1])
        input_ids_list.append(sample[2])
        attention_mask_list.append(sample[3])        
        token_type_ids_list.append(sample[4])        
        offset_map_list.append(sample[5])
        
        ent_matrix_spots, head_rel_matrix_spots, tail_rel_matrix_spots = sample[6]
        ent_spots_list.append(ent_matrix_spots)
        head_rel_spots_list.append(head_rel_matrix_spots)
        tail_rel_spots_list.append(tail_rel_matrix_spots)
    
    # @specific: codes indexed by bert tokenizer
    batch_input_ids = torch.stack(input_ids_list, dim = 0)
    batch_attention_mask = torch.stack(attention_mask_list, dim = 0)
    batch_token_type_ids = torch.stack(token_type_ids_list, dim = 0)
    
#     t1 = time.time()
    batch_ent_shaking_tag = handshaking_tagger.sharing_spots2shaking_tag4batch(ent_spots_list)
    batch_head_rel_shaking_tag = handshaking_tagger.spots2shaking_tag4batch(head_rel_spots_list)
    batch_tail_rel_shaking_tag = handshaking_tagger.spots2shaking_tag4batch(tail_rel_spots_list)
#     print(time.time() - t1)
    return text_id_list, text_list, batch_input_ids, batch_attention_mask, batch_token_type_ids, offset_map_list, batch_ent_shaking_tag, batch_head_rel_shaking_tag, batch_tail_rel_shaking_tag

In [26]:
def generate_pred_batch(batch_data):
    text_ids = []
    text_list = []
    input_ids = []
    attention_mask = []
    token_type_ids = [] 
    offset_map = []
    for sample in batch_data:
        text_ids.append(sample[0])
        text_list.append(sample[1])
        input_ids.append(sample[2])
        attention_mask.append(sample[3])        
        token_type_ids.append(sample[4])        
        offset_map.append(sample[5])
    input_ids = torch.stack(input_ids, dim = 0)
    attention_mask = torch.stack(attention_mask, dim = 0)
    token_type_ids = torch.stack(token_type_ids, dim = 0)
    return text_ids, text_list, input_ids, attention_mask, token_type_ids, offset_map

In [27]:
# @uni
def get_train_dev_dataloader_gen(indexed_train_sample_list, indexed_dev_sample_list, batch_size):
    train_dataloader = DataLoader(MyDataset(indexed_train_sample_list), 
                                      batch_size = batch_size, 
                                      shuffle = True, 
                                      num_workers = 6,
                                      drop_last = False,
                                      collate_fn = generate_train_dev_batch,
                                     )
    dev_dataloader = DataLoader(MyDataset(indexed_dev_sample_list), 
                              batch_size = batch_size, 
                              shuffle = True, 
                              num_workers = 6,
                              drop_last = False,
                              collate_fn = generate_train_dev_batch,
                             )
    return train_dataloader, dev_dataloader

In [28]:
indexed_train_data = get_indexed_train_valid_data(short_train_data)

Generate indexed train or valid data: 89465it [00:55, 1617.43it/s]


In [29]:
indexed_valid_data = get_indexed_train_valid_data(short_valid_data)

Generate indexed train or valid data: 458it [00:00, 1565.99it/s]


In [30]:
# have a look at dataloader
train_dataloader, dev_dataloader = get_train_dev_dataloader_gen(indexed_train_data, indexed_valid_data, 32)

In [31]:
# # test data loading time

# # 0 workers: 40, 41
# # 3 workers: 28
# # 4 workers: 23
# # 5 workers: 19
# # 6 workers: 16, 
# # 7 workers: 22, 28, 21
# # 10 wokers: 27, 28, 
# for ep in range(100):
#     t1 = time.time()
#     for ind, batch_data in enumerate(dev_dataloader):
#         print("\r{}/{}".format(ind + 1, len(dev_dataloader)), end = "")
#     print("ep {}: {}".format(ep + 1, time.time() - t1))

In [32]:
# train_data_iter = iter(train_dataloader)
# %timeit next(train_data_iter)

In [33]:
train_data_iter = iter(train_dataloader)
batch_data = next(train_data_iter)
text_id_list, text_list, batch_input_ids, \
batch_attention_mask, batch_token_type_ids, \
offset_map_list, batch_ent_shaking_tag, \
batch_head_rel_shaking_tag, batch_tail_rel_shaking_tag = batch_data

print(text_list[0])
print()
print(tokenizer.decode(batch_input_ids[0].tolist()))
print(batch_input_ids.size())
print(batch_attention_mask.size())
print(batch_token_type_ids.size())
print(len(offset_map_list))
print(batch_ent_shaking_tag.size())
print(batch_head_rel_shaking_tag.size())
print(batch_tail_rel_shaking_tag.size())

Gov. Mark Warner , Democrat of Virginia ; Senator Russell D. Feingold , Democrat of Wisconsin ; and Senator Bill Frist , Republican of Tennessee . ''

Gov. Mark Warner, Democrat of Virginia ; Senator Russell D. Feingold, Democrat of Wisconsin ; and Senator Bill Frist, Republican of Tennessee.'' [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
torch.Size([32, 100])
torch.Size([32, 100])
torch.Size([32, 100])
32
torch.Size([32, 5050])
torch.Size([32, 12, 5050])
torch.Size([32, 12, 5050])


# Model

In [34]:
class RelExtractor(nn.Module):
    def __init__(self, encoder, rel_size):
        super().__init__()
        # @specific
        self.encoder = encoder
        hidden_size = encoder.config.hidden_size

        self.ent_fc = nn.Linear(hidden_size, 2)
        self.head_rel_fc_list = [nn.Linear(hidden_size, 3) for _ in range(rel_size)]
        self.tail_rel_fc_list = [nn.Linear(hidden_size, 3) for _ in range(rel_size)]
        
        # register fcs
#         for ind, fc in enumerate(self.ent_fc_list):
#             self.register_parameter("weight_4_ent_in_rel{}".format(ind), fc.weight)
#             self.register_parameter("bias_4_ent_in_rel{}".format(ind), fc.bias)
        for ind, fc in enumerate(self.head_rel_fc_list):
            self.register_parameter("weight_4_head_rel{}".format(ind), fc.weight)
            self.register_parameter("bias_4_head_rel{}".format(ind), fc.bias)
        for ind, fc in enumerate(self.tail_rel_fc_list):
            self.register_parameter("weight_4_tail_rel{}".format(ind), fc.weight)
            self.register_parameter("bias_4_tail_rel{}".format(ind), fc.bias)
            
        # conditional layer normaliztion
        fake_inputs = torch.zeros([config.batch_size, config.max_seq_len, hidden_size])
        self.cond_layer_norm = LayerNorm(fake_inputs.size(), hidden_size, conditional = True)
        
    def forward(self, input_ids, attention_mask, token_type_ids):
        # @specific
        # input_ids, attention_mask, token_type_ids: (batch_size, seq_len)
        context_outputs = self.encoder(input_ids, attention_mask, token_type_ids)
        # last_hidden_state: (batch_size, seq_len, hidden_size)
        last_hidden_state = context_outputs[0]
        
        # shaking_hiddens: (batch_size, 1 + ... + seq_len, hidden_size)
        shaking_hiddens = self.shake_hands_afterwards(last_hidden_state)
        
        ent_shaking_outputs = self.ent_fc(shaking_hiddens)
            
        head_rel_shaking_outputs_list = []
        for fc in self.head_rel_fc_list:
            head_rel_shaking_outputs_list.append(fc(shaking_hiddens))
            
        tail_rel_shaking_outputs_list = []
        for fc in self.tail_rel_fc_list:
            tail_rel_shaking_outputs_list.append(fc(shaking_hiddens))
        
        head_rel_shaking_outputs = torch.stack(head_rel_shaking_outputs_list, dim = 1)
        tail_rel_shaking_outputs = torch.stack(tail_rel_shaking_outputs_list, dim = 1)
        
        return ent_shaking_outputs, head_rel_shaking_outputs, tail_rel_shaking_outputs

    def shake_hands_afterwards(self, seq_hiddens):
        '''
        seq_hiddens: (batch_size, seq_len, hidden_size) (32, 3, 5)
        return shake_hands_matrix_hiddens: (batch_size, (1 + seq_len) * seq_len / 2, hidden_size) (32, 5+4+3+2+1, 5)
        '''
        seq_len = seq_hiddens.size()[-2]
        shake_hands_hidden_list = []
        for ind in range(seq_len):
            hidden_each_step = seq_hiddens[:, ind, :]
            # seq_len - ind: only shake afterwards
            repeat_hidden_each_step = hidden_each_step[:, None, :].repeat(1, seq_len - ind, 1) 
    #         shake_hands_hidden = torch.cat([repeat_hidden_each_step, seq_hiddens[:, ind:, :]], dim = -1)
            shake_hands_hidden = self.cond_layer_norm(seq_hiddens[:, ind:, :], repeat_hidden_each_step)
            shake_hands_hidden_list.append(shake_hands_hidden)
        shake_hands_matrix_hiddens = torch.cat(shake_hands_hidden_list, dim = 1)
        return shake_hands_matrix_hiddens

In [35]:
roberta = AutoModel.from_pretrained(model_path)

In [36]:
rel_extractor = RelExtractor(roberta, len(rel2id))
rel_extractor = rel_extractor.to(device)

In [37]:
# wandb.watch(rel_extractor, log = "all")

In [38]:
# # test model
# batch_input_ids, \
# batch_attention_mask,\
# batch_token_type_ids = batch_input_ids.to(device), \
#                          batch_attention_mask.to(device), \
#                          batch_token_type_ids.to(device)
# with torch.no_grad():                 
#     ent_shaking_outputs, \
#     head_rel_shaking_outputs, \
#     tail_rel_shaking_outputs = rel_extractor(batch_input_ids, 
#                                               batch_attention_mask, 
#                                               batch_token_type_ids)

# print(ent_shaking_outputs.size())
# print(head_rel_shaking_outputs.size())
# print(tail_rel_shaking_outputs.size())

In [39]:
def bias_loss(weights = None):
    if weights is not None:
        weights = torch.FloatTensor(weights).to(device)
    cross_en = nn.CrossEntropyLoss(weight = weights)  
    return lambda pred, target: cross_en(pred.view(-1, pred.size()[-1]), target.view(-1))
loss_func = bias_loss()

In [40]:
def get_sample_accuracy(pred, truth):
    '''
    计算所有抽取字段都正确的样本比例
    即该batch的输出与truth全等的样本比例
    '''
    # (batch_size, ..., seq_len, tag_size) -> (batch_size, ..., seq_len)
    pred_id = torch.argmax(pred, dim = -1)
    # (batch_size, ..., seq_len) -> (batch_size, )，把每个sample压成一条seq
    pred_id = pred_id.view(pred_id.size()[0], -1)
    truth = truth.view(truth.size()[0], -1)
    
    # (batch_size, )，每个元素是pred与truth之间tag相同的数量
    correct_tag_num = torch.sum(torch.eq(truth, pred_id).float(), dim = 1)

    # seq维上所有tag必须正确，所以correct_tag_num必须等于seq的长度才算一个correct的sample
    sample_acc_ = torch.eq(correct_tag_num, torch.ones_like(correct_tag_num) * truth.size()[-1]).float()
    sample_acc = torch.mean(sample_acc_)
    
    return sample_acc

In [41]:
def get_rel_cpg(text_list, offset_map_list, 
                 batch_pred_ent_shaking_outputs,
                 batch_pred_head_rel_shaking_outputs,
                 batch_pred_tail_rel_shaking_outputs,
                 batch_gold_ent_shaking_tag,
                 batch_gold_head_rel_shaking_tag,
                 batch_gold_tail_rel_shaking_tag):
    batch_pred_ent_shaking_tag = torch.argmax(batch_pred_ent_shaking_outputs, dim = -1)
    batch_pred_head_rel_shaking_tag = torch.argmax(batch_pred_head_rel_shaking_outputs, dim = -1)
    batch_pred_tail_rel_shaking_tag = torch.argmax(batch_pred_tail_rel_shaking_outputs, dim = -1)
    
    correct_num, pred_num, gold_num = 0, 0, 0
    for ind in range(len(text_list)):
        text = text_list[ind]
        offset_map = offset_map_list[ind]
        gold_ent_shaking_tag, pred_ent_shaking_tag = batch_gold_ent_shaking_tag[ind], batch_pred_ent_shaking_tag[ind]
        gold_head_rel_shaking_tag, pred_head_rel_shaking_tag = batch_gold_head_rel_shaking_tag[ind], batch_pred_head_rel_shaking_tag[ind]
        gold_tail_rel_shaking_tag, pred_tail_rel_shaking_tag = batch_gold_tail_rel_shaking_tag[ind], batch_pred_tail_rel_shaking_tag[ind]
        
        pred_rel_list = handshaking_tagger.decode_rel_fr_shaking_tag(text, 
                                                  pred_ent_shaking_tag, 
                                                  pred_head_rel_shaking_tag, 
                                                  pred_tail_rel_shaking_tag, 
                                                  offset_map)
        gold_rel_list = handshaking_tagger.decode_rel_fr_shaking_tag(text, 
                                                  gold_ent_shaking_tag, 
                                                  gold_head_rel_shaking_tag, 
                                                  gold_tail_rel_shaking_tag, 
                                                  offset_map)

        gold_rel_set = set(["{}\u2E80{}\u2E80{}".format(rel["subject"], rel["predicate"], rel["object"]) for rel in gold_rel_list])
        pred_rel_set = set(["{}\u2E80{}\u2E80{}".format(rel["subject"], rel["predicate"], rel["object"]) for rel in pred_rel_list])
        
        for rel_str in pred_rel_set:
            if rel_str in gold_rel_set:
                correct_num += 1
        
        pred_num += len(gold_rel_set)
        gold_num += len(pred_rel_set)
        
    return correct_num, pred_num, gold_num

In [42]:
def get_scores(correct_num, pred_num, gold_num):
    minimini = 1e-10
    precision = correct_num / (pred_num + minimini)
    recall = correct_num / (gold_num + minimini)
    f1 = 2 * precision * recall / (precision + recall + minimini)
    return precision, recall, f1

# Train

In [43]:
# train step
def train_step(batch_train_data, optimizer, loss_weights):
    text_id_list, text_list, batch_input_ids, \
    batch_attention_mask, batch_token_type_ids, \
    offset_map_list, batch_ent_shaking_tag, \
    batch_head_rel_shaking_tag, batch_tail_rel_shaking_tag = batch_train_data
    
    batch_input_ids, \
    batch_attention_mask, \
    batch_token_type_ids, \
    batch_ent_shaking_tag, \
    batch_head_rel_shaking_tag, \
    batch_tail_rel_shaking_tag = (batch_input_ids.to(device), 
                              batch_attention_mask.to(device), 
                              batch_token_type_ids.to(device), 
                              batch_ent_shaking_tag.to(device), 
                              batch_head_rel_shaking_tag.to(device), 
                              batch_tail_rel_shaking_tag.to(device)
                             )

    # zero the parameter gradients
    optimizer.zero_grad()
    
    ent_shaking_outputs, \
    head_rel_shaking_outputs, \
    tail_rel_shaking_outputs = rel_extractor(batch_input_ids, 
                                              batch_attention_mask, 
                                              batch_token_type_ids, 
                                             )

    w_ent, w_rel = loss_weights["ent"], loss_weights["rel"]
    loss = w_ent * loss_func(ent_shaking_outputs, batch_ent_shaking_tag) + \
            w_rel * loss_func(head_rel_shaking_outputs, batch_head_rel_shaking_tag) + \
            w_rel * loss_func(tail_rel_shaking_outputs, batch_tail_rel_shaking_tag)
    
    # bp time: 2s
    loss.backward()
    optimizer.step()
    
    ent_sample_acc = get_sample_accuracy(ent_shaking_outputs, 
                                          batch_ent_shaking_tag)
    head_rel_sample_acc = get_sample_accuracy(head_rel_shaking_outputs, 
                                             batch_head_rel_shaking_tag)
    tail_rel_sample_acc = get_sample_accuracy(tail_rel_shaking_outputs, 
                                             batch_tail_rel_shaking_tag)
    
    return loss.item(), ent_sample_acc.item(), head_rel_sample_acc.item(), tail_rel_sample_acc.item()

# valid step
def valid_step(batch_valid_data):
    text_id_list, text_list, batch_input_ids, \
    batch_attention_mask, batch_token_type_ids, \
    offset_map_list, batch_ent_shaking_tag, \
    batch_head_rel_shaking_tag, batch_tail_rel_shaking_tag = batch_valid_data
    
    batch_input_ids, \
    batch_attention_mask, \
    batch_token_type_ids, \
    batch_ent_shaking_tag, \
    batch_head_rel_shaking_tag, \
    batch_tail_rel_shaking_tag = (batch_input_ids.to(device), 
                              batch_attention_mask.to(device), 
                              batch_token_type_ids.to(device), 
                              batch_ent_shaking_tag.to(device), 
                              batch_head_rel_shaking_tag.to(device), 
                              batch_tail_rel_shaking_tag.to(device)
                             )
    with torch.no_grad():
        ent_shaking_outputs, \
        head_rel_shaking_outputs, \
        tail_rel_shaking_outputs = rel_extractor(batch_input_ids, 
                                                  batch_attention_mask, 
                                                  batch_token_type_ids, 
                                                 )
    
    ent_sample_acc = get_sample_accuracy(ent_shaking_outputs, 
                                          batch_ent_shaking_tag)
    head_rel_sample_acc = get_sample_accuracy(head_rel_shaking_outputs, 
                                             batch_head_rel_shaking_tag)
    tail_rel_sample_acc = get_sample_accuracy(tail_rel_shaking_outputs, 
                                             batch_tail_rel_shaking_tag)
    
    rel_cpg = get_rel_cpg(text_list, offset_map_list, 
                            ent_shaking_outputs,
                            head_rel_shaking_outputs,
                            tail_rel_shaking_outputs,
                            batch_ent_shaking_tag,
                            batch_head_rel_shaking_tag,
                            batch_tail_rel_shaking_tag)
    
    return ent_sample_acc.item(), head_rel_sample_acc.item(), tail_rel_sample_acc.item(), rel_cpg

In [44]:
max_f1 = 0.
def train_n_valid(train_dataloader, dev_dataloader, optimizer, scheduler, num_epoch):  
    def train(dataloader, ep):
        # train
        rel_extractor.train()
        
        t_ep = time.time()
        start_lr = optimizer.param_groups[0]['lr']
        total_loss, total_ent_sample_acc, total_head_rel_sample_acc, total_tail_rel_sample_acc = 0., 0., 0., 0.
        for batch_ind, batch_train_data in enumerate(dataloader):
            t_batch = time.time()
            z = (2 * len(rel2id) + 1)
            steps_per_ep = len(dataloader)
            total_steps = config.loss_weight_recover_steps
            current_step = steps_per_ep * ep + batch_ind
            w_ent = max(1 / z + 1 - current_step / total_steps, 1 / z)
            w_rel = min((len(rel2id) / z) * current_step / total_steps, (len(rel2id) / z))
            loss, ent_sample_acc, head_rel_sample_acc, tail_rel_sample_acc = train_step(batch_train_data, optimizer, loss_weights)
            scheduler.step()
            
            total_loss += loss
            total_ent_sample_acc += ent_sample_acc
            total_head_rel_sample_acc += head_rel_sample_acc
            total_tail_rel_sample_acc += tail_rel_sample_acc
            
            avg_loss = total_loss / (batch_ind + 1)
            avg_ent_sample_acc = total_ent_sample_acc / (batch_ind + 1)
            avg_head_rel_sample_acc = total_head_rel_sample_acc / (batch_ind + 1)
            avg_tail_rel_sample_acc = total_tail_rel_sample_acc / (batch_ind + 1)
            
            batch_print_format = "\rEpoch: {}/{}, batch: {}/{}, train_loss: {}, " + \
                                "t_ent_sample_acc: {}, t_head_rel_sample_acc: {}, t_tail_rel_sample_acc: {}," + \
                                 "lr: {}, batch_time: {}, total_time: {} -------------"
                    
            print(batch_print_format.format(ep + 1, num_epoch, 
                                            batch_ind + 1, len(dataloader), 
                                            avg_loss, 
                                            avg_ent_sample_acc,
                                            avg_head_rel_sample_acc,
                                            avg_tail_rel_sample_acc,
                                            optimizer.param_groups[0]['lr'],
                                            time.time() - t_batch,
                                            time.time() - t_ep,
                                           ), end="")
            
            if batch_ind % config.log_interval == 0:
                wandb.log({
                    "train_loss": avg_loss,
                    "train_ent_seq_acc": avg_ent_sample_acc,
                    "train_head_rel_acc": avg_head_rel_sample_acc,
                    "train_tail_rel_acc": avg_tail_rel_sample_acc,
                    "learning_rate": optimizer.param_groups[0]['lr'],
                    "time": time.time() - t_ep,
                })
            
#         log_format = "Epoch: {}/{}, train_sample_acc(ent, head_rel, tail_rel): ({},{},{}), start_lr:{}, end_lr: {}, ep_time: {}"
#         logger.info(log_format.format(ep + 1, num_epoch, 
#                                       avg_ent_sample_acc, 
#                                       avg_head_rel_sample_acc, 
#                                       avg_tail_rel_sample_acc, 
#                                       start_lr, 
#                                       optimizer.param_groups[0]['lr'], 
#                                       time.time() - t_ep))
        
    def valid(dataloader, ep):
        # valid
        rel_extractor.eval()
        
        t_ep = time.time()
        total_ent_sample_acc, total_head_rel_sample_acc, total_tail_rel_sample_acc = 0., 0., 0.
        total_rel_correct_num, total_rel_pred_num, total_rel_gold_num = 0, 0, 0
        for batch_ind, batch_valid_data in enumerate(tqdm(dataloader, desc = "Validating")):
            ent_sample_acc, head_rel_sample_acc, tail_rel_sample_acc, rel_cpg = valid_step(batch_valid_data)

            total_ent_sample_acc += ent_sample_acc
            total_head_rel_sample_acc += head_rel_sample_acc
            total_tail_rel_sample_acc += tail_rel_sample_acc
            
            total_rel_correct_num += rel_cpg[0]
            total_rel_pred_num += rel_cpg[1]
            total_rel_gold_num += rel_cpg[2]

        avg_ent_sample_acc = total_ent_sample_acc / len(dataloader)
        avg_head_rel_sample_acc = total_head_rel_sample_acc / len(dataloader)
        avg_tail_rel_sample_acc = total_tail_rel_sample_acc / len(dataloader)
        
        rel_prf = get_scores(total_rel_correct_num, total_rel_pred_num, total_rel_gold_num)
        
        log_dict = {
                        "val_ent_seq_acc": avg_ent_sample_acc,
                        "val_head_rel_acc": avg_head_rel_sample_acc,
                        "val_tail_rel_acc": avg_tail_rel_sample_acc,
                        "val_prec": rel_prf[0],
                        "val_recall": rel_prf[1],
                        "val_f1": rel_prf[2],
                        "time": time.time() - t_ep,
                    }
        wandb.log(log_dict)
        pprint(log_dict)
        
        return rel_prf[2]
        
    for ep in range(num_epoch):
        train(train_dataloader, ep)   
        valid_f1 = valid(dev_dataloader, ep)
        
        global max_f1
        if valid_f1 >= max_f1: 
            max_f1 = valid_f1
            if valid_f1 > 0.5: # save the best model
                modle_state_num = len(glob.glob(model_state_dict_dir + "/model_state_dict_*.pt"))
                torch.save(rel_extractor.state_dict(), os.path.join(model_state_dict_dir, "model_state_dict_{}.pt".format(modle_state_num)))
                scheduler_state_num = len(glob.glob(schedule_state_dict_dir + "/scheduler_state_dict_*.pt"))
                torch.save(scheduler.state_dict(), os.path.join(schedule_state_dict_dir, "scheduler_state_dict_{}.pt".format(scheduler_state_num))) 
        print("Current avf_f1: {}, Best f1: {}".format(valid_f1, max_f1))

In [45]:
# torch.save(rel_extractor.state_dict(), os.path.join(model_state_dict_dir, "model_state_dict_{}.pt".format(0))) 
# torch.save(scheduler.state_dict(), os.path.join(schedule_state_dict_dir, "scheduler_state_dict_{}.pt".format(0))) 

In [46]:
def get_last_state_path(state_dir, pre_fix):
    max_file_num = -1
    last_state_path = None
    for path in glob.glob(state_dir + "/{}_*.pt".format(pre_fix)):
        file_num = re.search("state_dict_(\d+)\.pt", path).group(1)
        if int(file_num) > max_file_num:
            max_file_num = int(file_num)
            last_state_path = path
    return last_state_path

def get_model_state_path(state_dict_dir, state_dict_num):
    return os.path.join(state_dict_dir, "model_state_dict_{}.pt".format(state_dict_num))

In [None]:
# dataloader
print("preparing dataloader...")
train_dataloader, \
dev_dataloader = get_train_dev_dataloader_gen(indexed_train_data, 
                                            indexed_valid_data, 
                                            config.batch_size, 
                                            )
print("dataloaders done!")

In [None]:
# optimizer
init_learning_rate = config.lr
optimizer = torch.optim.Adam(rel_extractor.parameters(), lr = init_learning_rate)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, len(train_dataloader) * 2)

# decay_rate = 0.99
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 10, gamma = decay_rate)

In [None]:
epoch_num = config.epochs
# load the last best state (if any)
model_last_state_path = get_last_state_path(model_state_dict_dir, "model_state_dict")
if model_last_state_path is not None:
    rel_extractor.load_state_dict(torch.load(model_last_state_path))
    print("------------model state {} loaded ----------------".format(model_last_state_path.split("/")[-1]))
    
scheduler_last_state_path = get_last_state_path(schedule_state_dict_dir, "scheduler_state_dict")  
if scheduler_last_state_path is not None:
    scheduler.load_state_dict(torch.load(scheduler_last_state_path))
    print("------------scheduler state {} loaded ----------------".format(scheduler_last_state_path.split("/")[-1]))

train_n_valid(train_dataloader, dev_dataloader, optimizer, scheduler, epoch_num)

# Prediction

In [47]:
model_state_dict_dir = os.path.join(project_root, "notebook/wandb/run-20200529_153219-3nj0nz8n")

In [48]:
model_state_path = get_last_state_path(model_state_dict_dir, "model_state_dict")
# model_state_path = get_model_state_path(model_state_dict_dir, 16)
rel_extractor.load_state_dict(torch.load(model_state_path))
rel_extractor.eval()
print("------------model state {} loaded ----------------".format(model_state_path.split("/")[-1]))

------------model state model_state_dict_8.pt loaded ----------------


In [49]:
def filter_duplicates(rel_list):
    rel_memory_set = set()
    filtered_rel_list = []
    for rel in rel_list:
        rel_memory = "{}\u2E80\{}\u2E80\{}\u2E80\{}\u2E80{}".format(*rel.values())
        if rel_memory not in rel_memory_set:
            filtered_rel_list.append(rel)
            rel_memory_set.add(rel_memory)
    return filtered_rel_list

In [50]:
def predict(short_test_data):
    '''
    short_test_data: seq_len <= max_seq_len
    '''
    indexed_test_data = get_indexed_train_valid_data(short_test_data)
    test_dataloader = DataLoader(MyDataset(indexed_test_data), 
                              batch_size = config.test_batch_size, 
                              shuffle = False, 
                              num_workers = 0,
                              drop_last = False,
                              collate_fn = generate_pred_batch,
                             )
    short_pred_sample_list = []
    for batch_test_data in tqdm(test_dataloader, desc = "Predicting"):
        text_id_list, text_list, batch_input_ids, \
        batch_attention_mask, batch_token_type_ids, \
        offset_map_list = batch_test_data

        batch_input_ids, \
        batch_attention_mask, \
        batch_token_type_ids = (batch_input_ids.to(device), 
                                  batch_attention_mask.to(device), 
                                  batch_token_type_ids.to(device)
                                 )
        with torch.no_grad():
            batch_ent_shaking_outputs, \
            batch_head_rel_shaking_outputs, \
            batch_tail_rel_shaking_outputs = rel_extractor(batch_input_ids, 
                                                              batch_attention_mask, 
                                                              batch_token_type_ids, 
                                                             )

        batch_ent_shaking_tag, \
        batch_head_rel_shaking_tag, \
        batch_tail_rel_shaking_tag = torch.argmax(batch_ent_shaking_outputs, dim = -1), \
                                     torch.argmax(batch_head_rel_shaking_outputs, dim = -1), \
                                     torch.argmax(batch_tail_rel_shaking_outputs, dim = -1)

        for ind in range(len(text_list)):
            text, offset_map = text_list[ind], offset_map_list[ind]
            ent_shaking_tag, \
            head_rel_shaking_tag, \
            tail_rel_shaking_tag = batch_ent_shaking_tag[ind], \
                                    batch_head_rel_shaking_tag[ind], \
                                    batch_tail_rel_shaking_tag[ind]
            rel_list = handshaking_tagger.decode_rel_fr_shaking_tag(text, 
                                                  ent_shaking_tag, 
                                                  head_rel_shaking_tag, 
                                                  tail_rel_shaking_tag, 
                                                  offset_map)
            short_pred_sample_list.append({
                "text": text,
                "id": text_id_list[ind],
                "relation_list": rel_list,
            })
    # merge
    text_id2rel_list = {}
    for sample in short_pred_sample_list:
        text_id = sample["id"]
        if text_id not in text_id2rel_list:
            text_id2rel_list[text_id] = sample["relation_list"]
        else:
            text_id2rel_list[text_id].extend(sample["relation_list"])

    text_id2text = {sample["id"]:sample["text"] for sample in test_data}
    merged_pred_sample_list = []
    for text_id, rel_list in text_id2rel_list.items():
        merged_pred_sample_list.append({
            "id": text_id,
            "text": text_id2text[text_id],
            "relation_list": filter_duplicates(rel_list),
        })
    return merged_pred_sample_list

In [51]:
pred_sample_list = predict(short_test_data)

Generate indexed train or valid data: 552it [00:00, 2000.76it/s]
Predicting: 100%|██████████| 18/18 [00:04<00:00,  3.97it/s]


In [52]:
len([s for s in pred_sample_list if len(s["relation_list"]) > 0])

320

In [53]:
text_id2gold_n_pred = {}
for sample in test_data:
    text_id = sample["id"]
    text_id2gold_n_pred[text_id] = {
        "gold_relation_list": sample["relation_list"],
    }
def get_test_prf(pred_sample_list):
    for sample in pred_sample_list:
        text_id = sample["id"]
        text_id2gold_n_pred[text_id]["pred_relation_list"] = sample["relation_list"]

    correct_num, pred_num, gold_num = 0, 0, 0
    for gold_n_pred in text_id2gold_n_pred.values():
        gold_rel_list = gold_n_pred["gold_relation_list"]
        pred_rel_list = gold_n_pred["pred_relation_list"] if "pred_relation_list" in gold_n_pred else []
        gold_rel_set = set(["{}\u2E80{}\u2E80{}".format(rel["subject"], rel["predicate"], rel["object"]) for rel in gold_rel_list])
        pred_rel_set = set(["{}\u2E80{}\u2E80{}".format(rel["subject"], rel["predicate"], rel["object"]) for rel in pred_rel_list])

        for rel_str in pred_rel_set:
            if rel_str in gold_rel_set:
                correct_num += 1

        pred_num += len(gold_rel_set)
        gold_num += len(pred_rel_set)

    prf = get_scores(correct_num, pred_num, gold_num)
#     print(prf)
    return prf

In [54]:
# model state 16: (0.9112068965517129, 0.9034188034187924, 0.9072961372890456)
# model state 17: (0.9060344827586095, 0.9096191889218483, 0.9078232970872052)
# 18: (0.9178571428571316, 0.904600072824361, 0.9111803899493801)
get_test_prf(pred_sample_list)

(0.5972972972971359, 0.518779342722883, 0.5552763818596557)