In [1]:
import json
import os
from tqdm import tqdm
import re
from IPython.core.debugger import set_trace
from pprint import pprint
import unicodedata
from transformers import AutoModel, BasicTokenizer, BertTokenizerFast
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import glob
import time
from layers import LayerNorm
import wandb
from utils import Preprocessor, HandshakingTaggingScheme
import logging
from glove import Glove
import numpy as np

In [2]:
logger = logging.getLogger()
logger.setLevel(logging.WARNING)

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device {} will be used".format(device))

device cuda:0 will be used


In [9]:
pretrained_model_home = "/data/yubowen/experiments/relextr/pretrained_model"
project_root = "/data/yubowen/experiments/relextr"
data_home = os.path.join(project_root, "data")

experiment_dir = os.path.join(project_root, "exp")
experiment_name = "nyt_single"
    
nyt_data_dir = os.path.join(data_home, experiment_name)
nyt_train_data_path = os.path.join(nyt_data_dir, "train_triples.json")
nyt_valid_data_path = os.path.join(nyt_data_dir, "valid_triples.json")
nyt_test_data_path_dict = {
    "test_triples": os.path.join(nyt_data_dir, "test_triples.json"),
    "test_triples_epo": os.path.join(nyt_data_dir, "test_split_by_type", "test_triples_epo.json"),
    "test_triples_seo": os.path.join(nyt_data_dir, "test_split_by_type", "test_triples_seo.json"),
    "test_triples_normal": os.path.join(nyt_data_dir, "test_split_by_type", "test_triples_normal.json"),
    "test_triples_1": os.path.join(nyt_data_dir, "test_split_by_num", "test_triples_1.json"),
    "test_triples_2": os.path.join(nyt_data_dir, "test_split_by_num", "test_triples_2.json"),
    "test_triples_3": os.path.join(nyt_data_dir, "test_split_by_num", "test_triples_3.json"),
    "test_triples_4": os.path.join(nyt_data_dir, "test_split_by_num", "test_triples_4.json"),
    "test_triples_5": os.path.join(nyt_data_dir, "test_split_by_num", "test_triples_5.json"),
}

In [5]:
wandb.init(project = experiment_name, name = "BiLSTM")

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable


W&B Run: https://app.wandb.ai/wycheng/nyt_single/runs/33ahq1zb

In [6]:
# hyperparameters
config = wandb.config          # Initialize config
config.batch_size = 6          # input batch size for training (default: 64)
config.test_batch_size = 100    # input batch size for testing (default: 1000)
config.epochs = 50             # number of epochs to train (default: 10)
config.lr = 1e-3               # learning rate (default: 0.01)
config.seed = 2333               # random seed (default: 42)
config.log_interval = 10  
config.max_seq_len = 100
config.sliding_len = 20
config.loss_weight_recover_steps = 10000

config.word_embedding_dim = 100
config.rnn_hidden_size = 256
config.dropout = 0.1

config.word_embedding_path = os.path.join(data_home, "pretrained_word_embeddings", "glove_100_nyt.emb")

torch.manual_seed(config.seed) # pytorch random seed
torch.backends.cudnn.deterministic = True

model_state_dict_dir = wandb.run.dir
schedule_state_dict_dir = wandb.run.dir

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable


# Load Data

In [10]:
nyt_train_data = json.load(open(nyt_train_data_path, "r", encoding = "utf-8"))
nyt_valid_data = json.load(open(nyt_valid_data_path, "r", encoding = "utf-8"))
nyt_test_data_dict = {}
for file_name, path in nyt_test_data_path_dict.items():
    nyt_test_data_dict[file_name] = json.load(open(path, "r", encoding = "utf-8"))

# Preprocess

In [11]:
# miss the last token
# def get_tok2char_span_map(text):
#     res = []
#     left = 0
#     sign = False
#     for i,c in enumerate(text):
#         if text[i].isalnum():
#             if sign == False:
#                 left = i
#                 sign = True
#         else:
#             if text[i] == ' ':
#                 if sign == True:
#                     res.append((left, i))
#                     sign = False
#             else:
#                 if sign == True:
#                     res.append((left, i))  
#                     sign = False
#                 res.append((i, i+1))
#     return res

In [12]:
def get_tok2char_span_map(text):
    tokens = text.split(" ")
    tok2char_span = []
    char_num = 0
    for tok in tokens:
        tok2char_span.append((char_num, char_num + len(tok)))
        char_num += len(tok) + 1 # +1: whitespace
    return tok2char_span

In [13]:
def tran2normal_samples(data):
    normal_sample_list = []
    for sample in tqdm(data, desc = "Transforming data format"):
        text = sample["text"]
        spo_list = sample["triple_list"]
        normal_sample = {
            "text": text,
            "id": sample["id"],
        }
        normal_rel_list = []
        for rel in spo_list:
            normal_rel_list.append({
                "subject": rel[0],
                "predicate": rel[1],
                "object": rel[2],
            })
        normal_sample["relation_list"] = normal_rel_list
        normal_sample_list.append(normal_sample)
    return normal_sample_list

In [14]:
preprocessor = Preprocessor(transform_func = tran2normal_samples, 
                            get_tok2char_span_map_func = get_tok2char_span_map)

train_data = preprocessor.get_normal_dataset(nyt_train_data, add_id = True, dataset_name = "train")
valid_data = preprocessor.get_normal_dataset(nyt_valid_data, add_id = True, dataset_name = "valid")
test_data_dict = {}
for file_name, data in nyt_test_data_dict.items():
    preprocessed_data = preprocessor.get_normal_dataset(data, add_id = True, dataset_name = "file_name")
    test_data_dict[file_name] = preprocessed_data

56195it [00:00, 526419.16it/s]
Transforming data format: 100%|██████████| 56195/56195 [00:00<00:00, 69666.28it/s] 
Adding token level spans: 100%|██████████| 56195/56195 [00:16<00:00, 3344.33it/s]
4999it [00:00, 468031.11it/s]
Transforming data format: 100%|██████████| 4999/4999 [00:00<00:00, 144721.02it/s]
Adding token level spans: 100%|██████████| 4999/4999 [00:01<00:00, 3538.75it/s]
5000it [00:00, 487075.44it/s]
Transforming data format: 100%|██████████| 5000/5000 [00:00<00:00, 154476.09it/s]
Adding token level spans: 100%|██████████| 5000/5000 [00:01<00:00, 3552.82it/s]
978it [00:00, 359385.78it/s]
Transforming data format: 100%|██████████| 978/978 [00:00<00:00, 104873.68it/s]
Adding token level spans: 100%|██████████| 978/978 [00:00<00:00, 3600.68it/s]
1297it [00:00, 389993.00it/s]
Transforming data format: 100%|██████████| 1297/1297 [00:00<00:00, 88573.58it/s]
Adding token level spans: 100%|██████████| 1297/1297 [00:00<00:00, 3350.38it/s]
3266it [00:00, 418674.07it/s]
Transformin

# Split

In [15]:
def split_into_short_samples(sample_list, sliding_len = 50):
    new_sample_list = []
    for sample in tqdm(sample_list, desc = "Splitting"):
        text_id = sample["id"]
        text = sample["text"]
        
        offset_map = get_tok2char_span_map(text)
        tokens = [text[a:b] for a,b in offset_map]
        
        # sliding on token level
        split_sample_list = []
        for start_ind in range(0, len(tokens), sliding_len):
            end_ind = start_ind + config.max_seq_len
#             while "##" in tokens[end_ind]:
#                 end_ind += 1
            char_span_list = offset_map[start_ind:end_ind]
            char_level_span = (char_span_list[0][0], char_span_list[-1][1])
            sub_text = text[char_level_span[0]:char_level_span[1]]

            new_sample = {
                "id": text_id,
                "text": sub_text,
                "relation_list": []
            }
            for rel in sample["relation_list"]:
                subj_span = rel["subj_span"]
                obj_span = rel["obj_span"]
                if subj_span[0] >= start_ind and subj_span[1] <= end_ind \
                    and obj_span[0] >= start_ind and obj_span[1] <= end_ind:
                    new_rel = copy.deepcopy(rel)
                    new_rel["subj_span"] = (subj_span[0] - start_ind, subj_span[1] - start_ind)
                    new_rel["obj_span"] = (obj_span[0] - start_ind, obj_span[1] - start_ind)
                    new_sample["relation_list"].append(new_rel)
#                 else:
#                     set_trace()
            if len(new_sample["relation_list"]) > 0:
                split_sample_list.append(new_sample)
#         if len(split_sample_list) == 0:
#             set_trace()
        new_sample_list.extend(split_sample_list)
    return new_sample_list

In [16]:
short_train_data = split_into_short_samples(train_data, sliding_len = config.sliding_len)
short_valid_data = split_into_short_samples(valid_data, sliding_len = config.sliding_len)

Splitting: 100%|██████████| 56195/56195 [00:08<00:00, 6752.46it/s]
Splitting: 100%|██████████| 4999/4999 [00:01<00:00, 4512.29it/s]


In [17]:
short_test_data_dict = {}
for file_name, data in test_data_dict.items():
    short_test_data = split_into_short_samples(data, sliding_len = config.sliding_len)
    short_test_data_dict[file_name] = short_test_data

Splitting: 100%|██████████| 5000/5000 [00:00<00:00, 7786.09it/s]
Splitting: 100%|██████████| 978/978 [00:00<00:00, 2500.67it/s]
Splitting: 100%|██████████| 1297/1297 [00:00<00:00, 2753.30it/s]
Splitting: 100%|██████████| 3266/3266 [00:00<00:00, 9862.15it/s] 
Splitting: 100%|██████████| 3244/3244 [00:00<00:00, 9727.61it/s]
Splitting: 100%|██████████| 1045/1045 [00:00<00:00, 6995.80it/s]
Splitting: 100%|██████████| 312/312 [00:00<00:00, 4922.24it/s]
Splitting: 100%|██████████| 291/291 [00:00<00:00, 4375.99it/s]
Splitting: 100%|██████████| 108/108 [00:00<00:00, 2577.17it/s]


In [18]:
print("train: {}".format(len(short_train_data)), "valid: {}".format(len(short_valid_data)))
for fil_name, data in short_test_data_dict.items():
        print("{}: {}".format(fil_name, len(data)))

train: 75129 valid: 6724
test_triples: 6677
test_triples_epo: 1294
test_triples_seo: 1851
test_triples_normal: 4283
test_triples_1: 4243
test_triples_2: 1415
test_triples_3: 448
test_triples_4: 387
test_triples_5: 184


# Tagging

In [19]:
from collections import defaultdict
rel_set = set()
word2num = defaultdict(int)
word2idx = {'<PAD>':0, '<UNK>':1}
idx2word = {}
idx = 2

all_data = train_data + valid_data 
for data in list(short_test_data_dict.values()):
    all_data.extend(data)
    
for sample in tqdm(all_data):
    for rel in sample["relation_list"]:
        rel_set.add(rel["predicate"])
    text = sample['text']
    span_list = get_tok2char_span_map(text)
    for span in span_list:
        word = text[span[0]:span[1]]
        word2num[word] += 1

#过滤出现次数小于3的word
for k,v in word2num.items():
    if v < 3:
        continue
    word2idx[k] = idx
    idx += 1
for k,v in word2idx.items():
    idx2word[v] = k
rel_set = sorted(rel_set)

100%|██████████| 81976/81976 [00:05<00:00, 14907.94it/s]


In [20]:
len(word2idx)

39708

In [21]:
rel2id = {rel:ind for ind, rel in enumerate(rel_set)}

In [22]:
handshaking_tagger = HandshakingTaggingScheme(rel2id = rel2id, max_seq_len = config.max_seq_len)

In [23]:
def sample_equal_to(sample1, sample2):
    assert sample1["id"] == sample2["id"]
    assert sample1["text"] == sample2["text"]
    memory_set = set()
    for rel in sample2["relation_list"]:
        memory = "{}\u2E80{}\u2E80{}\u2E80{}\u2E80{}".format(rel["subject"], 
                                                             rel["predicate"], 
                                                             rel["object"], 
                                                             str(rel["subj_span"]), 
                                                             str(rel["obj_span"]))
        memory_set.add(memory)
    for rel in sample1["relation_list"]:
        memory = "{}\u2E80{}\u2E80{}\u2E80{}\u2E80{}".format(rel["subject"], 
                                                             rel["predicate"], 
                                                             rel["object"], 
                                                             str(rel["subj_span"]), 
                                                             str(rel["obj_span"]))
        if memory not in memory_set:
            set_trace()
            return False
    return True

# Dataset

In [24]:
# @specific
def get_indexed_train_valid_data(data):
    indexed_samples = []
    for ind, sample in tqdm(enumerate(data), desc = "Generate indexed train or valid data"):
        text = sample["text"]
        text_id = sample["id"]
        
        
        # tagging
        spots_tuple = handshaking_tagger.get_spots(sample)
        offset_map = get_tok2char_span_map(text)
        input_ids = []
        for idx in offset_map:
            word = text[idx[0]:idx[1]]
            if word not in word2idx:
                input_ids.append(word2idx['<UNK>'])
            else:
                input_ids.append(word2idx[word])
        if len(input_ids) < config.max_seq_len:
            input_ids.extend([word2idx['<PAD>']] * (config.max_seq_len - len(input_ids)))
        input_ids = torch.tensor(input_ids[:config.max_seq_len])
        
        sample_tp = (text_id,
                     text, 
                     input_ids,
                     offset_map,
                     spots_tuple,
                    )
        indexed_samples.append(sample_tp)       
    return indexed_samples

In [25]:
# @specific
def get_indexed_pred_data(data):
    indexed_samples = []
    for ind, sample in tqdm(enumerate(data), desc = "Generate indexed pred data"):
        text = sample["text"] 
        text_id = sample["id"]
        # @specific
        codes = tokenizer.encode_plus(text, 
                                    return_offsets_mapping = True, 
                                    add_special_tokens = False,
                                    max_length = max_seq_len, 
                                    pad_to_max_length = True)
        
        input_ids = torch.tensor(codes["input_ids"]).long()
        attention_mask = torch.tensor(codes["attention_mask"]).long()
        token_type_ids = torch.tensor(codes["token_type_ids"]).long()
        offset_map = codes["offset_mapping"]

        sample_tp = (text_id,
                     text, 
                     input_ids,
                     attention_mask,
                     token_type_ids,
                     offset_map,
                     )
        indexed_samples.append(sample_tp)       
    return indexed_samples

In [26]:
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)

In [27]:
def generate_train_dev_batch(batch_data):
    text_id_list = []
    text_list = []
    input_ids_list = []
    attention_mask_list = []
    token_type_ids_list = [] 
    offset_map_list = []
    ent_spots_list = []
    head_rel_spots_list = []
    tail_rel_spots_list = []
    
    for sample in batch_data:
        text_id_list.append(sample[0])
        text_list.append(sample[1])
        input_ids_list.append(sample[2])    
        offset_map_list.append(sample[3])
        
        ent_matrix_spots, head_rel_matrix_spots, tail_rel_matrix_spots = sample[4]
        ent_spots_list.append(ent_matrix_spots)
        head_rel_spots_list.append(head_rel_matrix_spots)
        tail_rel_spots_list.append(tail_rel_matrix_spots)
    
    # @specific: codes indexed by bert tokenizer
    batch_input_ids = torch.stack(input_ids_list, dim = 0)
    batch_attention_mask = None
    batch_token_type_ids = None

    batch_ent_shaking_tag = handshaking_tagger.sharing_spots2shaking_tag4batch(ent_spots_list)
    batch_head_rel_shaking_tag = handshaking_tagger.spots2shaking_tag4batch(head_rel_spots_list)
    batch_tail_rel_shaking_tag = handshaking_tagger.spots2shaking_tag4batch(tail_rel_spots_list)

    return text_id_list, text_list, batch_input_ids, batch_attention_mask, batch_token_type_ids, offset_map_list, batch_ent_shaking_tag, batch_head_rel_shaking_tag, batch_tail_rel_shaking_tag

In [28]:
def generate_pred_batch(batch_data):
    text_ids = []
    text_list = []
    input_ids = []
    attention_mask = []
    token_type_ids = [] 
    offset_map = []
    for sample in batch_data:
        text_ids.append(sample[0])
        text_list.append(sample[1])
        input_ids.append(sample[2])
        attention_mask.append(sample[3])        
        token_type_ids.append(sample[4])        
        offset_map.append(sample[5])
    input_ids = torch.stack(input_ids, dim = 0)
    attention_mask = torch.stack(attention_mask, dim = 0)
    token_type_ids = torch.stack(token_type_ids, dim = 0)
    return text_ids, text_list, input_ids, attention_mask, token_type_ids, offset_map

In [29]:
# @uni
def get_train_dev_dataloader_gen(indexed_train_sample_list, indexed_dev_sample_list, batch_size):
    train_dataloader = DataLoader(MyDataset(indexed_train_sample_list), 
                                      batch_size = batch_size, 
                                      shuffle = True, 
                                      num_workers = 6,
                                      drop_last = False,
                                      collate_fn = generate_train_dev_batch,
                                     )
    dev_dataloader = DataLoader(MyDataset(indexed_dev_sample_list), 
                              batch_size = batch_size, 
                              shuffle = True, 
                              num_workers = 6,
                              drop_last = False,
                              collate_fn = generate_train_dev_batch,
                             )
    return train_dataloader, dev_dataloader

In [30]:
indexed_train_data = get_indexed_train_valid_data(short_train_data)

Generate indexed train or valid data: 75129it [00:11, 6723.32it/s]


In [31]:
indexed_valid_data = get_indexed_train_valid_data(short_valid_data)

Generate indexed train or valid data: 6724it [00:00, 8828.21it/s]


In [32]:
# have a look at dataloader
train_dataloader, dev_dataloader = get_train_dev_dataloader_gen(indexed_train_data, indexed_valid_data, 32)

In [33]:
train_data_iter = iter(train_dataloader)
batch_data = next(train_data_iter)
text_id_list, text_list, batch_input_ids, \
batch_attention_mask, batch_token_type_ids, \
offset_map_list, batch_ent_shaking_tag, \
batch_head_rel_shaking_tag, batch_tail_rel_shaking_tag = batch_data

print(text_list[0])
print()
print(batch_input_ids[0].tolist())
print(batch_input_ids.size())
# print(batch_attention_mask.size())
# print(batch_token_type_ids.size())
print(len(offset_map_list))
print(batch_ent_shaking_tag.size())
print(batch_head_rel_shaking_tag.size())
print(batch_tail_rel_shaking_tag.size())

By comparison , just 57 percent of blacks in Atlanta were born in Georgia .

[633, 10365, 12, 1178, 8994, 1330, 39, 5703, 49, 1758, 309, 897, 49, 2054, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
torch.Size([32, 100])
32
torch.Size([32, 5050])
torch.Size([32, 24, 5050])
torch.Size([32, 24, 5050])


# Load Word Embedding Matrix

In [34]:
glove = Glove()
glove = glove.load(config.word_embedding_path)

In [35]:
# prepare embedding matrix
word_embedding_init_matrix = np.random.normal(-1, 1, size=(len(word2idx), config.word_embedding_dim))
count_in = 0

# 在预训练词向量中的用该预训练向量
# 不在预训练集里的用随机向量
for ind, tok in tqdm(idx2word.items(), desc="Embedding matrix initializing..."):
    if tok in glove.dictionary:
        count_in += 1
        word_embedding_init_matrix[ind] = glove.word_vectors[glove.dictionary[tok]]
        
print(count_in / len(idx2word)) # 命中预训练词向量的比例

Embedding matrix initializing...: 100%|██████████| 39708/39708 [00:00<00:00, 233891.93it/s]

0.9999496323159062





In [36]:
word_embedding_init_matrix = torch.FloatTensor(word_embedding_init_matrix)
word_embedding_init_matrix.size()

torch.Size([39708, 100])

# Model

In [37]:
class RelExtractor(nn.Module):
    def __init__(self, init_word_embedding_matrix, hidden_size, dropout, rel_size):
        super().__init__()
        # BiLSTM encoder
        self.word_embeds = nn.Embedding.from_pretrained(init_word_embedding_matrix, freeze = False)
        self.dropout = nn.Dropout(dropout)
        self.lstm = nn.LSTM(init_word_embedding_matrix.size()[-1], 
                            hidden_size // 2, 
                            num_layers = 2, 
                            bidirectional = True, 
                            batch_first = True)


        self.ent_fc = nn.Linear(hidden_size, 2)
        self.head_rel_fc_list = [nn.Linear(hidden_size, 3) for _ in range(rel_size)]
        self.tail_rel_fc_list = [nn.Linear(hidden_size, 3) for _ in range(rel_size)]
        
        for ind, fc in enumerate(self.head_rel_fc_list):
            self.register_parameter("weight_4_head_rel{}".format(ind), fc.weight)
            self.register_parameter("bias_4_head_rel{}".format(ind), fc.bias)
        for ind, fc in enumerate(self.tail_rel_fc_list):
            self.register_parameter("weight_4_tail_rel{}".format(ind), fc.weight)
            self.register_parameter("bias_4_tail_rel{}".format(ind), fc.bias)
            
        # conditional layer normaliztion
        fake_inputs = torch.zeros([config.batch_size, config.max_seq_len, hidden_size])
        self.cond_layer_norm = LayerNorm(fake_inputs.size(), hidden_size, conditional = True)
        
    def forward(self, input_ids, attention_mask, token_type_ids):
        # BiLSTM encoder
        embedding = self.word_embeds(input_ids)
        outputs, hidden = self.lstm(embedding)
        last_hidden_state = self.dropout(outputs)  # last_hidden_state: (batch_size, seq_len, hidden_size) 
        
        # shaking_hiddens: (batch_size, 1 + ... + seq_len, hidden_size)
        shaking_hiddens = self.shake_hands_afterwards(last_hidden_state)
        
        ent_shaking_outputs = self.ent_fc(shaking_hiddens)
            
        head_rel_shaking_outputs_list = []
        for fc in self.head_rel_fc_list:
            head_rel_shaking_outputs_list.append(fc(shaking_hiddens))
            
        tail_rel_shaking_outputs_list = []
        for fc in self.tail_rel_fc_list:
            tail_rel_shaking_outputs_list.append(fc(shaking_hiddens))
        
        head_rel_shaking_outputs = torch.stack(head_rel_shaking_outputs_list, dim = 1)
        tail_rel_shaking_outputs = torch.stack(tail_rel_shaking_outputs_list, dim = 1)
        
        return ent_shaking_outputs, head_rel_shaking_outputs, tail_rel_shaking_outputs

    def shake_hands_afterwards(self, seq_hiddens):
        '''
        seq_hiddens: (batch_size, seq_len, hidden_size) (32, 3, 5)
        return shake_hands_matrix_hiddens: (batch_size, (1 + seq_len) * seq_len / 2, hidden_size) (32, 5+4+3+2+1, 5)
        '''
        seq_len = seq_hiddens.size()[-2]
        shake_hands_hidden_list = []
        for ind in range(seq_len):
            hidden_each_step = seq_hiddens[:, ind, :]
            # seq_len - ind: only shake afterwards
            repeat_hidden_each_step = hidden_each_step[:, None, :].repeat(1, seq_len - ind, 1) 
    #         shake_hands_hidden = torch.cat([repeat_hidden_each_step, seq_hiddens[:, ind:, :]], dim = -1)
            shake_hands_hidden = self.cond_layer_norm(seq_hiddens[:, ind:, :], repeat_hidden_each_step)
            shake_hands_hidden_list.append(shake_hands_hidden)
        shake_hands_matrix_hiddens = torch.cat(shake_hands_hidden_list, dim = 1)
        return shake_hands_matrix_hiddens

In [38]:
rel_extractor = RelExtractor(word_embedding_init_matrix, config.rnn_hidden_size, config.dropout, len(rel2id))
rel_extractor = rel_extractor.to(device)

In [39]:
def bias_loss(weights = None):
    if weights is not None:
        weights = torch.FloatTensor(weights).to(device)
    cross_en = nn.CrossEntropyLoss(weight = weights)  
    return lambda pred, target: cross_en(pred.view(-1, pred.size()[-1]), target.view(-1))
loss_func = bias_loss()

In [40]:
def get_sample_accuracy(pred, truth):
    '''
    计算所有抽取字段都正确的样本比例
    即该batch的输出与truth全等的样本比例
    '''
    # (batch_size, ..., seq_len, tag_size) -> (batch_size, ..., seq_len)
    pred_id = torch.argmax(pred, dim = -1)
    # (batch_size, ..., seq_len) -> (batch_size, )，把每个sample压成一条seq
    pred_id = pred_id.view(pred_id.size()[0], -1)
    truth = truth.view(truth.size()[0], -1)
    
    # (batch_size, )，每个元素是pred与truth之间tag相同的数量
    correct_tag_num = torch.sum(torch.eq(truth, pred_id).float(), dim = 1)

    # seq维上所有tag必须正确，所以correct_tag_num必须等于seq的长度才算一个correct的sample
    sample_acc_ = torch.eq(correct_tag_num, torch.ones_like(correct_tag_num) * truth.size()[-1]).float()
    sample_acc = torch.mean(sample_acc_)
    
    return sample_acc

In [41]:
def get_rel_cpg(text_list, offset_map_list, 
                 batch_pred_ent_shaking_outputs,
                 batch_pred_head_rel_shaking_outputs,
                 batch_pred_tail_rel_shaking_outputs,
                 batch_gold_ent_shaking_tag,
                 batch_gold_head_rel_shaking_tag,
                 batch_gold_tail_rel_shaking_tag):
    batch_pred_ent_shaking_tag = torch.argmax(batch_pred_ent_shaking_outputs, dim = -1)
    batch_pred_head_rel_shaking_tag = torch.argmax(batch_pred_head_rel_shaking_outputs, dim = -1)
    batch_pred_tail_rel_shaking_tag = torch.argmax(batch_pred_tail_rel_shaking_outputs, dim = -1)
    
    correct_num, pred_num, gold_num = 0, 0, 0
    for ind in range(len(text_list)):
        text = text_list[ind]
        offset_map = offset_map_list[ind]
        gold_ent_shaking_tag, pred_ent_shaking_tag = batch_gold_ent_shaking_tag[ind], batch_pred_ent_shaking_tag[ind]
        gold_head_rel_shaking_tag, pred_head_rel_shaking_tag = batch_gold_head_rel_shaking_tag[ind], batch_pred_head_rel_shaking_tag[ind]
        gold_tail_rel_shaking_tag, pred_tail_rel_shaking_tag = batch_gold_tail_rel_shaking_tag[ind], batch_pred_tail_rel_shaking_tag[ind]
        
        pred_rel_list = handshaking_tagger.decode_rel_fr_shaking_tag(text, 
                                                  pred_ent_shaking_tag, 
                                                  pred_head_rel_shaking_tag, 
                                                  pred_tail_rel_shaking_tag, 
                                                  offset_map)
        gold_rel_list = handshaking_tagger.decode_rel_fr_shaking_tag(text, 
                                                  gold_ent_shaking_tag, 
                                                  gold_head_rel_shaking_tag, 
                                                  gold_tail_rel_shaking_tag, 
                                                  offset_map)

        gold_rel_set = set(["{}\u2E80{}\u2E80{}".format(rel["subject"], rel["predicate"], rel["object"]) for rel in gold_rel_list])
        pred_rel_set = set(["{}\u2E80{}\u2E80{}".format(rel["subject"], rel["predicate"], rel["object"]) for rel in pred_rel_list])
        
        for rel_str in pred_rel_set:
            if rel_str in gold_rel_set:
                correct_num += 1
        
        pred_num += len(gold_rel_set)
        gold_num += len(pred_rel_set)
        
    return correct_num, pred_num, gold_num

In [42]:
def get_scores(correct_num, pred_num, gold_num):
    minimini = 1e-10
    precision = correct_num / (pred_num + minimini)
    recall = correct_num / (gold_num + minimini)
    f1 = 2 * precision * recall / (precision + recall + minimini)
    return precision, recall, f1

# Train

In [43]:
# train step
def train_step(batch_train_data, optimizer, loss_weights):
    text_id_list, text_list, batch_input_ids, \
    batch_attention_mask, batch_token_type_ids, \
    offset_map_list, batch_ent_shaking_tag, \
    batch_head_rel_shaking_tag, batch_tail_rel_shaking_tag = batch_train_data
    
    batch_input_ids, \
    batch_attention_mask, \
    batch_token_type_ids, \
    batch_ent_shaking_tag, \
    batch_head_rel_shaking_tag, \
    batch_tail_rel_shaking_tag = (batch_input_ids.to(device), 
                                  None,
                                  None,
                              batch_ent_shaking_tag.to(device), 
                              batch_head_rel_shaking_tag.to(device), 
                              batch_tail_rel_shaking_tag.to(device)
                             )

    # zero the parameter gradients
    optimizer.zero_grad()
    
    ent_shaking_outputs, \
    head_rel_shaking_outputs, \
    tail_rel_shaking_outputs = rel_extractor(batch_input_ids, 
                                              batch_attention_mask, 
                                              batch_token_type_ids, 
                                             )

    w_ent, w_rel = loss_weights["ent"], loss_weights["rel"]
    loss = w_ent * loss_func(ent_shaking_outputs, batch_ent_shaking_tag) + \
            w_rel * loss_func(head_rel_shaking_outputs, batch_head_rel_shaking_tag) + \
            w_rel * loss_func(tail_rel_shaking_outputs, batch_tail_rel_shaking_tag)
    
    # bp time: 2s
    loss.backward()
    optimizer.step()
    
    ent_sample_acc = get_sample_accuracy(ent_shaking_outputs, 
                                          batch_ent_shaking_tag)
    head_rel_sample_acc = get_sample_accuracy(head_rel_shaking_outputs, 
                                             batch_head_rel_shaking_tag)
    tail_rel_sample_acc = get_sample_accuracy(tail_rel_shaking_outputs, 
                                             batch_tail_rel_shaking_tag)
    
    return loss.item(), ent_sample_acc.item(), head_rel_sample_acc.item(), tail_rel_sample_acc.item()

# valid step
def valid_step(batch_valid_data):
    text_id_list, text_list, batch_input_ids, \
    batch_attention_mask, batch_token_type_ids, \
    offset_map_list, batch_ent_shaking_tag, \
    batch_head_rel_shaking_tag, batch_tail_rel_shaking_tag = batch_valid_data
    
    batch_input_ids, \
    batch_attention_mask, \
    batch_token_type_ids, \
    batch_ent_shaking_tag, \
    batch_head_rel_shaking_tag, \
    batch_tail_rel_shaking_tag = (batch_input_ids.to(device), 
                                  None,
                                  None,
                              batch_ent_shaking_tag.to(device), 
                              batch_head_rel_shaking_tag.to(device), 
                              batch_tail_rel_shaking_tag.to(device)
                             )
    with torch.no_grad():
        ent_shaking_outputs, \
        head_rel_shaking_outputs, \
        tail_rel_shaking_outputs = rel_extractor(batch_input_ids, 
                                                  batch_attention_mask, 
                                                  batch_token_type_ids, 
                                                 )
    
    ent_sample_acc = get_sample_accuracy(ent_shaking_outputs, 
                                          batch_ent_shaking_tag)
    head_rel_sample_acc = get_sample_accuracy(head_rel_shaking_outputs, 
                                             batch_head_rel_shaking_tag)
    tail_rel_sample_acc = get_sample_accuracy(tail_rel_shaking_outputs, 
                                             batch_tail_rel_shaking_tag)
    
    rel_cpg = get_rel_cpg(text_list, offset_map_list, 
                            ent_shaking_outputs,
                            head_rel_shaking_outputs,
                            tail_rel_shaking_outputs,
                            batch_ent_shaking_tag,
                            batch_head_rel_shaking_tag,
                            batch_tail_rel_shaking_tag)
    
    return ent_sample_acc.item(), head_rel_sample_acc.item(), tail_rel_sample_acc.item(), rel_cpg

In [44]:
max_f1 = 0.
def train_n_valid(train_dataloader, dev_dataloader, optimizer, scheduler, num_epoch):  
    def train(dataloader, ep):
        # train
        rel_extractor.train()
        
        t_ep = time.time()
        start_lr = optimizer.param_groups[0]['lr']
        total_loss, total_ent_sample_acc, total_head_rel_sample_acc, total_tail_rel_sample_acc = 0., 0., 0., 0.
        for batch_ind, batch_train_data in enumerate(dataloader):
            t_batch = time.time()
            z = (2 * len(rel2id) + 1)
            steps_per_ep = len(dataloader)
            total_steps = config.loss_weight_recover_steps
            current_step = steps_per_ep * ep + batch_ind
            w_ent = max(1 / z + 1 - current_step / total_steps, 1 / z)
            w_rel = min((len(rel2id) / z) * current_step / total_steps, (len(rel2id) / z))
            loss_weights = {"ent": w_ent, "rel": w_rel}
            loss, ent_sample_acc, head_rel_sample_acc, tail_rel_sample_acc = train_step(batch_train_data, optimizer, loss_weights)
            scheduler.step()
            
            total_loss += loss
            total_ent_sample_acc += ent_sample_acc
            total_head_rel_sample_acc += head_rel_sample_acc
            total_tail_rel_sample_acc += tail_rel_sample_acc
            
            avg_loss = total_loss / (batch_ind + 1)
            avg_ent_sample_acc = total_ent_sample_acc / (batch_ind + 1)
            avg_head_rel_sample_acc = total_head_rel_sample_acc / (batch_ind + 1)
            avg_tail_rel_sample_acc = total_tail_rel_sample_acc / (batch_ind + 1)
            
            batch_print_format = "\rEpoch: {}/{}, batch: {}/{}, train_loss: {}, " + \
                                "t_ent_sample_acc: {}, t_head_rel_sample_acc: {}, t_tail_rel_sample_acc: {}," + \
                                 "lr: {}, batch_time: {}, total_time: {} -------------"
                    
            print(batch_print_format.format(ep + 1, num_epoch, 
                                            batch_ind + 1, len(dataloader), 
                                            avg_loss, 
                                            avg_ent_sample_acc,
                                            avg_head_rel_sample_acc,
                                            avg_tail_rel_sample_acc,
                                            optimizer.param_groups[0]['lr'],
                                            time.time() - t_batch,
                                            time.time() - t_ep,
                                           ), end="")
            
            if batch_ind % config.log_interval == 0:
                wandb.log({
                    "train_loss": avg_loss,
                    "train_ent_seq_acc": avg_ent_sample_acc,
                    "train_head_rel_acc": avg_head_rel_sample_acc,
                    "train_tail_rel_acc": avg_tail_rel_sample_acc,
                    "learning_rate": optimizer.param_groups[0]['lr'],
                    "time": time.time() - t_ep,
                })
        
    def valid(dataloader, ep):
        # valid
        rel_extractor.eval()
        
        t_ep = time.time()
        total_ent_sample_acc, total_head_rel_sample_acc, total_tail_rel_sample_acc = 0., 0., 0.
        total_rel_correct_num, total_rel_pred_num, total_rel_gold_num = 0, 0, 0
        for batch_ind, batch_valid_data in enumerate(tqdm(dataloader, desc = "Validating")):
            ent_sample_acc, head_rel_sample_acc, tail_rel_sample_acc, rel_cpg = valid_step(batch_valid_data)

            total_ent_sample_acc += ent_sample_acc
            total_head_rel_sample_acc += head_rel_sample_acc
            total_tail_rel_sample_acc += tail_rel_sample_acc
            
            total_rel_correct_num += rel_cpg[0]
            total_rel_pred_num += rel_cpg[1]
            total_rel_gold_num += rel_cpg[2]

        avg_ent_sample_acc = total_ent_sample_acc / len(dataloader)
        avg_head_rel_sample_acc = total_head_rel_sample_acc / len(dataloader)
        avg_tail_rel_sample_acc = total_tail_rel_sample_acc / len(dataloader)
        
        rel_prf = get_scores(total_rel_correct_num, total_rel_pred_num, total_rel_gold_num)
        
        log_dict = {
                        "val_ent_seq_acc": avg_ent_sample_acc,
                        "val_head_rel_acc": avg_head_rel_sample_acc,
                        "val_tail_rel_acc": avg_tail_rel_sample_acc,
                        "val_prec": rel_prf[0],
                        "val_recall": rel_prf[1],
                        "val_f1": rel_prf[2],
                        "time": time.time() - t_ep,
                    }
        pprint(log_dict)
        wandb.log(log_dict)
        
        return rel_prf[2]
        
    for ep in range(num_epoch):
        train(train_dataloader, ep)   
        valid_f1 = valid(dev_dataloader, ep)
        
        global max_f1
        if valid_f1 >= max_f1: 
            max_f1 = valid_f1
            if valid_f1 > 0.5: # save the best model
                modle_state_num = len(glob.glob(model_state_dict_dir + "/model_state_dict_*.pt"))
                torch.save(rel_extractor.state_dict(), os.path.join(model_state_dict_dir, "model_state_dict_{}.pt".format(modle_state_num)))
                scheduler_state_num = len(glob.glob(schedule_state_dict_dir + "/scheduler_state_dict_*.pt"))
                torch.save(scheduler.state_dict(), os.path.join(schedule_state_dict_dir, "scheduler_state_dict_{}.pt".format(scheduler_state_num))) 
        print("Current avf_f1: {}, Best f1: {}".format(valid_f1, max_f1))

In [45]:
def get_last_state_path(state_dir, pre_fix):
    max_file_num = -1
    last_state_path = None
    for path in glob.glob(state_dir + "/{}_*.pt".format(pre_fix)):
        file_num = re.search("state_dict_(\d+)\.pt", path).group(1)
        if int(file_num) > max_file_num:
            max_file_num = int(file_num)
            last_state_path = path
    return last_state_path

def get_model_state_path(state_dict_dir, state_dict_num):
    return os.path.join(state_dict_dir, "model_state_dict_{}.pt".format(state_dict_num))

In [46]:
# dataloader
print("preparing dataloader...")
train_dataloader, \
dev_dataloader = get_train_dev_dataloader_gen(indexed_train_data, 
                                            indexed_valid_data, 
                                            config.batch_size, 
                                            )
print("dataloaders done!")

preparing dataloader...
dataloaders done!


In [47]:
# optimizer
init_learning_rate = config.lr
optimizer = torch.optim.Adam(rel_extractor.parameters(), lr = init_learning_rate)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, len(train_dataloader) * 2)

# decay_rate = 0.99
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 10, gamma = decay_rate)

In [None]:
epoch_num = config.epochs

# load the last best state (if any)
model_last_state_path = get_last_state_path(model_state_dict_dir, "model_state_dict")
if model_last_state_path is not None:
    rel_extractor.load_state_dict(torch.load(model_last_state_path))
    print("------------model state {} loaded ----------------".format(model_last_state_path.split("/")[-1]))
    
scheduler_last_state_path = get_last_state_path(schedule_state_dict_dir, "scheduler_state_dict")  
if scheduler_last_state_path is not None:
    scheduler.load_state_dict(torch.load(scheduler_last_state_path))
    print("------------scheduler state {} loaded ----------------".format(scheduler_last_state_path.split("/")[-1]))

train_n_valid(train_dataloader, dev_dataloader, optimizer, scheduler, epoch_num)

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable


Epoch: 1/50, batch: 908/12522, train_loss: 2.927136591668723, t_ent_sample_acc: 0.0, t_head_rel_sample_acc: 0.0, t_tail_rel_sample_acc: 0.0,lr: 0.0009967672033531418, batch_time: 0.7537217140197754, total_time: 718.4217524528503 ---------------

requests_with_retry encountered retryable exception: ('Connection aborted.', OSError("(104, 'ECONNRESET')")). args: ('https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream',), kwargs: {'json': {'files': {'output.log': {'offset': 0, 'content': ['2020-06-01T02:27:17.106680 Epoch: 1/50, batch: 606/12522, train_loss: 4.382330671613636, t_ent_sample_acc: 0.0, t_head_rel_sample_acc: 0.0, t_tail_rel_sample_acc: 0.0,lr: 0.000998560755259447, batch_time: 0.7810513973236084, total_time: 476.4979655742645 -------------\r']}, 'wandb-events.jsonl': {'offset': 15, 'content': ['{"system.gpu.0.gpu": 0.0, "system.gpu.0.memory": 0.0, "system.gpu.0.memoryAllocated": 49.98, "system.gpu.0.temp": 39.0, "system.gpu.process.0.gpu": 0.0, "system.gpu.process.0.memory": 0.0, "system.gpu.process.0.memoryAllocated": 49.98, "system.gpu.process.0.temp": 39.0, "system.gpu.0.powerWatts": 38.68, "system.gpu.0.powerPercent": 15.47, "system.gpu.process.0.powerWatts": 38.68, "system.gpu.process.0.powerPercent

Epoch: 1/50, batch: 1252/12522, train_loss: 2.124385475890507, t_ent_sample_acc: 0.0, t_head_rel_sample_acc: 0.0, t_tail_rel_sample_acc: 0.0,lr: 0.000993855938860604, batch_time: 0.7679905891418457, total_time: 991.7341613769531 ---------------

requests_with_retry encountered retryable exception: ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')). args: ('https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream',), kwargs: {'json': {'files': {'output.log': {'offset': 1, 'content': ['2020-06-01T02:32:22.385921 Epoch: 1/50, batch: 986/12522, train_loss: 2.696053864032457, t_ent_sample_acc: 0.0, t_head_rel_sample_acc: 0.0, t_tail_rel_sample_acc: 0.0,lr: 0.0009961880059927454, batch_time: 0.6412737369537354, total_time: 781.8273639678955 -------------\r']}, 'wandb-events.jsonl': {'offset': 25, 'content': ['{"system.gpu.0.gpu": 0.0, "system.gpu.0.memory": 0.0, "system.gpu.0.memoryAllocated": 49.98, "system.gpu.0.temp": 39.0, "system.gpu.process.0.gpu": 0.0, "system.gpu.process.0.memory": 0.0, "system.gpu.process.0.memoryAllocated": 49.98, "system.gpu.process.0.temp": 39.0, "system.gpu.0.powerWatts": 38.73, "system.gpu.0.powerPercent": 15.49, "system.gpu.process.0.powerWatts": 38.

Epoch: 1/50, batch: 3378/12522, train_loss: 0.7894555139235087, t_ent_sample_acc: 4.933886075598591e-05, t_head_rel_sample_acc: 0.0, t_tail_rel_sample_acc: 0.0,lr: 0.0009558032771385767, batch_time: 0.9431967735290527, total_time: 2693.2807099819183 --------------

requests_with_retry encountered retryable exception: ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')). args: ('https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream',), kwargs: {'json': {'files': {'output.log': {'offset': 2, 'content': ['2020-06-01T03:00:43.682316 Epoch: 1/50, batch: 3113/12522, train_loss: 0.8564584258270435, t_ent_sample_acc: 5.353892439245757e-05, t_head_rel_sample_acc: 0.0, t_tail_rel_sample_acc: 0.0,lr: 0.0009623825675411611, batch_time: 0.8159642219543457, total_time: 2483.1148495674133 -------------\r']}, 'wandb-events.jsonl': {'offset': 78, 'content': ['{"system.gpu.0.gpu": 0.0, "system.gpu.0.memory": 0.0, "system.gpu.0.memoryAllocated": 49.98, "system.gpu.0.temp": 39.0, "system.gpu.process.0.gpu": 0.0, "system.gpu.process.0.memory": 0.0, "system.gpu.process.0.memoryAllocated": 49.98, "system.gpu.process.0.temp": 39.0, "system.gpu.0.powerWatts": 38.71, "system.gpu.0.powerPercent": 15.48, "system.gpu.proce

Epoch: 1/50, batch: 10021/12522, train_loss: 0.2666869108274467, t_ent_sample_acc: 0.031450621204549424, t_head_rel_sample_acc: 0.00024947610762456903, t_tail_rel_sample_acc: 0.0002827395886411782,lr: 0.0006543653262120819, batch_time: 1.029634952545166, total_time: 7915.689510583878 ----------------

requests_with_retry encountered retryable exception: ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')). args: ('https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream',), kwargs: {'json': {'files': {'output.log': {'offset': 3, 'content': ['2020-06-01T04:27:46.361190 Epoch: 1/50, batch: 9758/12522, train_loss: 0.2738715800962235, t_ent_sample_acc: 0.027755005190083667, t_head_rel_sample_acc: 0.00025620004862736276, t_tail_rel_sample_acc: 0.00027328005186918694,lr: 0.0006699683987361997, batch_time: 0.5889179706573486, total_time: 7705.772531986237 -------------\r']}, 'wandb-events.jsonl': {'offset': 240, 'content': ['{"system.gpu.0.gpu": 0.0, "system.gpu.0.memory": 0.0, "system.gpu.0.memoryAllocated": 49.98, "system.gpu.0.temp": 39.0, "system.gpu.process.0.gpu": 0.0, "system.gpu.process.0.memory": 0.0, "system.gpu.process.0.memoryAllocated": 49.98, "system.gpu.process.0.temp": 39.0, "system.gpu.0.powerWatts": 38.7, "system.gpu.0.pow

Epoch: 1/50, batch: 12522/12522, train_loss: 0.21344221384965473, t_ent_sample_acc: 0.06159825529382549, t_head_rel_sample_acc: 0.00067880532289728, t_tail_rel_sample_acc: 0.0007985944975262118,lr: 0.000500062721463132, batch_time: 0.6430239677429199, total_time: 9918.337591409683 -------------------

Validating: 100%|██████████| 1121/1121 [08:08<00:00,  2.29it/s]

{'time': 488.9634954929352,
 'val_ent_seq_acc': 0.2468034549417717,
 'val_f1': 0.007253115117204482,
 'val_head_rel_acc': 0.0002973535622367893,
 'val_prec': 0.0036778574122972117,
 'val_recall': 0.2599999999998267,
 'val_tail_rel_acc': 0.0002973535622367893}
Current avf_f1: 0.007253115117204482, Best f1: 0.007253115117204482





Epoch: 2/50, batch: 3180/12522, train_loss: 8.90295035845059e-05, t_ent_sample_acc: 0.2394130033563893, t_head_rel_sample_acc: 0.01058700240258151, t_tail_rel_sample_acc: 0.01006289338165859,lr: 0.0003058514091904477, batch_time: 0.8183455467224121, total_time: 2544.6383781433105 --------------------

requests_with_retry encountered retryable exception: ('Connection aborted.', OSError("(104, 'ECONNRESET')")). args: ('https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream',), kwargs: {'json': {'files': {'output.log': {'offset': 13, 'content': ['2020-06-01T05:41:08.872935 Epoch: 2/50, batch: 2127/12522, train_loss: 9.11448772579304e-05, t_ent_sample_acc: 0.22598339337724221, t_head_rel_sample_acc: 0.0076006897666054144, t_tail_rel_sample_acc: 0.007287259267482838,lr: 0.0003682292485679278, batch_time: 0.684828519821167, total_time: 1700.826170682907 -------------\r']}, 'wandb-events.jsonl': {'offset': 377, 'content': ['{"system.gpu.0.gpu": 0.0, "system.gpu.0.memory": 0.0, "system.gpu.0.memoryAllocated": 49.98, "system.gpu.0.temp": 39.0, "system.gpu.process.0.gpu": 0.0, "system.gpu.process.0.memory": 0.0, "system.gpu.process.0.memoryAllocated": 49.98, "system.gpu.process.0.temp": 39.0, "system.gpu.0.powerWatts": 38.7, "system.gpu.0.powerPercent": 15.48, "system.gpu.process.

Epoch: 2/50, batch: 3751/12522, train_loss: 8.931135063842537e-05, t_ent_sample_acc: 0.23935839873461748, t_head_rel_sample_acc: 0.011152581858240869, t_tail_rel_sample_acc: 0.01048609291003946,lr: 0.0002733736704236526, batch_time: 0.8042082786560059, total_time: 2998.8418617248535 ---------------

requests_with_retry encountered retryable exception: ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')). args: ('https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream',), kwargs: {'json': {'files': {'output.log': {'offset': 14, 'content': ['2020-06-01T05:59:17.091024 Epoch: 2/50, batch: 3483/12522, train_loss: 8.946477320597094e-05, t_ent_sample_acc: 0.24045363736844014, t_head_rel_sample_acc: 0.011245095550996833, t_tail_rel_sample_acc: 0.010575174973026761,lr: 0.0002884824600038317, batch_time: 0.8718113899230957, total_time: 2788.7366518974304 -------------\r']}, 'wandb-events.jsonl': {'offset': 410, 'content': ['{"system.gpu.0.gpu": 0.0, "system.gpu.0.memory": 0.0, "system.gpu.0.memoryAllocated": 49.98, "system.gpu.0.temp": 39.0, "system.gpu.process.0.gpu": 0.0, "system.gpu.process.0.memory": 0.0, "system.gpu.process.0.memoryAllocated": 49.98, "system.gpu.process.0.temp": 39.0, "system.gpu.0.powerWatts": 38.62, "system.gpu.0.po

Epoch: 2/50, batch: 8495/12522, train_loss: 8.0314490642591e-05, t_ent_sample_acc: 0.2996076186679264, t_head_rel_sample_acc: 0.02878163709157772, t_tail_rel_sample_acc: 0.02878163709333183,lr: 6.248150730496549e-05, batch_time: 0.6181893348693848, total_time: 6790.400936365128 --------------------

requests_with_retry encountered retryable exception: ('Connection aborted.', OSError("(110, 'ETIMEDOUT')")). args: ('https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream',), kwargs: {'json': {'files': {'output.log': {'offset': 15, 'content': ['2020-06-01T06:49:54.561027 Epoch: 2/50, batch: 7264/12522, train_loss: 8.280002716128165e-05, t_ent_sample_acc: 0.28434747305292524, t_head_rel_sample_acc: 0.022416483766772674, t_tail_rel_sample_acc: 0.022462372169604256,lr: 0.00010491326780865851, batch_time: 0.804511308670044, total_time: 5826.207480669022 -------------\r']}, 'wandb-events.jsonl': {'offset': 505, 'content': ['{"system.gpu.0.gpu": 0.0, "system.gpu.0.memory": 0.0, "system.gpu.0.memoryAllocated": 49.98, "system.gpu.0.temp": 39.0, "system.gpu.process.0.gpu": 0.0, "system.gpu.process.0.memory": 0.0, "system.gpu.process.0.memoryAllocated": 49.98, "system.gpu.process.0.temp": 39.0, "system.gpu.0.powerWatts": 38.66, "system.gpu.0.powerPercent": 15.47, "system.gpu.process

Epoch: 2/50, batch: 12522/12522, train_loss: 7.478649492281683e-05, t_ent_sample_acc: 0.3431427426799249, t_head_rel_sample_acc: 0.047303413991596044, t_tail_rel_sample_acc: 0.047170314917005,lr: 3.933981973514023e-12, batch_time: 0.6021895408630371, total_time: 10021.204030752182 -----------------

Validating: 100%|██████████| 1121/1121 [08:32<00:00,  2.19it/s]

{'time': 512.8701512813568,
 'val_ent_seq_acc': 0.45376153171966377,
 'val_f1': 0.13568661212357924,
 'val_head_rel_acc': 0.05835563650256709,
 'val_prec': 0.07431158053564622,
 'val_recall': 0.7794263105835035,
 'val_tail_rel_acc': 0.05850431324380727}
Current avf_f1: 0.13568661212357924, Best f1: 0.13568661212357924





Epoch: 3/50, batch: 398/12522, train_loss: 0.00042603906673211345, t_ent_sample_acc: 0.02010050311160447, t_head_rel_sample_acc: 0.0004187604814917598, t_tail_rel_sample_acc: 0.0004187604814917598,lr: 0.000999380097173113, batch_time: 0.8122684955596924, total_time: 319.0021708011627 ------------------

requests_with_retry encountered retryable exception: ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')). args: ('https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream',), kwargs: {'json': {'files': {'output.log': {'offset': 25, 'content': ['2020-06-01T08:10:11.446885 Epoch: 3/50, batch: 136/12522, train_loss: 0.0004152763802481144, t_ent_sample_acc: 0.0330882362802239, t_head_rel_sample_acc: 0.0012254902326008853, t_tail_rel_sample_acc: 0.0012254902326008853,lr: 0.0009999283048922763, batch_time: 0.996302604675293, total_time: 108.78156280517578 -------------\r']}, 'wandb-events.jsonl': {'offset': 655, 'content': ['{"system.gpu.0.gpu": 0.0, "system.gpu.0.memory": 0.0, "system.gpu.0.memoryAllocated": 49.98, "system.gpu.0.temp": 39.0, "system.gpu.process.0.gpu": 0.0, "system.gpu.process.0.memory": 0.0, "system.gpu.process.0.memoryAllocated": 49.98, "system.gpu.process.0.temp": 39.0, "system.gpu.0.powerWatts": 38.69, "system.gpu.0.pow

Epoch: 3/50, batch: 2673/12522, train_loss: 0.00013524545477271174, t_ent_sample_acc: 0.2574510596193627, t_head_rel_sample_acc: 0.055742612808909346, t_tail_rel_sample_acc: 0.05599202049545612,lr: 0.0009721749823187451, batch_time: 0.7539916038513184, total_time: 2140.920558452606 --------------------

requests_with_retry encountered retryable exception: 500 Server Error: Internal Server Error for url: https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream. args: ('https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream',), kwargs: {'json': {'files': {'output.log': {'offset': 26, 'content': ['2020-06-01T08:43:33.569801 Epoch: 3/50, batch: 2635/12522, train_loss: 0.000136390538821907, t_ent_sample_acc: 0.2531309355368424, t_head_rel_sample_acc: 0.05433270229447273, t_tail_rel_sample_acc: 0.054206200105189374,lr: 0.0009729536214200448, batch_time: 0.7968289852142334, total_time: 2110.9469542503357 -------------\r']}, 'wandb-events.jsonl': {'offset': 717, 'content': ['{"system.gpu.0.gpu": 0.0, "system.gpu.0.memory": 0.0, "system.gpu.0.memoryAllocated": 49.98, "system.gpu.0.temp": 39.0, "system.gpu.process.0.gpu": 0.0, "system.gpu.process.0.memory": 0.0, "system.gpu.process.0.memoryAllocated": 49.98, "system.gpu.process.0.temp": 39.0, "system.gpu.0.powerWatts": 38.

Epoch: 3/50, batch: 2746/12522, train_loss: 0.0001333079223215029, t_ent_sample_acc: 0.26323137282022835, t_head_rel_sample_acc: 0.05784171070830282, t_tail_rel_sample_acc: 0.058266571145681074,lr: 0.000970649088031254, batch_time: 0.7930846214294434, total_time: 2198.3573830127716 ----------------

requests_with_retry encountered retryable exception: 500 Server Error: Internal Server Error for url: https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream. args: ('https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream',), kwargs: {'json': {'files': {'output.log': {'offset': 26, 'content': ['2020-06-01T08:43:33.569801 Epoch: 3/50, batch: 2635/12522, train_loss: 0.000136390538821907, t_ent_sample_acc: 0.2531309355368424, t_head_rel_sample_acc: 0.05433270229447273, t_tail_rel_sample_acc: 0.054206200105189374,lr: 0.0009729536214200448, batch_time: 0.7968289852142334, total_time: 2110.9469542503357 -------------\r']}, 'wandb-events.jsonl': {'offset': 717, 'content': ['{"system.gpu.0.gpu": 0.0, "system.gpu.0.memory": 0.0, "system.gpu.0.memoryAllocated": 49.98, "system.gpu.0.temp": 39.0, "system.gpu.process.0.gpu": 0.0, "system.gpu.process.0.memory": 0.0, "system.gpu.process.0.memoryAllocated": 49.98, "system.gpu.process.0.temp": 39.0, "system.gpu.0.powerWatts": 38.

Epoch: 3/50, batch: 2752/12522, train_loss: 0.00013319963816784788, t_ent_sample_acc: 0.26380814552350446, t_head_rel_sample_acc: 0.058018412413941914, t_tail_rel_sample_acc: 0.05838178453945317,lr: 0.0009705219150645193, batch_time: 0.7654774188995361, total_time: 2203.0928170681 ---------------

requests_with_retry encountered retryable exception: 500 Server Error: Internal Server Error for url: https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream. args: ('https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream',), kwargs: {'json': {'files': {'output.log': {'offset': 26, 'content': ['2020-06-01T08:43:33.569801 Epoch: 3/50, batch: 2635/12522, train_loss: 0.000136390538821907, t_ent_sample_acc: 0.2531309355368424, t_head_rel_sample_acc: 0.05433270229447273, t_tail_rel_sample_acc: 0.054206200105189374,lr: 0.0009729536214200448, batch_time: 0.7968289852142334, total_time: 2110.9469542503357 -------------\r']}, 'wandb-events.jsonl': {'offset': 717, 'content': ['{"system.gpu.0.gpu": 0.0, "system.gpu.0.memory": 0.0, "system.gpu.0.memoryAllocated": 49.98, "system.gpu.0.temp": 39.0, "system.gpu.process.0.gpu": 0.0, "system.gpu.process.0.memory": 0.0, "system.gpu.process.0.memoryAllocated": 49.98, "system.gpu.process.0.temp": 39.0, "system.gpu.0.powerWatts": 38.

Epoch: 3/50, batch: 2763/12522, train_loss: 0.00013289412510967744, t_ent_sample_acc: 0.26468814693194476, t_head_rel_sample_acc: 0.05857160249526113, t_tail_rel_sample_acc: 0.05881288615293016,lr: 0.0009702880723943708, batch_time: 0.6937253475189209, total_time: 2211.6190333366394 --------------

requests_with_retry encountered retryable exception: 500 Server Error: Internal Server Error for url: https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream. args: ('https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream',), kwargs: {'json': {'files': {'output.log': {'offset': 26, 'content': ['2020-06-01T08:43:33.569801 Epoch: 3/50, batch: 2635/12522, train_loss: 0.000136390538821907, t_ent_sample_acc: 0.2531309355368424, t_head_rel_sample_acc: 0.05433270229447273, t_tail_rel_sample_acc: 0.054206200105189374,lr: 0.0009729536214200448, batch_time: 0.7968289852142334, total_time: 2110.9469542503357 -------------\r']}, 'wandb-events.jsonl': {'offset': 717, 'content': ['{"system.gpu.0.gpu": 0.0, "system.gpu.0.memory": 0.0, "system.gpu.0.memoryAllocated": 49.98, "system.gpu.0.temp": 39.0, "system.gpu.process.0.gpu": 0.0, "system.gpu.process.0.memory": 0.0, "system.gpu.process.0.memoryAllocated": 49.98, "system.gpu.process.0.temp": 39.0, "system.gpu.0.powerWatts": 38.

Epoch: 3/50, batch: 2787/12522, train_loss: 0.0001322591519213298, t_ent_sample_acc: 0.26617630076827564, t_head_rel_sample_acc: 0.05890443885801169, t_tail_rel_sample_acc: 0.05926324763967865,lr: 0.0009697747626406643, batch_time: 0.7905662059783936, total_time: 2230.4373519420624 ---------------

requests_with_retry encountered retryable exception: 500 Server Error: Internal Server Error for url: https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream. args: ('https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream',), kwargs: {'json': {'files': {'output.log': {'offset': 26, 'content': ['2020-06-01T08:43:33.569801 Epoch: 3/50, batch: 2635/12522, train_loss: 0.000136390538821907, t_ent_sample_acc: 0.2531309355368424, t_head_rel_sample_acc: 0.05433270229447273, t_tail_rel_sample_acc: 0.054206200105189374,lr: 0.0009729536214200448, batch_time: 0.7968289852142334, total_time: 2110.9469542503357 -------------\r']}, 'wandb-events.jsonl': {'offset': 717, 'content': ['{"system.gpu.0.gpu": 0.0, "system.gpu.0.memory": 0.0, "system.gpu.0.memoryAllocated": 49.98, "system.gpu.0.temp": 39.0, "system.gpu.process.0.gpu": 0.0, "system.gpu.process.0.memory": 0.0, "system.gpu.process.0.memoryAllocated": 49.98, "system.gpu.process.0.temp": 39.0, "system.gpu.0.powerWatts": 38.

Epoch: 3/50, batch: 2829/12522, train_loss: 0.0001311384810970858, t_ent_sample_acc: 0.269824443514712, t_head_rel_sample_acc: 0.06091669775458687, t_tail_rel_sample_acc: 0.061211265947507774,lr: 0.0009688662276143198, batch_time: 0.9676027297973633, total_time: 2266.4697387218475 ----------------

requests_with_retry encountered retryable exception: 500 Server Error: Internal Server Error for url: https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream. args: ('https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream',), kwargs: {'json': {'files': {'output.log': {'offset': 26, 'content': ['2020-06-01T08:43:33.569801 Epoch: 3/50, batch: 2635/12522, train_loss: 0.000136390538821907, t_ent_sample_acc: 0.2531309355368424, t_head_rel_sample_acc: 0.05433270229447273, t_tail_rel_sample_acc: 0.054206200105189374,lr: 0.0009729536214200448, batch_time: 0.7968289852142334, total_time: 2110.9469542503357 -------------\r']}, 'wandb-events.jsonl': {'offset': 717, 'content': ['{"system.gpu.0.gpu": 0.0, "system.gpu.0.memory": 0.0, "system.gpu.0.memoryAllocated": 49.98, "system.gpu.0.temp": 39.0, "system.gpu.process.0.gpu": 0.0, "system.gpu.process.0.memory": 0.0, "system.gpu.process.0.memoryAllocated": 49.98, "system.gpu.process.0.temp": 39.0, "system.gpu.0.powerWatts": 38.

Epoch: 3/50, batch: 3461/12522, train_loss: 0.00011717135685489423, t_ent_sample_acc: 0.3186458707613533, t_head_rel_sample_acc: 0.08032360793933879, t_tail_rel_sample_acc: 0.08070885306732808,lr: 0.0009536386607099499, batch_time: 0.7860622406005859, total_time: 2771.5501346588135 ----------------

requests_with_retry encountered retryable exception: 408 Client Error: Request Timeout for url: https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream. args: ('https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream',), kwargs: {'json': {'files': {'output.log': {'offset': 32, 'content': ['2020-06-01T08:53:34.065444 Epoch: 3/50, batch: 3384/12522, train_loss: 0.00011878353513972624, t_ent_sample_acc: 0.3135342861461865, t_head_rel_sample_acc: 0.07821119192714669, t_tail_rel_sample_acc: 0.07894996262153271,lr: 0.0009556484528906814, batch_time: 0.795285701751709, total_time: 2711.379138469696 -------------\r']}, 'wandb-events.jsonl': {'offset': 735, 'content': ['{"system.gpu.0.gpu": 0.0, "system.gpu.0.memory": 0.0, "system.gpu.0.memoryAllocated": 49.98, "system.gpu.0.temp": 39.0, "system.gpu.process.0.gpu": 0.0, "system.gpu.process.0.memory": 0.0, "system.gpu.process.0.memoryAllocated": 49.98, "system.gpu.process.0.temp": 39.0, "system.gpu.0.powerWatts": 38.72, "sy

Epoch: 3/50, batch: 4118/12522, train_loss: 0.00010653125025101956, t_ent_sample_acc: 0.35502672028619725, t_head_rel_sample_acc: 0.09867249718347529, t_tail_rel_sample_acc: 0.10053424239375743,lr: 0.0009347891793014522, batch_time: 0.8084201812744141, total_time: 3319.139408826828 --------------

requests_with_retry encountered retryable exception: ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')). args: ('https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream',), kwargs: {'json': {'files': {'output.log': {'offset': 33, 'content': ['2020-06-01T09:00:12.045701 Epoch: 3/50, batch: 3866/12522, train_loss: 0.00011006269985144283, t_ent_sample_acc: 0.3423866263546909, t_head_rel_sample_acc: 0.09247284246135486, t_tail_rel_sample_acc: 0.0935506144807367,lr: 0.0009423754853335001, batch_time: 0.7707746028900146, total_time: 3109.229909658432 -------------\r']}, 'wandb-events.jsonl': {'offset': 748, 'content': ['{"system.gpu.0.gpu": 0.0, "system.gpu.0.memory": 0.0, "system.gpu.0.memoryAllocated": 49.98, "system.gpu.0.temp": 39.0, "system.gpu.process.0.gpu": 0.0, "system.gpu.process.0.memory": 0.0, "system.gpu.process.0.memoryAllocated": 49.98, "system.gpu.process.0.temp": 39.0, "system.gpu.0.powerWatts": 38.51, "system.gpu.0.powerP

Epoch: 3/50, batch: 9392/12522, train_loss: 6.883118530607706e-05, t_ent_sample_acc: 0.486264918237493, t_head_rel_sample_acc: 0.24872232233753247, t_tail_rel_sample_acc: 0.250763066189151,lr: 0.0006913706893442529, batch_time: 0.970381498336792, total_time: 7324.600524902344 --------------------

requests_with_retry encountered retryable exception: ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')). args: ('https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream',), kwargs: {'json': {'files': {'output.log': {'offset': 34, 'content': ['2020-06-01T10:06:57.247203 Epoch: 3/50, batch: 9120/12522, train_loss: 6.987054800917074e-05, t_ent_sample_acc: 0.48278509953892546, t_head_rel_sample_acc: 0.24283626264843503, t_tail_rel_sample_acc: 0.2449561456463447,lr: 0.0007070174397815256, batch_time: 0.7972805500030518, total_time: 7114.6274082660675 -------------\r']}, 'wandb-history.jsonl': {'offset': 3417, 'content': ['{"train_loss": 6.99830355455485e-05, "train_ent_seq_acc": 0.4824735204069518, "train_head_rel_acc": 0.2422175831496945, "train_tail_rel_acc": 0.24430756220892572, "learning_rate": 0.0007086717604910251, "time": 7091.848155260086, "_runtime": 28534.10676431656, "_timestamp": 1591005993.7554958, "_step": 3417}\n', '{"train

Epoch: 3/50, batch: 12522/12522, train_loss: 5.9389115647569825e-05, t_ent_sample_acc: 0.5208566386780406, t_head_rel_sample_acc: 0.31077304633232833, t_tail_rel_sample_acc: 0.31258319383625444,lr: 0.000500062721463132, batch_time: 0.5248210430145264, total_time: 9863.859137296677 ---------------

Validating: 100%|██████████| 1121/1121 [08:47<00:00,  2.13it/s]

{'time': 527.2157592773438,
 'val_ent_seq_acc': 0.6259292456637952,
 'val_f1': 0.7601860276522033,
 'val_head_rel_acc': 0.5443056925107708,
 'val_prec': 0.7090720482836599,
 'val_recall': 0.819241664850721,
 'val_tail_rel_acc': 0.5420012026227659}
Current avf_f1: 0.7601860276522033, Best f1: 0.7601860276522033





Epoch: 4/50, batch: 2518/12522, train_loss: 2.4904311343373314e-05, t_ent_sample_acc: 0.6744771166700897, t_head_rel_sample_acc: 0.5727429317933307, t_tail_rel_sample_acc: 0.5723457911404662,lr: 0.0003447400852543908, batch_time: 0.803868293762207, total_time: 2064.985038995743 ----------------

requests_with_retry encountered retryable exception: ('Connection aborted.', OSError("(104, 'ECONNRESET')")). args: ('https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream',), kwargs: {'json': {'files': {'output.log': {'offset': 44, 'content': ['2020-06-01T11:31:57.482768 Epoch: 4/50, batch: 2225/12522, train_loss: 2.4893407472222095e-05, t_ent_sample_acc: 0.6738576963644349, t_head_rel_sample_acc: 0.5725093774916081, t_tail_rel_sample_acc: 0.5722097520747881,lr: 0.0003623099494040241, batch_time: 0.7974627017974854, total_time: 1823.2837853431702 -------------\r']}, 'wandb-events.jsonl': {'offset': 1031, 'content': ['{"system.gpu.0.gpu": 0.0, "system.gpu.0.memory": 0.0, "system.gpu.0.memoryAllocated": 31.67, "system.gpu.0.temp": 39.0, "system.gpu.process.0.gpu": 0.0, "system.gpu.process.0.memory": 0.0, "system.gpu.process.0.memoryAllocated": 31.67, "system.gpu.process.0.temp": 39.0, "system.gpu.0.powerWatts": 38.66, "system.gpu.0.powerPercent": 15.47, "system.gpu.process.

Epoch: 4/50, batch: 4945/12522, train_loss: 2.4283667427234087e-05, t_ent_sample_acc: 0.6795079389133154, t_head_rel_sample_acc: 0.5828109345855559, t_tail_rel_sample_acc: 0.5826761182386304,lr: 0.00020940517606903136, batch_time: 0.7115352153778076, total_time: 4008.5340819358826 -------------

requests_with_retry encountered retryable exception: ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')). args: ('https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream',), kwargs: {'json': {'files': {'output.log': {'offset': 45, 'content': ['2020-06-01T12:04:52.245022 Epoch: 4/50, batch: 4685/12522, train_loss: 2.4337786228377773e-05, t_ent_sample_acc: 0.6795090898855519, t_head_rel_sample_acc: 0.5829598152141306, t_tail_rel_sample_acc: 0.5827819425338361,lr: 0.0002228279591596653, batch_time: 0.8143208026885986, total_time: 3798.1882536411285 -------------\r']}, 'wandb-events.jsonl': {'offset': 1093, 'content': ['{"system.gpu.0.gpu": 0.0, "system.gpu.0.memory": 0.0, "system.gpu.0.memoryAllocated": 31.67, "system.gpu.0.temp": 39.0, "system.gpu.process.0.gpu": 0.0, "system.gpu.process.0.memory": 0.0, "system.gpu.process.0.memoryAllocated": 31.67, "system.gpu.process.0.temp": 39.0, "system.gpu.0.powerWatts": 38.62, "system.gpu.0.power

Epoch: 4/50, batch: 6365/12522, train_loss: 2.390560076504111e-05, t_ent_sample_acc: 0.6814349489870356, t_head_rel_sample_acc: 0.5901021357300894, t_tail_rel_sample_acc: 0.5900235811299633,lr: 0.00014190811811531362, batch_time: 0.6270740032196045, total_time: 5102.4829177856445 --------------

requests_with_retry encountered retryable exception: ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')). args: ('https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream',), kwargs: {'json': {'files': {'output.log': {'offset': 46, 'content': ['2020-06-01T12:23:06.429319 Epoch: 4/50, batch: 6099/12522, train_loss: 2.396558646875849e-05, t_ent_sample_acc: 0.6813412219393467, t_head_rel_sample_acc: 0.5886757539166527, t_tail_rel_sample_acc: 0.5885391194720916,lr: 0.00015374918713494852, batch_time: 0.760833740234375, total_time: 4892.25328373909 -------------\r']}, 'wandb-events.jsonl': {'offset': 1127, 'content': ['{"system.gpu.0.gpu": 0.0, "system.gpu.0.memory": 0.0, "system.gpu.0.memoryAllocated": 31.67, "system.gpu.0.temp": 39.0, "system.gpu.process.0.gpu": 0.0, "system.gpu.process.0.memory": 0.0, "system.gpu.process.0.memoryAllocated": 31.67, "system.gpu.process.0.temp": 39.0, "system.gpu.0.powerWatts": 38.7, "system.gpu.0.powerPerc

Epoch: 4/50, batch: 9440/12522, train_loss: 2.318571841748478e-05, t_ent_sample_acc: 0.6879767134627801, t_head_rel_sample_acc: 0.6033545350782195, t_tail_rel_sample_acc: 0.6029308062757097,lr: 3.6928323783813945e-05, batch_time: 0.7831687927246094, total_time: 7586.55038189888 ----------------

requests_with_retry encountered retryable exception: ('Connection aborted.', OSError("(104, 'ECONNRESET')")). args: ('https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream',), kwargs: {'json': {'files': {'output.log': {'offset': 47, 'content': ['2020-06-01T12:53:57.467135 Epoch: 4/50, batch: 8387/12522, train_loss: 2.3463719413202905e-05, t_ent_sample_acc: 0.6857239562901266, t_head_rel_sample_acc: 0.599538984188693, t_tail_rel_sample_acc: 0.5991415437273179,lr: 6.580051441344598e-05, batch_time: 0.7865791320800781, total_time: 6743.249205350876 -------------\r']}, 'wandb-events.jsonl': {'offset': 1185, 'content': ['{"system.gpu.0.gpu": 0.0, "system.gpu.0.memory": 0.0, "system.gpu.0.memoryAllocated": 31.67, "system.gpu.0.temp": 39.0, "system.gpu.process.0.gpu": 0.0, "system.gpu.process.0.memory": 0.0, "system.gpu.process.0.memoryAllocated": 31.67, "system.gpu.process.0.temp": 39.0, "system.gpu.0.powerWatts": 38.59, "system.gpu.0.powerPercent": 15.43, "system.gpu.process.0.

Epoch: 4/50, batch: 11793/12522, train_loss: 2.281792128197077e-05, t_ent_sample_acc: 0.6909324928068138, t_head_rel_sample_acc: 0.6101218393290608, t_tail_rel_sample_acc: 0.6094858688507272,lr: 2.094954404022464e-06, batch_time: 0.7645599842071533, total_time: 9469.119141578674 ----------------

requests_with_retry encountered retryable exception: ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')). args: ('https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream',), kwargs: {'json': {'files': {'output.log': {'offset': 48, 'content': ['2020-06-01T13:35:53.193328 Epoch: 4/50, batch: 11524/12522, train_loss: 2.2874389080019242e-05, t_ent_sample_acc: 0.6905010021804046, t_head_rel_sample_acc: 0.6092213507714536, t_tail_rel_sample_acc: 0.608570535088922,lr: 3.920982483080149e-06, batch_time: 0.8224115371704102, total_time: 9258.855890512466 -------------\r']}, 'wandb-events.jsonl': {'offset': 1263, 'content': ['{"system.gpu.0.gpu": 0.0, "system.gpu.0.memory": 0.0, "system.gpu.0.memoryAllocated": 31.67, "system.gpu.0.temp": 39.0, "system.gpu.process.0.gpu": 0.0, "system.gpu.process.0.memory": 0.0, "system.gpu.process.0.memoryAllocated": 31.67, "system.gpu.process.0.temp": 39.0, "system.gpu.0.powerWatts": 38.62, "system.gpu.0.powerP

Epoch: 4/50, batch: 12238/12522, train_loss: 2.2783883838798392e-05, t_ent_sample_acc: 0.691180494827789, t_head_rel_sample_acc: 0.6108296719407271, t_tail_rel_sample_acc: 0.6101078764590432,lr: 3.1950365121463654e-07, batch_time: 0.8205771446228027, total_time: 9833.19199848175 ---------------

requests_with_retry encountered retryable exception: ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')). args: ('https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream',), kwargs: {'json': {'files': {'output.log': {'offset': 49, 'content': ['2020-06-01T13:41:57.479720 Epoch: 4/50, batch: 11983/12522, train_loss: 2.2808574203597033e-05, t_ent_sample_acc: 0.6908398195485643, t_head_rel_sample_acc: 0.6103090645760474, t_tail_rel_sample_acc: 0.6096553607368712,lr: 1.1467105556698942e-06, batch_time: 0.922905683517456, total_time: 9623.18696641922 -------------\r']}, 'wandb-events.jsonl': {'offset': 1275, 'content': ['{"system.gpu.0.gpu": 0.0, "system.gpu.0.memory": 0.0, "system.gpu.0.memoryAllocated": 31.67, "system.gpu.0.temp": 39.0, "system.gpu.process.0.gpu": 0.0, "system.gpu.process.0.memory": 0.0, "system.gpu.process.0.memoryAllocated": 31.67, "system.gpu.process.0.temp": 39.0, "system.gpu.0.powerWatts": 38.62, "system.gpu.0.powerP

Epoch: 4/50, batch: 12522/12522, train_loss: 2.2762913136907708e-05, t_ent_sample_acc: 0.6911968456346508, t_head_rel_sample_acc: 0.6112176060707515, t_tail_rel_sample_acc: 0.6105521106870863,lr: 3.933981973514023e-12, batch_time: 0.4957880973815918, total_time: 10057.934579849243 --------------

Validating: 100%|██████████| 1121/1121 [09:05<00:00,  2.05it/s]

{'time': 545.7005305290222,
 'val_ent_seq_acc': 0.6328427165374492,
 'val_f1': 0.80289629696627,
 'val_head_rel_acc': 0.6071216333109177,
 'val_prec': 0.7790456431535197,
 'val_recall': 0.8282534589933744,
 'val_tail_rel_acc': 0.6071216333242105}
Current avf_f1: 0.80289629696627, Best f1: 0.80289629696627





Epoch: 5/50, batch: 5456/12522, train_loss: 2.6687543656576243e-05, t_ent_sample_acc: 0.6718902917485957, t_head_rel_sample_acc: 0.5700757713972037, t_tail_rel_sample_acc: 0.5693120861868443,lr: 0.0008874336567760226, batch_time: 0.6323835849761963, total_time: 4239.50582075119 ---------------

requests_with_retry encountered retryable exception: 408 Client Error: Request Timeout for url: https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream. args: ('https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream',), kwargs: {'json': {'files': {'output.log': {'offset': 59, 'content': ['2020-06-01T15:07:57.327771 Epoch: 5/50, batch: 5377/12522, train_loss: 2.6727320810695185e-05, t_ent_sample_acc: 0.67224600407045, t_head_rel_sample_acc: 0.5697725015805516, t_tail_rel_sample_acc: 0.5690905848095048,lr: 0.0008905467523641983, batch_time: 0.6288461685180664, total_time: 4179.140299081802 -------------\r']}, 'wandb-events.jsonl': {'offset': 1435, 'content': ['{"system.gpu.0.gpu": 0.0, "system.gpu.0.memory": 0.0, "system.gpu.0.memoryAllocated": 31.67, "system.gpu.0.temp": 39.0, "system.gpu.process.0.gpu": 0.0, "system.gpu.process.0.memory": 0.0, "system.gpu.process.0.memoryAllocated": 31.67, "system.gpu.process.0.temp": 39.0, "system.gpu.0.powerWatts": 38.62, "syst

Epoch: 5/50, batch: 7909/12522, train_loss: 2.5881381424818405e-05, t_ent_sample_acc: 0.6771399854166738, t_head_rel_sample_acc: 0.5814262375316572, t_tail_rel_sample_acc: 0.5809204846428384,lr: 0.0007735072413438614, batch_time: 0.7040383815765381, total_time: 6084.407439470291 --------------

requests_with_retry encountered retryable exception: ('Connection aborted.', OSError("(110, 'ETIMEDOUT')")). args: ('https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream',), kwargs: {'json': {'files': {'output.log': {'offset': 60, 'content': ['2020-06-01T15:23:38.573026 Epoch: 5/50, batch: 6620/12522, train_loss: 2.6276710301912388e-05, t_ent_sample_acc: 0.6739929688030859, t_head_rel_sample_acc: 0.5747230753550177, t_tail_rel_sample_acc: 0.5741691982390478,lr: 0.0008373246909065484, batch_time: 0.6973292827606201, total_time: 5120.379776239395 -------------\r']}, 'wandb-events.jsonl': {'offset': 1465, 'content': ['{"system.gpu.0.gpu": 0.0, "system.gpu.0.memory": 0.0, "system.gpu.0.memoryAllocated": 40.43, "system.gpu.0.temp": 39.0, "system.gpu.process.0.gpu": 0.0, "system.gpu.process.0.memory": 0.0, "system.gpu.process.0.memoryAllocated": 40.43, "system.gpu.process.0.temp": 39.0, "system.gpu.0.powerWatts": 38.69, "system.gpu.0.powerPercent": 15.48, "system.gpu.process.0.

Epoch: 5/50, batch: 9919/12522, train_loss: 2.5054834242704273e-05, t_ent_sample_acc: 0.6821084299302892, t_head_rel_sample_acc: 0.591944767162016, t_tail_rel_sample_acc: 0.5915583035123602,lr: 0.0006604375881311916, batch_time: 0.7412741184234619, total_time: 7627.5588574409485 --------------

requests_with_retry encountered retryable exception: ('Connection aborted.', OSError("(110, 'ETIMEDOUT')")). args: ('https://api.wandb.ai/files/wycheng/nyt_single/33ahq1zb/file_stream',), kwargs: {'json': {'files': {'output.log': {'offset': 61, 'content': ['2020-06-01T15:49:22.072106 Epoch: 5/50, batch: 8658/12522, train_loss: 2.55763157609319e-05, t_ent_sample_acc: 0.6786594470314463, t_head_rel_sample_acc: 0.5854315998789537, t_tail_rel_sample_acc: 0.5850850995909702,lr: 0.0007330320363725775, batch_time: 0.6473555564880371, total_time: 6663.741533994675 -------------\r']}, 'wandb-events.jsonl': {'offset': 1513, 'content': ['{"system.gpu.0.gpu": 0.0, "system.gpu.0.memory": 0.0, "system.gpu.0.memoryAllocated": 40.43, "system.gpu.0.temp": 39.0, "system.gpu.process.0.gpu": 0.0, "system.gpu.process.0.memory": 0.0, "system.gpu.process.0.memoryAllocated": 40.43, "system.gpu.process.0.temp": 39.0, "system.gpu.0.powerWatts": 38.58, "system.gpu.0.powerPercent": 15.43, "system.gpu.process.0.po

Epoch: 5/50, batch: 12522/12522, train_loss: 2.4206393254654794e-05, t_ent_sample_acc: 0.6875898606398105, t_head_rel_sample_acc: 0.6033514499952286, t_tail_rel_sample_acc: 0.6029787726154097,lr: 0.000500062721463132, batch_time: 0.4984614849090576, total_time: 9500.345129728317 ---------------

Validating: 100%|██████████| 1121/1121 [05:42<00:00,  3.27it/s]

{'time': 342.47619676589966,
 'val_ent_seq_acc': 0.6699375740578938,
 'val_f1': 0.820606468305049,
 'val_head_rel_acc': 0.6387154497172979,
 'val_prec': 0.805167861184451,
 'val_recall': 0.8366487016168462,
 'val_tail_rel_acc': 0.6366339747949331}
Current avf_f1: 0.820606468305049, Best f1: 0.820606468305049





Epoch: 6/50, batch: 2/12522, train_loss: 1.6303388292726595e-05, t_ent_sample_acc: 0.8333333730697632, t_head_rel_sample_acc: 0.7500000298023224, t_tail_rel_sample_acc: 0.7500000298023224,lr: 0.0004999372785368682, batch_time: 0.6811408996582031, total_time: 2.8166308403015137 -------------

# Prediction

In [None]:
model_state_path = get_last_state_path(model_state_dict_dir)
# model_state_path = get_state_path(model_state_dict_dir, 16)
rel_extractor.load_state_dict(torch.load(model_state_path))
rel_extractor.eval()
print("------------model state {} loaded ----------------".format(model_state_path.split("/")[-1]))

In [None]:
def filter_duplicates(rel_list):
    rel_memory_set = set()
    filtered_rel_list = []
    for rel in rel_list:
        rel_memory = "{}\u2E80\{}\u2E80\{}\u2E80\{}\u2E80{}".format(*rel.values())
        if rel_memory not in rel_memory_set:
            filtered_rel_list.append(rel)
            rel_memory_set.add(rel_memory)
    return filtered_rel_list

In [None]:
def predict(short_test_data):
    '''
    short_test_data: seq_len <= max_seq_len
    '''
    indexed_test_data = get_indexed_train_valid_data(short_test_data)
    test_dataloader = DataLoader(MyDataset(indexed_test_data), 
                              batch_size = batch_size, 
                              shuffle = False, 
                              num_workers = 0,
                              drop_last = False,
                              collate_fn = generate_pred_batch,
                             )
    short_pred_sample_list = []
    for batch_test_data in tqdm(test_dataloader, desc = "Predicting"):
        text_id_list, text_list, batch_input_ids, \
        batch_attention_mask, batch_token_type_ids, \
        offset_map_list = batch_test_data

        batch_input_ids, \
        batch_attention_mask, \
        batch_token_type_ids = (batch_input_ids.to(device), 
                                  batch_attention_mask.to(device), 
                                  batch_token_type_ids.to(device)
                                 )
        with torch.no_grad():
            batch_ent_shaking_outputs, \
            batch_head_rel_shaking_outputs, \
            batch_tail_rel_shaking_outputs = rel_extractor(batch_input_ids, 
                                                              batch_attention_mask, 
                                                              batch_token_type_ids, 
                                                             )

        batch_ent_shaking_tag, \
        batch_head_rel_shaking_tag, \
        batch_tail_rel_shaking_tag = torch.argmax(batch_ent_shaking_outputs, dim = -1), \
                                     torch.argmax(batch_head_rel_shaking_outputs, dim = -1), \
                                     torch.argmax(batch_tail_rel_shaking_outputs, dim = -1)

        for ind in range(len(text_list)):
            text, offset_map = text_list[ind], offset_map_list[ind]
            ent_shaking_tag, \
            head_rel_shaking_tag, \
            tail_rel_shaking_tag = batch_ent_shaking_tag[ind], \
                                    batch_head_rel_shaking_tag[ind], \
                                    batch_tail_rel_shaking_tag[ind]
            rel_list = handshaking_tagger.decode_rel_fr_shaking_tag(text, 
                                                  ent_shaking_tag, 
                                                  head_rel_shaking_tag, 
                                                  tail_rel_shaking_tag, 
                                                  offset_map)
            short_pred_sample_list.append({
                "text": text,
                "id": text_id_list[ind],
                "relation_list": rel_list,
            })
    # merge
    text_id2rel_list = {}
    for sample in short_pred_sample_list:
        text_id = sample["id"]
        if text_id not in text_id2rel_list:
            text_id2rel_list[text_id] = sample["relation_list"]
        else:
            text_id2rel_list[text_id].extend(sample["relation_list"])

    text_id2text = {sample["id"]:sample["text"] for sample in test_data}
    merged_pred_sample_list = []
    for text_id, rel_list in text_id2rel_list.items():
        merged_pred_sample_list.append({
            "id": text_id,
            "text": text_id2text[text_id],
            "relation_list": filter_duplicates(rel_list),
        })
    return merged_pred_sample_list

In [None]:
pred_sample_list = predict(short_test_data)

In [None]:
len([s for s in pred_sample_list if len(s["relation_list"]) > 0])

In [None]:
text_id2gold_n_pred = {}
for sample in test_data:
    text_id = sample["id"]
    text_id2gold_n_pred[text_id] = {
        "gold_relation_list": sample["relation_list"],
    }
def get_test_prf(pred_sample_list):
    for sample in pred_sample_list:
        text_id = sample["id"]
        text_id2gold_n_pred[text_id]["pred_relation_list"] = sample["relation_list"]

    correct_num, pred_num, gold_num = 0, 0, 0
    for gold_n_pred in text_id2gold_n_pred.values():
        gold_rel_list = gold_n_pred["gold_relation_list"]
        pred_rel_list = gold_n_pred["pred_relation_list"] if "pred_relation_list" in gold_n_pred else []
        gold_rel_set = set(["{}\u2E80{}\u2E80{}".format(rel["subject"], rel["predicate"], rel["object"]) for rel in gold_rel_list])
        pred_rel_set = set(["{}\u2E80{}\u2E80{}".format(rel["subject"], rel["predicate"], rel["object"]) for rel in pred_rel_list])

        for rel_str in pred_rel_set:
            if rel_str in gold_rel_set:
                correct_num += 1

        pred_num += len(gold_rel_set)
        gold_num += len(pred_rel_set)

    prf = get_scores(correct_num, pred_num, gold_num)
#     print(prf)
    return prf

In [None]:
# model state 16: (0.9112068965517129, 0.9034188034187924, 0.9072961372890456)
# model state 17: (0.9060344827586095, 0.9096191889218483, 0.9078232970872052)
# 18: (0.9178571428571316, 0.904600072824361, 0.9111803899493801)
get_test_prf(pred_sample_list)