In [1]:
import json
import os
from tqdm import tqdm
import re
from IPython.core.debugger import set_trace
from pprint import pprint
import unicodedata
from transformers import AutoModel, BasicTokenizer, BertTokenizerFast
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import glob
import time
from bs4 import BeautifulSoup
import random
import math

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device {} will be used".format(device))

device cuda:0 will be used


In [3]:
batch_size = 8
max_seq_len = 109

In [4]:
pretrained_model_home = "/home/wangyucheng/opt/transformers_models_h5"
project_root = "/home/wangyucheng/opt/data/research/relextr"
data_home = os.path.join(project_root, "data")

experiment = "only_shaking+roberta+restartwarm@webnlg"
model_state_dict_dir = os.path.join(project_root, "state_dict", experiment, "model")
schedule_state_dict_dir = os.path.join(project_root, "state_dict", experiment, "opt_schedule")

if not os.path.exists(model_state_dict_dir):
    os.makedirs(model_state_dict_dir)
if not os.path.exists(schedule_state_dict_dir):
    os.makedirs(schedule_state_dict_dir)

In [5]:
webnlg_data_dir = os.path.join(data_home, "webnlg", "original")
webnlg_train_data_dir = os.path.join(webnlg_data_dir, "train")
webnlg_valid_data_dir = os.path.join(webnlg_data_dir, "dev")

# Load and Preprocess Data

In [6]:
webnlg_train_data_file_path_list = glob.glob(webnlg_train_data_dir + "/*/*.xml")
webnlg_valid_data_file_path_list = glob.glob(webnlg_valid_data_dir + "/*/*.xml")

In [7]:
def get_normal_sample_list(file_path_list):
    all_entry_list = []
    for file_path in tqdm(file_path_list, desc = "Loading entries"):
        soup = BeautifulSoup(open(file_path, "r", encoding = "utf-8"), "lxml")
        entry_list = soup.select("entry")
        all_entry_list.extend(entry_list)

    normal_sample_list = []
    sent2spo_list = {}
    for entry in tqdm(all_entry_list, desc = "Transforming into normal format"):
        sents = [lex.get_text() for lex in entry.select("lex")]
        spo_list = [[re.sub("_", " ", e.strip()) for e in triple.get_text().split("|")] for triple in entry.select("modifiedtripleset > mtriple")]
        spo_list = [{"subject": spo[0], "predicate": spo[1], "object": spo[2]} for spo in spo_list]
        for sent in sents:
            if sent not in sent2spo_list:
                sent2spo_list[sent] = spo_list
            else:
                sent2spo_list[sent].extend(spo_list)
    for sent, spo_list in sent2spo_list.items():
        # filter duplicates
        spo_memory = set()
        filtered_spo_list = []
        for spo in spo_list:
            memory = "{}\u2E80{}\u2E80{}".format(spo["subject"], spo["predicate"], spo["object"])
            if memory in spo_memory:
                continue
            if spo["subject"] not in sent or spo["object"] not in sent:
                continue
            filtered_spo_list.append(spo)
            spo_memory.add(memory)
        if len(filtered_spo_list) == 0:
            continue
        normal_sample_list.append({
            "text": sent,
            "relation_list": filtered_spo_list,
        })
    return normal_sample_list

In [8]:
train_normal_sample_list = get_normal_sample_list(webnlg_train_data_file_path_list)
valid_normal_sample_list = get_normal_sample_list(webnlg_valid_data_file_path_list)
print(len(train_normal_sample_list))
print(len(valid_normal_sample_list))

Loading entries: 100%|██████████| 52/52 [00:10<00:00,  5.13it/s]
Transforming into normal format: 100%|██████████| 6940/6940 [00:08<00:00, 772.74it/s] 
Loading entries: 100%|██████████| 52/52 [00:01<00:00, 43.88it/s]
Transforming into normal format: 100%|██████████| 872/872 [00:01<00:00, 782.26it/s]

12995
1671





In [9]:
# split train data into train and valid set, take original valid set as test set 
random.seed(233)
random.shuffle(train_normal_sample_list)
train_data, valid_data = train_normal_sample_list[:-2000], train_normal_sample_list[-2000:]
test_data = valid_normal_sample_list

In [10]:
# add text id
text_id = 0
for sample in tqdm(train_data + valid_data + test_data):
    sample["id"] = text_id
    text_id += 1

100%|██████████| 14666/14666 [00:00<00:00, 779453.14it/s]


# Span

In [11]:
model_path = os.path.join(pretrained_model_home, "bert-base-cased")
tokenizer = BertTokenizerFast.from_pretrained(model_path, add_special_tokens = False, do_lower_case = False)

In [12]:
def get_char_ind2tok_ind(offset_map):
    char_num = None
    for tok_ind in range(len(offset_map) - 1, -1, -1):
        if offset_map[tok_ind][1] != 0:
            char_num = offset_map[tok_ind][1]
            break
    char_ind2tok_ind = [0 for _ in range(char_num)] # 除了空格(0)，其他字符均有对应token
    for tok_ind, sp in enumerate(offset_map):
        for char_ind in range(sp[0], sp[1]):
            char_ind2tok_ind[char_ind] = tok_ind
    return char_ind2tok_ind

In [13]:
def get_ent2char_spans(text, entities):
    entities = sorted(entities, key = lambda x: len(x), reverse = True)
    text_cp = text[:]
    ent2char_spans = {}
    for ent in entities:
        spans = []
        for m in re.finditer(re.escape(ent), text_cp):
            spans.append(m.span())
        ent2char_spans[ent] = spans
    return ent2char_spans

In [14]:
def char_sp2tok_sp(span, char_ind2tok_ind):
    tok_span = (char_ind2tok_ind[span[0]], char_ind2tok_ind[span[1] - 1] + 1)
    return tok_span

In [15]:
def get_ent2tok_spans(text, entities):
    ent2char_spans = get_ent2char_spans(text, entities)
    offset_map = tokenizer.encode_plus(text, 
                                       return_offsets_mapping = True, 
                                       add_special_tokens = False)["offset_mapping"]
    char_ind2tok_ind = get_char_ind2tok_ind(offset_map)
    ent2tok_spans = {}
    for ent, char_spans in ent2char_spans.items():
        tok_spans = [char_sp2tok_sp(sp, char_ind2tok_ind) for sp in char_spans]
        ent2tok_spans[ent] = tok_spans
    return ent2tok_spans

In [16]:
# # test get_ent2tok_spans # debug
# sample = train_data[0]
# entities = [rel["subject"] for rel in sample["relation_list"]]
# entities.extend([rel["object"] for rel in sample["relation_list"]])
# get_ent2tok_spans(sample["text"], entities)

In [17]:
# check token level span
dif_ent_pairs = []
for sample in tqdm(train_data + valid_data + test_data):
    entities = [rel["subject"] for rel in sample["relation_list"]]
    entities.extend([rel["object"] for rel in sample["relation_list"]])
    ent2tok_spans = get_ent2tok_spans(sample["text"], entities)
    offset_map = tokenizer.encode_plus(sample["text"], 
                                       return_offsets_mapping = True, 
                                       add_special_tokens = False)["offset_mapping"]
    for ent, tok_spans in ent2tok_spans.items():
        for sp in tok_spans:
            char_span_list = offset_map[sp[0]:sp[1]]
            char_span = (char_span_list[0][0], char_span_list[-1][1])
            decoded_ent = sample["text"][char_span[0]:char_span[1]]
            if ent != decoded_ent:
                dif_ent_pairs.append((ent, decoded_ent))
#             assert ent == decoded_ent

100%|██████████| 14666/14666 [00:11<00:00, 1262.23it/s]


In [18]:
# add span
for sample in tqdm(train_data + valid_data + test_data):
    entities = [rel["subject"] for rel in sample["relation_list"]]
    entities.extend([rel["object"] for rel in sample["relation_list"]])
    ent2tok_spans = get_ent2tok_spans(sample["text"], entities)
    
    new_relation_list = []
    for rel in sample["relation_list"]:
        subj_spans = ent2tok_spans[rel["subject"]]
        obj_spans = ent2tok_spans[rel["object"]]
        if len(subj_spans) == 0 or len(obj_spans) == 0:
            set_trace()
        for subj_sp in subj_spans:
            for obj_sp in obj_spans:
                new_relation_list.append({
                    "subject": rel["subject"],
                    "object": rel["object"],
                    "subj_span": subj_sp,
                    "obj_span": obj_sp,
                    "predicate": rel["predicate"],
                })
    sample["relation_list"] = new_relation_list

100%|██████████| 14666/14666 [00:08<00:00, 1716.30it/s]


In [19]:
# # check if any empty relation list
# for sample in tqdm(train_data + valid_data + test_data):
#     if len(sample["relation_list"]) == 0:
#         set_trace()

# Tagging

In [20]:
rel_set = set()
for sample in tqdm(train_data + valid_data + test_data):
    for rel in sample["relation_list"]:
        rel_set.add(rel["predicate"])
rel_set = sorted(rel_set)

100%|██████████| 14666/14666 [00:00<00:00, 249080.48it/s]


In [21]:
rel2id = {rel:ind for ind, rel in enumerate(rel_set)}
id2rel = {ind:rel for rel, ind in rel2id.items()}

In [22]:
def get_spots(sample):
    '''
    spot: (rel_id, span_pos1, span_pos2, tag_id)
    tag_id: 
        ent_matrix_spots: 0.other, 1.entity
        head_rel_matrix_spots: 0.other, 1.subj head -> obj head 2. obj head -> subj head
        tail_rel_matrix_spots: 0.other, 1.subj tail -> obj tail 2. obj tail -> subj tail
    '''
    matrix_spots = [] 
    spot_memory_set = set()
    def add_spot(spot):
        memory = "{},{},{},{}".format(*spot)
        if memory not in spot_memory_set:
            matrix_spots.append(spot)
            spot_memory_set.add(memory)
            
    for rel in sample["relation_list"]:
        subj_span = rel["subj_span"]
        obj_span = rel["obj_span"]
        # entity head and tail: 1
        add_spot((rel2id[rel["predicate"]], subj_span[0], subj_span[1] - 1, 1))
        add_spot((rel2id[rel["predicate"]], obj_span[0], obj_span[1] - 1, 1))
        if  subj_span[0] <= obj_span[0]:
            # subj head -> obj head: 2
            add_spot((rel2id[rel["predicate"]], subj_span[0], obj_span[0], 2))
        else:
            # obj head -> subj head: 4
            add_spot((rel2id[rel["predicate"]], obj_span[0], subj_span[0], 4))
            
        if subj_span[1] <= obj_span[1]:
            # subj tail -> obj tail: 3
            add_spot((rel2id[rel["predicate"]], subj_span[1] - 1, obj_span[1] - 1, 3))
        else:
            # obj tail -> subj tail: 5
            add_spot((rel2id[rel["predicate"]], obj_span[1] - 1, subj_span[1] - 1, 5))
            
    return matrix_spots

In [23]:
matrix_size = max_seq_len
shaking_ind2matrix_ind = [(ind, end_ind) for ind in range(matrix_size) for end_ind in list(range(matrix_size))[ind:]]

matrix_ind2shaking_ind = [[0 for i in range(matrix_size)] for j in range(matrix_size)]
for shaking_ind, matrix_ind in enumerate(shaking_ind2matrix_ind):
    matrix_ind2shaking_ind[matrix_ind[0]][matrix_ind[1]] = shaking_ind
    
def get_shaking_ind2matrix_ind():
    '''
    return: e.g. [(0, 0), (0, 1), (0, 2), (1, 1), (1, 2), (2, 2)]
    
    '''
    return shaking_ind2matrix_ind

def get_matrix_ind2shaking_ind():
    return matrix_ind2shaking_ind

In [24]:
def spots2shaking_tag(spots):
    '''
    convert spots to shaking seq tag
    spots: [(predicate_id, start_ind, end_ind, tag_id), ]
    return: shake_seq_tag
    '''
    matrix_ind2shaking_ind = get_matrix_ind2shaking_ind()
    shaking_seq_len = matrix_size * (matrix_size + 1) // 2
    shaking_seq_tag = torch.zeros(len(rel2id), shaking_seq_len).long()
    for sp in spots:
        shaking_ind = matrix_ind2shaking_ind[sp[1]][sp[2]]
        tag_id = sp[3]
        shaking_seq_tag[sp[0]][shaking_ind] += 2**(tag_id - 1)
    return shaking_seq_tag

In [25]:
def spots2shaking_tag4batch(batch_spots):
    '''
    convert spots to batch shaking seq tag
    因长序列的stack是费时操作，所以写这个函数转用作生成批量shaking tag
    如果每个样本生成一条shaking tag再stack，一个32的batch耗时1s，太昂贵
    spots: [(predicate_id, start_ind, end_ind, tag_id), ]
    return: 
        batch_shake_seq_tag: (batch_size, rel_size, shaking_seq_len)
    '''
    matrix_ind2shaking_ind = get_matrix_ind2shaking_ind()
    shaking_seq_len = matrix_size * (matrix_size + 1) // 2
    batch_shaking_seq_tag = torch.zeros(len(batch_spots), len(rel2id), shaking_seq_len).long()
    for batch_id, spots in enumerate(batch_spots):
        for sp in spots:
            shaking_ind = matrix_ind2shaking_ind[sp[1]][sp[2]]
            tag_id = sp[3]
            rel_id = sp[0]
            batch_shaking_seq_tag[batch_id][rel_id][shaking_ind] += 2**(tag_id - 1)
    return batch_shaking_seq_tag

In [26]:
def get_spots_fr_shaking_tag(shaking_tag):
    spots = []
    shaking_ind2matrix_ind = get_shaking_ind2matrix_ind()
#     %timeit shaking_tag.nonzero() # 4.5ms
    for shaking_inds in shaking_tag.nonzero():
        rel_id = shaking_inds[0].item()
        matrix_inds = shaking_ind2matrix_ind[shaking_inds[1]]
        tag_id_sum = shaking_tag[rel_id][shaking_inds[1]].item()
        for tag_id in range(1, 6):
            if tag_id_sum & 2**(tag_id - 1) != 0:
                spot = (rel_id, matrix_inds[0], matrix_inds[1], tag_id)
                spots.append(spot)
    return spots

In [27]:
# # check spots decoding
# sample = train_data[0]
# # pprint(sample)
# matrix_spots = get_spots(sample)
# print(matrix_spots)
# shaking_tag = spots2shaking_tag(matrix_spots)
# # %timeit spots2shaking_tag(matrix_spots)
# decoded_spots = get_spots_fr_shaking_tag(shaking_tag)
# # %timeit get_spots_fr_shaking_tag(shaking_tag) # 5ms
# print(decoded_spots)

In [28]:
def decode_rel_fr_shaking_tag(text, shaking_tag, offset_map):
    '''
    shaking_tag: size = (rel_size, shaking_seq_len, )
    '''
    rel_list = []
    
    matrix_spots = get_spots_fr_shaking_tag(shaking_tag)
#     spots = sorted(spots, key = lambda sp: (sp[0], sp[3])) # group by relation id, and entity first

    # entity
    head_ind2entities = {}
    for sp in matrix_spots:
        tag_id = sp[3]
        if tag_id != 1:
            continue
            
        rel_id = sp[0]
        
        char_span_list = offset_map[sp[1]:sp[2] + 1]
        char_sp = (char_span_list[0][0], char_span_list[-1][1])
        ent_text = text[char_sp[0]:char_sp[1]] 
        
        rel_head_key = "{}-{}".format(rel_id, sp[1])
        if rel_head_key not in head_ind2entities:
            head_ind2entities[rel_head_key] = []
        head_ind2entities[rel_head_key].append({
            "text": ent_text,
            "span": (sp[1], sp[2] + 1),
        })
        
    # tail relation
    tail_rel_memory_set = set()
    for sp in matrix_spots:
        rel_id = sp[0]
        tag_id = sp[3]
        if tag_id == 3:
            tail_rel_memory = "{}-{}-{}".format(rel_id, sp[1], sp[2])
            tail_rel_memory_set.add(tail_rel_memory)
        elif tag_id == 5:
            tail_rel_memory = "{}-{}-{}".format(rel_id, sp[2], sp[1])
            tail_rel_memory_set.add(tail_rel_memory)

    # head relation
    for sp in matrix_spots:
        rel_id = sp[0]
        tag_id = sp[3]
        
        if tag_id == 2:
            subj_head_key, obj_head_key = "{}-{}".format(rel_id, sp[1]), "{}-{}".format(rel_id, sp[2])
            if subj_head_key not in head_ind2entities or obj_head_key not in head_ind2entities:
                continue
            subj_list = head_ind2entities[subj_head_key]
            obj_list = head_ind2entities[obj_head_key]

            for subj in subj_list:
                for obj in obj_list:
                    tail_rel_memory = "{}-{}-{}".format(rel_id, subj["span"][1] - 1, obj["span"][1] - 1)
                    if tail_rel_memory not in tail_rel_memory_set:
                        continue
                    rel_list.append({
                        "subject": subj["text"],
                        "object": obj["text"],
                        "subj_span": subj["span"],
                        "obj_span": obj["span"],
                        "predicate": id2rel[rel_id],
                    })
        elif tag_id == 4:
            subj_head_key, obj_head_key = "{}-{}".format(rel_id, sp[2]), "{}-{}".format(rel_id, sp[1])
            if subj_head_key not in head_ind2entities or obj_head_key not in head_ind2entities:
                continue
            subj_list = head_ind2entities[subj_head_key]
            obj_list = head_ind2entities[obj_head_key]

            for subj in subj_list:
                for obj in obj_list:
                    tail_rel_memory = "{}-{}-{}".format(rel_id, subj["span"][1] - 1, obj["span"][1] - 1)
                    if tail_rel_memory not in tail_rel_memory_set:
                        continue
                    rel_list.append({
                        "subject": subj["text"],
                        "object": obj["text"],
                        "subj_span": subj["span"],
                        "obj_span": obj["span"],
                        "predicate": id2rel[rel_id],
                    })
    return rel_list

In [29]:
def sample_equal_to(sample1, sample2):
    assert sample1["id"] == sample2["id"]
    assert sample1["text"] == sample2["text"]
    memory_set = set()
    for rel in sample2["relation_list"]:
        memory = "{}\u2E80{}\u2E80{}".format(str(rel["subj_span"]), 
                                             rel["predicate"], 
                                             str(rel["obj_span"]))
        memory_set.add(memory)
    for rel in sample1["relation_list"]:
        memory = "{}\u2E80{}\u2E80{}".format(str(rel["subj_span"]), 
                                             rel["predicate"], 
                                             str(rel["obj_span"]))
        if memory not in memory_set:
            set_trace()
            return False
    return True

In [30]:
# sample = train_data[102]
# matrix_spots = get_spots(sample) # 79 us
# # print(matrix_spots)
# shaking_tag = spots2shaking_tag(matrix_spots) # 10 ms
# # %timeit spots2shaking_tag(matrix_spots)
# text = sample["text"]
# offset_map = tokenizer.encode_plus(text, return_offsets_mapping = True, 
#                                    add_special_tokens = False)["offset_mapping"]

# decoded_rel_list = decode_rel_fr_shaking_tag(text, 
#                                              shaking_tag, 
#                                              offset_map) # 4ms
# # %timeit decode_rel_fr_shaking_tag(text, shaking_tag, offset_map)
# pred_sample = {
#     "id": sample["id"],
#     "text": text,
#     "relation_list": decoded_rel_list,
# }
# # print(sample)
# # print()
# # print(pred_sample)
# sample_equal_to(sample, pred_sample)

In [31]:
# # check tagging and decoding
# for sample in tqdm(train_data + valid_data + test_data):
#     matrix_spots = get_spots(sample)
#     shaking_tag = spots2shaking_tag(matrix_spots)

#     text = sample["text"]
#     offset_map = tokenizer.encode_plus(text, return_offsets_mapping = True, 
#                                        add_special_tokens = False)["offset_mapping"]
#     decoded_rel_list = decode_rel_fr_shaking_tag(text, 
#                                                  shaking_tag,
#                                                  offset_map)
#     pred_sample = {
#         "id": sample["id"],
#         "text": text,
#         "relation_list": decoded_rel_list,
#     }
#     if not sample_equal_to(pred_sample, sample):
#         set_trace()

In [32]:
# # check batch tagging and decoding
# batch_spots = []
# batch_samples = train_data[:100]
# for sample in tqdm(batch_samples):
#     batch_spots.append(get_spots(sample))
# batch_shaking_tag = spots2shaking_tag4batch(batch_spots)

# for ind, sample in tqdm(enumerate(batch_samples)):
#     text = sample["text"]
#     offset_map = tokenizer.encode_plus(text, return_offsets_mapping = True, 
#                                        add_special_tokens = False)["offset_mapping"]
#     shaking_tag = batch_shaking_tag[ind]
#     decoded_rel_list = decode_rel_fr_shaking_tag(text, 
#                                                  shaking_tag,
#                                                  offset_map)
#     pred_sample = {
#         "id": sample["id"],
#         "text": text,
#         "relation_list": decoded_rel_list,
#     }
#     if not sample_equal_to(pred_sample, sample):
#         set_trace()

# Dataset

In [33]:
def get_indexed_train_valid_data(data):
    indexed_samples = []
    for ind, sample in tqdm(enumerate(data), desc = "Generate indexed train or valid data"):
        text = sample["text"]
        text_id = sample["id"]
        # codes for bert input
        codes = tokenizer.encode_plus(text, 
                                    return_offsets_mapping = True, 
                                    add_special_tokens = False,
                                    max_length = max_seq_len, 
                                    pad_to_max_length = True)
        
        
        # tagging
        matrix_spots = get_spots(sample)
        
        # get codes
        input_ids = torch.tensor(codes["input_ids"]).long()
        attention_mask = torch.tensor(codes["attention_mask"]).long()
        token_type_ids = torch.tensor(codes["token_type_ids"]).long()
        offset_map = codes["offset_mapping"]

        sample_tp = (text_id,
                     text, 
                     input_ids,
                     attention_mask,
                     token_type_ids,
                     offset_map,
                     matrix_spots,
                    )
        indexed_samples.append(sample_tp)       
    return indexed_samples

In [34]:
def get_indexed_pred_data(data):
    indexed_samples = []
    for ind, sample in tqdm(enumerate(data), desc = "Generate indexed pred data"):
        text = sample["text"] 
        text_id = sample["id"]
        codes = tokenizer.encode_plus(text, 
                                    return_offsets_mapping = True, 
                                    add_special_tokens = False,
                                    max_length = max_seq_len, 
                                    pad_to_max_length = True)
        
        input_ids = torch.tensor(codes["input_ids"]).long()
        attention_mask = torch.tensor(codes["attention_mask"]).long()
        token_type_ids = torch.tensor(codes["token_type_ids"]).long()
        offset_map = codes["offset_mapping"]

        sample_tp = (text_id,
                     text, 
                     input_ids,
                     attention_mask,
                     token_type_ids,
                     offset_map,
                     )
        indexed_samples.append(sample_tp)       
    return indexed_samples

In [35]:
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)

In [36]:
def generate_train_dev_batch(batch_data):
    text_id_list = []
    text_list = []
    input_ids_list = []
    attention_mask_list = []
    token_type_ids_list = [] 
    offset_map_list = []
    matrix_spots_list = []
    
    for sample in batch_data:
        text_id_list.append(sample[0])
        text_list.append(sample[1])
        input_ids_list.append(sample[2])
        attention_mask_list.append(sample[3])        
        token_type_ids_list.append(sample[4])        
        offset_map_list.append(sample[5])
        
        matrix_spots_list.append(sample[6])
    
    batch_input_ids = torch.stack(input_ids_list, dim = 0)
    batch_attention_mask = torch.stack(attention_mask_list, dim = 0)
    batch_token_type_ids = torch.stack(token_type_ids_list, dim = 0)

    batch_shaking_tag = spots2shaking_tag4batch(matrix_spots_list)
#     batch_shaking_tag = torch.stack(shaking_tag_list, dim = 0)
    return text_id_list, text_list, batch_input_ids, batch_attention_mask, batch_token_type_ids, offset_map_list, batch_shaking_tag

In [37]:
def generate_pred_batch(batch_data):
    text_ids = []
    text_list = []
    input_ids = []
    attention_mask = []
    token_type_ids = [] 
    offset_map = []
    for sample in batch_data:
        text_ids.append(sample[0])
        text_list.append(sample[1])
        input_ids.append(sample[2])
        attention_mask.append(sample[3])        
        token_type_ids.append(sample[4])        
        offset_map.append(sample[5])
    input_ids = torch.stack(input_ids, dim = 0)
    attention_mask = torch.stack(attention_mask, dim = 0)
    token_type_ids = torch.stack(token_type_ids, dim = 0)
    return text_ids, text_list, input_ids, attention_mask, token_type_ids, offset_map

In [38]:
# @uni
def get_train_dev_dataloader_gen(indexed_train_sample_list, indexed_dev_sample_list, batch_size):
    train_dataloader = DataLoader(MyDataset(indexed_train_sample_list), 
                                      batch_size = batch_size, 
                                      shuffle = True, 
                                      num_workers = 0,
                                      drop_last = False,
                                      collate_fn = generate_train_dev_batch,
                                     )
    dev_dataloader = DataLoader(MyDataset(indexed_dev_sample_list), 
                              batch_size = batch_size, 
                              shuffle = True, 
                              num_workers = 0,
                              drop_last = False,
                              collate_fn = generate_train_dev_batch,
                             )
    return train_dataloader, dev_dataloader

In [39]:
indexed_train_data = get_indexed_train_valid_data(train_data)

Generate indexed train or valid data: 10995it [00:05, 2195.31it/s]


In [40]:
indexed_valid_data = get_indexed_train_valid_data(valid_data)

Generate indexed train or valid data: 2000it [00:00, 2184.91it/s]


In [41]:
# have a look at dataloader
train_dataloader, dev_dataloader = get_train_dev_dataloader_gen(indexed_train_data, indexed_valid_data, 24)

In [42]:
# train_data_iter = iter(train_dataloader)
# %timeit next(train_data_iter)

In [43]:
train_data_iter = iter(train_dataloader)
batch_data = next(train_data_iter)
text_id_list, text_list, batch_input_ids, \
batch_attention_mask, batch_token_type_ids, \
offset_map_list, batch_shaking_tag = batch_data

print(text_list[0])
print()
print(tokenizer.decode(batch_input_ids[0].tolist()))
print(batch_input_ids.size())
print(batch_attention_mask.size())
print(batch_token_type_ids.size())
print(len(offset_map_list))
print(batch_shaking_tag.size())

Aenir written by Garth Nix is available in paperback with the OCLC number 45644811.

Aenir written by Garth Nix is available in paperback with the OCLC number 45644811. [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
torch.Size([24, 109])
torch.Size([24, 109])
torch.Size([24, 109])
24
torch.Size([24, 208, 5995])


# Model

In [44]:
def shake_hands_afterwards(seq_hiddens):
    '''
    seq_hiddens: (batch_size, seq_len, hidden_size) (32, 3, 5)
    return shake_hands_matrix_hiddens: (batch_size, (1 + seq_len) * seq_len / 2, hidden_size) (32, 5+4+3+2+1, 5)
    '''
    seq_len = seq_hiddens.size()[-2]
    shake_hands_hidden_list = []
    for ind in range(seq_len):
        hidden_each_step = seq_hiddens[:, ind, :]
        # seq_len - ind: only shake afterwards
        repeat_hidden_each_step = hidden_each_step[:, None, :].repeat(1, seq_len - ind, 1) 
        shake_hands_hidden = torch.cat([repeat_hidden_each_step, seq_hiddens[:, ind:, :]], dim = -1)
        shake_hands_hidden_list.append(shake_hands_hidden)
    shake_hands_matrix_hiddens = torch.cat(shake_hands_hidden_list, dim = 1)
    return shake_hands_matrix_hiddens

In [45]:
# # 尝试用卷积层代替全连接层做并行运算，因为for循环太慢。但repeat操作会成倍增加显存（200G），不现实
# x = torch.tensor([[[1, 3, 4.], [2, 6, 8.]]])
# x = x.repeat(1, 1, 2)
# x = x.view(-1, x.size()[-1]).unsqueeze(-1)
# print(x)
# print(x.size())
# conv_fc = nn.Conv1d(in_channels = 3 * 2, 
#                     out_channels = 4 * 2, 
#                     kernel_size = 1, 
#                     groups = 2)

# outputs = conv_fc(x)
# print(outputs.size())
# print(outputs)

# outputs = outputs.view(1, -1, 2, 4)
# print(outputs.size())
# print(outputs)

In [46]:
class RelExtractor(nn.Module):
    def __init__(self, encoder, rel_size):
        super().__init__()
        self.encoder = encoder
        hidden_size = encoder.config.hidden_size

        self.shaking_fc_list = [nn.Linear(hidden_size * 2, 2**5) for _ in range(rel_size)]
        
        # register fcs
        for ind, fc in enumerate(self.shaking_fc_list):
            self.register_parameter("weight_4_ent_in_rel{}".format(ind), fc.weight)
            self.register_parameter("bias_4_ent_in_rel{}".format(ind), fc.bias)
        
    def forward(self, input_ids, attention_mask, token_type_ids):
        # input_ids, attention_mask, token_type_ids: (batch_size, seq_len)
        context_outputs = self.encoder(input_ids, attention_mask, token_type_ids)
        # last_hidden_state: (batch_size, seq_len, hidden_size)
        last_hidden_state = context_outputs[0]
        
        # shaking_hiddens: (batch_size, 1 + ... + seq_len, hidden_size)
        shaking_hiddens = shake_hands_afterwards(last_hidden_state)
        
        shaking_outputs_list = []
        for fc in self.shaking_fc_list:
            shaking_outputs_list.append(fc(shaking_hiddens))

        shaking_outputs = torch.stack(shaking_outputs_list, dim = 1)
        return shaking_outputs

In [47]:
roberta = AutoModel.from_pretrained(model_path)

In [48]:
rel_extractor = RelExtractor(roberta, len(rel2id))
rel_extractor = rel_extractor.to(device)

In [49]:
# # test model
# batch_input_ids, \
# batch_attention_mask,\
# batch_token_type_ids = batch_input_ids.to(device), \
#                          batch_attention_mask.to(device), \
#                          batch_token_type_ids.to(device)
                       
# shaking_outputs = rel_extractor(batch_input_ids, 
#                                   batch_attention_mask, 
#                                   batch_token_type_ids)

# print(shaking_outputs.size())

In [50]:
def bias_loss(weights = None):
    if weights is not None:
        weights = torch.FloatTensor(weights).to(device)
    cross_en = nn.CrossEntropyLoss(weight = weights)  
    return lambda pred, target: cross_en(pred.view(-1, pred.size()[-1]), target.view(-1))
loss_func = bias_loss()

In [51]:
def get_sample_accuracy(pred, truth):
    '''
    计算所有抽取字段都正确的样本比例
    即该batch的输出与truth全等的样本比例
    '''
    # (batch_size, ..., seq_len, tag_size) -> (batch_size, ..., seq_len)
    pred_id = torch.argmax(pred, dim = -1)
    # (batch_size, ..., seq_len) -> (batch_size, )，把每个sample压成一条seq
    pred_id = pred_id.view(pred_id.size()[0], -1)
    truth = truth.view(truth.size()[0], -1)
    
    # (batch_size, )，每个元素是pred与truth之间tag相同的数量
    correct_tag_num = torch.sum(torch.eq(truth, pred_id).float(), dim = 1)

    # seq维上所有tag必须正确，所以correct_tag_num必须等于seq的长度才算一个correct的sample
    sample_acc_ = torch.eq(correct_tag_num, torch.ones_like(correct_tag_num) * truth.size()[-1]).float()
    sample_acc = torch.mean(sample_acc_, axis=0)
    
    return sample_acc

In [52]:
def get_rel_cpg(text_list, offset_map_list, 
                 batch_pred_shaking_outputs,
                 batch_gold_shaking_tag):
    batch_pred_shaking_tag = torch.argmax(batch_pred_shaking_outputs, dim = -1)
    
    correct_num, pred_num, gold_num = 0, 0, 0
    for ind in range(len(text_list)):
        text = text_list[ind]
        offset_map = offset_map_list[ind]
        gold_shaking_tag, pred_shaking_tag = batch_gold_shaking_tag[ind], batch_pred_shaking_tag[ind]
 
        pred_rel_list = decode_rel_fr_shaking_tag(text, 
                                                  pred_shaking_tag,
                                                  offset_map)
        gold_rel_list = decode_rel_fr_shaking_tag(text, 
                                                  gold_shaking_tag,
                                                  offset_map)

        gold_rel_set = set(["{}\u2E80{}\u2E80{}".format(rel["subject"], rel["predicate"], rel["object"]) for rel in gold_rel_list])
        pred_rel_set = set(["{}\u2E80{}\u2E80{}".format(rel["subject"], rel["predicate"], rel["object"]) for rel in pred_rel_list])
        
        for rel_str in pred_rel_set:
            if rel_str in gold_rel_set:
                correct_num += 1
        
        pred_num += len(gold_rel_set)
        gold_num += len(pred_rel_set)
        
    return correct_num, pred_num, gold_num

In [53]:
def get_scores(correct_num, pred_num, gold_num):
    minimini = 1e-10
    precision = correct_num / (pred_num + minimini)
    recall = correct_num / (gold_num + minimini)
    f1 = 2 * precision * recall / (precision + recall + minimini)
    return precision, recall, f1

# Train

In [54]:
# train step
def train_step(train_data, optimizer):
    text_id_list, text_list, batch_input_ids, \
    batch_attention_mask, batch_token_type_ids, \
    offset_map_list, batch_shaking_tag = train_data
    
    batch_input_ids, \
    batch_attention_mask, \
    batch_token_type_ids, \
    batch_shaking_tag = (batch_input_ids.to(device), 
                          batch_attention_mask.to(device), 
                          batch_token_type_ids.to(device), 
                          batch_shaking_tag.to(device)
                         )
    
    # zero the parameter gradients
    optimizer.zero_grad()
    
    shaking_outputs = rel_extractor(batch_input_ids, 
                                  batch_attention_mask, 
                                  batch_token_type_ids, 
                                 )
    
    # bp
    loss = loss_func(shaking_outputs, batch_shaking_tag)
    loss.backward()
    optimizer.step()
    
    sample_acc = get_sample_accuracy(shaking_outputs, 
                                          batch_shaking_tag)
    
    return loss.item(), sample_acc.item()

# valid step
def valid_step(valid_data):
    text_id_list, text_list, batch_input_ids, \
    batch_attention_mask, batch_token_type_ids, \
    offset_map_list, batch_shaking_tag = valid_data
    
    batch_input_ids, \
    batch_attention_mask, \
    batch_token_type_ids, \
    batch_shaking_tag = (batch_input_ids.to(device), 
                          batch_attention_mask.to(device), 
                          batch_token_type_ids.to(device), 
                          batch_shaking_tag.to(device)
                         )
    with torch.no_grad():
        shaking_outputs = rel_extractor(batch_input_ids, 
                                      batch_attention_mask, 
                                      batch_token_type_ids, 
                                     )
    
    sample_acc = get_sample_accuracy(shaking_outputs, 
                                      batch_shaking_tag)
    
    rel_cpg = get_rel_cpg(text_list, offset_map_list, 
                          shaking_outputs,
                          batch_shaking_tag)
    
    return sample_acc.item(), rel_cpg

In [55]:
max_f1 = 0.
def train_n_valid(train_dataloader, dev_dataloader, optimizer, scheduler, num_epoch):  
    def train(dataloader, ep):
        # train
        rel_extractor.train()
        
        t_ep = time.time()
        total_loss, total_sample_acc = 0., 0.
        for batch_ind, train_data in enumerate(dataloader):
            t_batch = time.time()
            loss, sample_acc = train_step(train_data, optimizer)
            scheduler.step()
            
            total_loss += loss
            total_sample_acc += sample_acc
            
            avg_loss = total_loss / (batch_ind + 1)
            avg_sample_acc = total_sample_acc / (batch_ind + 1)
            
            batch_print_format = "\rEpoch: {}/{}, batch: {}/{}, train_loss: {}, " + \
                                "t_sample_acc: {}," + \
                                 "lr: {}, batch_time: {}, total_time: {} -------------"
                    
            print(batch_print_format.format(ep + 1, num_epoch, 
                                            batch_ind + 1, len(dataloader), 
                                            avg_loss, 
                                            avg_sample_acc,
                                            optimizer.param_groups[0]['lr'],
                                            time.time() - t_batch,
                                            time.time() - t_ep,
                                           ), end="")
    def valid(dataloader, ep):
        # valid
        rel_extractor.eval()
        
        t_ep = time.time()
        total_sample_acc = 0.
        total_rel_correct_num, total_rel_pred_num, total_rel_gold_num = 0, 0, 0
        for batch_ind, dev_data in enumerate(tqdm(dataloader, desc = "Validating")):
            sample_acc, rel_cpg = valid_step(dev_data)

            total_sample_acc += sample_acc
            
            total_rel_correct_num += rel_cpg[0]
            total_rel_pred_num += rel_cpg[1]
            total_rel_gold_num += rel_cpg[2]

        avg_sample_acc = total_sample_acc / len(dataloader)
        
        rel_prf = get_scores(total_rel_correct_num, total_rel_pred_num, total_rel_gold_num)
             
        print_format = "Epoch: {}/{}, val_sample_acc: {}, " + \
                        "val_rel_prec: {}, val_rel_rec: {}, val_rel_f1: {},\n" + \
                        "val_time: {}"
        print(print_format.format(ep + 1, num_epoch,  
                                  avg_sample_acc,
                                  *rel_prf, 
                                  time.time() - t_ep,
                                 ))
        return rel_prf[2]
        
    for ep in range(num_epoch):
        train(train_dataloader, ep)   
        print()
        valid_f1 = valid(dev_dataloader, ep)
        
        global max_f1
        if valid_f1 >= max_f1: 
            max_f1 = valid_f1
            if valid_f1 > 0.7: # save the best model
                file_num = len(glob.glob(model_state_dict_dir + "/*.pt"))
                torch.save(rel_extractor.state_dict(), os.path.join(model_state_dict_dir, "model_state_dict_{}.pt".format(file_num))) 
                torch.save(scheduler.state_dict(), os.path.join(schedule_state_dict_dir, "scheduler_state_dict_{}.pt".format(file_num))) 
        print("Current avf_f1: {}, Best f1: {}".format(valid_f1, max_f1))

In [56]:
def get_last_state_path(state_dir):
    max_file_num = -1
    last_state_path = None
    for path in glob.glob(state_dir + "/*.pt"):
        file_num = re.search("state_dict_(\d+)\.pt", path).group(1)
        if int(file_num) > max_file_num:
            max_file_num = int(file_num)
            last_state_path = path
    return last_state_path

In [57]:
# dataloader
print("preparing dataloader...")
train_dataloader, \
dev_dataloader = get_train_dev_dataloader_gen(indexed_train_data, 
                                            indexed_valid_data, 
                                            batch_size, 
                                            )
print("dataloaders done!")

preparing dataloader...
dataloaders done!


In [58]:
# optimizer
init_learning_rate = 5e-5
optimizer = torch.optim.Adam(rel_extractor.parameters(), lr = init_learning_rate)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, len(train_dataloader) * 2)

In [59]:
epoch_num = 20
# load the last best state (if any)
model_all_state_paths = glob.glob(model_state_dict_dir + "/*.pt")
if len(model_all_state_paths) > 0:
    model_last_state_path = get_last_state_path(model_state_dict_dir)
    scheduler_last_state_path = get_last_state_path(schedule_state_dict_dir)
    
    rel_extractor.load_state_dict(torch.load(model_last_state_path))
    print("------------model state {} loaded ----------------".format(model_last_state_path.split("/")[-1]))
    scheduler.load_state_dict(torch.load(scheduler_last_state_path))
    print("------------scheduler state {} loaded ----------------".format(scheduler_last_state_path.split("/")[-1]))

train_n_valid(train_dataloader, dev_dataloader, optimizer, scheduler, epoch_num)

Epoch: 1/20, batch: 1374/1375, train_loss: 0.05591354720241775, t_sample_acc: 0.0,lr: 2.5028559927002325e-05, batch_time: 1.7606003284454346, total_time: 2460.0167138576508 --------------

Validating:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch: 1/20, batch: 1375/1375, train_loss: 0.05587333338165825, t_sample_acc: 0.0,lr: 2.5000000000000008e-05, batch_time: 0.9092319011688232, total_time: 2460.9425253868103 -------------


Validating: 100%|██████████| 250/250 [01:32<00:00,  2.70it/s]


Epoch: 1/20, val_sample_acc: 0.0, val_rel_prec: 0.0, val_rel_rec: 0.0, val_rel_f1: 0.0,
val_time: 92.59789848327637
Current avf_f1: 0.0, Best f1: 0.0
Epoch: 2/20, batch: 1374/1375, train_loss: 0.000495179559847542, t_sample_acc: 0.0,lr: 1.6313393930156295e-11, batch_time: 1.7613189220428467, total_time: 2457.3490686416626 --------------

Validating:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch: 2/20, batch: 1375/1375, train_loss: 0.0004951634890696203, t_sample_acc: 0.0,lr: 5e-05, batch_time: 0.8829731941223145, total_time: 2458.2559745311737 -------------


Validating: 100%|██████████| 250/250 [01:30<00:00,  2.76it/s]


Epoch: 2/20, val_sample_acc: 0.0, val_rel_prec: 0.0, val_rel_rec: 0.0, val_rel_f1: 0.0,
val_time: 90.53565335273743
Current avf_f1: 0.0, Best f1: 0.0
Epoch: 3/20, batch: 1374/1375, train_loss: 0.0002213461258051035, t_sample_acc: 0.0,lr: 2.5028559927002325e-05, batch_time: 1.7622177600860596, total_time: 2460.1938219070435 --------------

Validating:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch: 3/20, batch: 1375/1375, train_loss: 0.00022129190748620948, t_sample_acc: 0.0,lr: 2.5000000000000008e-05, batch_time: 0.8812675476074219, total_time: 2461.090471506119 -------------


Validating: 100%|██████████| 250/250 [01:32<00:00,  2.69it/s]


Epoch: 3/20, val_sample_acc: 0.0, val_rel_prec: 0.0, val_rel_rec: 0.0, val_rel_f1: 0.0,
val_time: 92.9315197467804
Current avf_f1: 0.0, Best f1: 0.0
Epoch: 4/20, batch: 1374/1375, train_loss: 0.00013934246172913787, t_sample_acc: 0.0,lr: 1.6313393930156295e-11, batch_time: 1.7619554996490479, total_time: 2465.8048458099365 -------------

Validating:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch: 4/20, batch: 1375/1375, train_loss: 0.0001394144866922447, t_sample_acc: 0.0,lr: 5e-05, batch_time: 0.8835532665252686, total_time: 2466.709825515747 -------------


Validating: 100%|██████████| 250/250 [01:32<00:00,  2.71it/s]


Epoch: 4/20, val_sample_acc: 0.0, val_rel_prec: 0.0, val_rel_rec: 0.0, val_rel_f1: 0.0,
val_time: 92.09782719612122
Current avf_f1: 0.0, Best f1: 0.0
Epoch: 5/20, batch: 1374/1375, train_loss: 0.00011600634373649985, t_sample_acc: 0.0,lr: 2.5028559927002325e-05, batch_time: 1.7495250701904297, total_time: 2466.1472256183624 -------------

Validating:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch: 5/20, batch: 1375/1375, train_loss: 0.0001159651218304961, t_sample_acc: 0.0,lr: 2.5000000000000008e-05, batch_time: 0.8820629119873047, total_time: 2467.0430574417114 -------------


Validating: 100%|██████████| 250/250 [01:32<00:00,  2.69it/s]


Epoch: 5/20, val_sample_acc: 0.0, val_rel_prec: 0.0, val_rel_rec: 0.0, val_rel_f1: 0.0,
val_time: 92.78888750076294
Current avf_f1: 0.0, Best f1: 0.0
Epoch: 6/20, batch: 1374/1375, train_loss: 9.921700558685742e-05, t_sample_acc: 0.0,lr: 1.6313393930156295e-11, batch_time: 1.7670164108276367, total_time: 2465.483920812607 ---------------

Validating:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch: 6/20, batch: 1375/1375, train_loss: 9.917968506025235e-05, t_sample_acc: 0.0,lr: 5e-05, batch_time: 0.8815395832061768, total_time: 2466.3890283107758 -------------


Validating: 100%|██████████| 250/250 [01:33<00:00,  2.69it/s]


Epoch: 6/20, val_sample_acc: 0.0, val_rel_prec: 0.0, val_rel_rec: 0.0, val_rel_f1: 0.0,
val_time: 93.07560300827026
Current avf_f1: 0.0, Best f1: 0.0
Epoch: 7/20, batch: 1374/1375, train_loss: 7.061909339068308e-05, t_sample_acc: 0.0,lr: 2.5028559927002325e-05, batch_time: 1.7627124786376953, total_time: 2461.3381419181824 -------------

Validating:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch: 7/20, batch: 1375/1375, train_loss: 7.060951090005057e-05, t_sample_acc: 0.0,lr: 2.5000000000000008e-05, batch_time: 0.8847787380218506, total_time: 2462.24200296402 -------------


Validating: 100%|██████████| 250/250 [01:32<00:00,  2.71it/s]


Epoch: 7/20, val_sample_acc: 0.0, val_rel_prec: 0.0, val_rel_rec: 0.0, val_rel_f1: 0.0,
val_time: 92.35467410087585
Current avf_f1: 0.0, Best f1: 0.0
Epoch: 8/20, batch: 1374/1375, train_loss: 5.209650205624037e-05, t_sample_acc: 0.0,lr: 1.6313393930156295e-11, batch_time: 1.7680702209472656, total_time: 2457.2851297855377 --------------

Validating:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch: 8/20, batch: 1375/1375, train_loss: 5.2096006303062575e-05, t_sample_acc: 0.0,lr: 5e-05, batch_time: 0.8844068050384521, total_time: 2458.1837944984436 -------------


Validating: 100%|██████████| 250/250 [01:31<00:00,  2.72it/s]


Epoch: 8/20, val_sample_acc: 0.0025, val_rel_prec: 0.010826998387468076, val_rel_rec: 0.7580645161278096, val_rel_f1: 0.02134908016983201,
val_time: 91.75291776657104
Current avf_f1: 0.02134908016983201, Best f1: 0.02134908016983201
Epoch: 9/20, batch: 1374/1375, train_loss: 3.624642441064088e-05, t_sample_acc: 0.010917030567685589,lr: 2.5028559927002325e-05, batch_time: 1.7616968154907227, total_time: 2459.0186688899994 --------------

Validating:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch: 9/20, batch: 1375/1375, train_loss: 3.623532440740911e-05, t_sample_acc: 0.01090909090909091,lr: 2.5000000000000008e-05, batch_time: 0.8807082176208496, total_time: 2459.913102388382 -------------


Validating: 100%|██████████| 250/250 [01:33<00:00,  2.67it/s]


Epoch: 9/20, val_sample_acc: 0.059, val_rel_prec: 0.26791061967288027, val_rel_rec: 0.8042876901797508, val_rel_f1: 0.4019353723488823,
val_time: 93.67486476898193
Current avf_f1: 0.4019353723488823, Best f1: 0.4019353723488823
Epoch: 10/20, batch: 1374/1375, train_loss: 2.35909072734893e-05, t_sample_acc: 0.048125909752547304,lr: 1.6313393930156295e-11, batch_time: 1.758556604385376, total_time: 2458.7368807792664 ----------------

Validating:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch: 10/20, batch: 1375/1375, train_loss: 2.3589992674467133e-05, t_sample_acc: 0.048090909090909094,lr: 5e-05, batch_time: 0.8775858879089355, total_time: 2459.6380870342255 -------------


Validating: 100%|██████████| 250/250 [01:34<00:00,  2.65it/s]


Epoch: 10/20, val_sample_acc: 0.0725, val_rel_prec: 0.4118866620594238, val_rel_rec: 0.7534766118836598, val_rel_f1: 0.5326184092482926,
val_time: 94.28609156608582
Current avf_f1: 0.5326184092482926, Best f1: 0.5326184092482926
Epoch: 11/20, batch: 1374/1375, train_loss: 1.93703652762736e-05, t_sample_acc: 0.08642649199417758,lr: 2.5028559927002325e-05, batch_time: 1.7640302181243896, total_time: 2460.225976228714 ----------------

Validating:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch: 11/20, batch: 1375/1375, train_loss: 1.936306426366421e-05, t_sample_acc: 0.08660606061328541,lr: 2.5000000000000008e-05, batch_time: 0.8817868232727051, total_time: 2461.133918762207 -------------


Validating: 100%|██████████| 250/250 [01:35<00:00,  2.62it/s]


Epoch: 11/20, val_sample_acc: 0.1575, val_rel_prec: 0.6222068647776866, val_rel_rec: 0.6970322580644981, val_rel_f1: 0.6574975656755586,
val_time: 95.41891622543335
Current avf_f1: 0.6574975656755586, Best f1: 0.6574975656755586
Epoch: 12/20, batch: 1374/1375, train_loss: 1.3654548929030257e-05, t_sample_acc: 0.17758369723435224,lr: 1.6313393930156295e-11, batch_time: 1.7563307285308838, total_time: 2461.918289422989 --------------

Validating:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch: 12/20, batch: 1375/1375, train_loss: 1.3662360651572023e-05, t_sample_acc: 0.17745454545454545,lr: 5e-05, batch_time: 0.8842799663543701, total_time: 2462.818258047104 -------------


Validating: 100%|██████████| 250/250 [01:33<00:00,  2.67it/s]


Epoch: 12/20, val_sample_acc: 0.204, val_rel_prec: 0.6602165399677341, val_rel_rec: 0.7654914529914325, val_rel_f1: 0.7089672232031926,
val_time: 93.62155079841614
Current avf_f1: 0.7089672232031926, Best f1: 0.7089672232031926
Epoch: 13/20, batch: 1374/1375, train_loss: 1.2507895986230653e-05, t_sample_acc: 0.19068413391557495,lr: 2.5028559927002325e-05, batch_time: 1.761150598526001, total_time: 2459.3586146831512 --------------

Validating:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch: 13/20, batch: 1375/1375, train_loss: 1.2505258118835628e-05, t_sample_acc: 0.1907878787951036,lr: 2.5000000000000008e-05, batch_time: 0.8808040618896484, total_time: 2460.2542552948 -------------


Validating: 100%|██████████| 250/250 [01:35<00:00,  2.63it/s]


Epoch: 13/20, val_sample_acc: 0.304, val_rel_prec: 0.7383091453581954, val_rel_rec: 0.7919446503582706, val_rel_f1: 0.7641869336648739,
val_time: 95.2155933380127
Current avf_f1: 0.7641869336648739, Best f1: 0.7641869336648739
Epoch: 14/20, batch: 1374/1375, train_loss: 8.90948975051171e-06, t_sample_acc: 0.3076783114992722,lr: 1.6313393930156295e-11, batch_time: 1.7615888118743896, total_time: 2462.4932713508606 ---------------

Validating:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch: 14/20, batch: 1375/1375, train_loss: 8.915239660960982e-06, t_sample_acc: 0.3076969697041945,lr: 5e-05, batch_time: 0.885399580001831, total_time: 2463.402421236038 -------------


Validating: 100%|██████████| 250/250 [01:35<00:00,  2.61it/s]


Epoch: 14/20, val_sample_acc: 0.325, val_rel_prec: 0.7991246256622714, val_rel_rec: 0.8052460538532775, val_rel_f1: 0.8021736616449757,
val_time: 95.6311240196228
Current avf_f1: 0.8021736616449757, Best f1: 0.8021736616449757
Epoch: 15/20, batch: 1374/1375, train_loss: 8.678583629915982e-06, t_sample_acc: 0.3056768558951965,lr: 2.5028559927002325e-05, batch_time: 1.7650160789489746, total_time: 2462.0267374515533 --------------

Validating:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch: 15/20, batch: 1375/1375, train_loss: 8.684439946591737e-06, t_sample_acc: 0.3054545454545455,lr: 2.5000000000000008e-05, batch_time: 0.8861005306243896, total_time: 2462.927896499634 -------------


Validating: 100%|██████████| 250/250 [01:35<00:00,  2.62it/s]


Epoch: 15/20, val_sample_acc: 0.359, val_rel_prec: 0.8442755125546915, val_rel_rec: 0.8173505798394108, val_rel_f1: 0.8305949007998528,
val_time: 95.3206787109375
Current avf_f1: 0.8305949007998528, Best f1: 0.8305949007998528
Epoch: 16/20, batch: 1374/1375, train_loss: 6.343265523323078e-06, t_sample_acc: 0.43649927219796214,lr: 1.6313393930156295e-11, batch_time: 1.7650213241577148, total_time: 2470.6102521419525 --------------

Validating:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch: 16/20, batch: 1375/1375, train_loss: 6.341546696272747e-06, t_sample_acc: 0.4364242424314672,lr: 5e-05, batch_time: 0.8841009140014648, total_time: 2471.508337497711 -------------


Validating: 100%|██████████| 250/250 [01:39<00:00,  2.51it/s]


Epoch: 16/20, val_sample_acc: 0.4205, val_rel_prec: 0.86708131766872, val_rel_rec: 0.8386809269162023, val_rel_f1: 0.8526446935732814,
val_time: 99.67811155319214
Current avf_f1: 0.8526446935732814, Best f1: 0.8526446935732814
Epoch: 17/20, batch: 1374/1375, train_loss: 6.4506647034652365e-06, t_sample_acc: 0.3988355167394469,lr: 2.5028559927002325e-05, batch_time: 1.7578082084655762, total_time: 2499.0904314517975 --------------

Validating:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch: 17/20, batch: 1375/1375, train_loss: 6.450740577913001e-06, t_sample_acc: 0.39854545454545454,lr: 2.5000000000000008e-05, batch_time: 0.8813259601593018, total_time: 2499.9938719272614 -------------


Validating: 100%|██████████| 250/250 [01:41<00:00,  2.47it/s]


Epoch: 17/20, val_sample_acc: 0.4595, val_rel_prec: 0.8762957843814587, val_rel_rec: 0.8554081403192971, val_rel_f1: 0.8657259899363325,
val_time: 101.31661486625671
Current avf_f1: 0.8657259899363325, Best f1: 0.8657259899363325
Epoch: 18/20, batch: 1374/1375, train_loss: 4.755216707425484e-06, t_sample_acc: 0.5264737991266376,lr: 1.6313393930156295e-11, batch_time: 1.7613909244537354, total_time: 2506.2628321647644 --------------

Validating:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch: 18/20, batch: 1375/1375, train_loss: 4.7543773016514025e-06, t_sample_acc: 0.5265757575902071,lr: 5e-05, batch_time: 0.8829264640808105, total_time: 2507.1632475852966 -------------


Validating: 100%|██████████| 250/250 [01:44<00:00,  2.40it/s]


Epoch: 18/20, val_sample_acc: 0.4875, val_rel_prec: 0.8977194194885764, val_rel_rec: 0.8631229235880208, val_rel_f1: 0.8800813007630075,
val_time: 104.19761347770691
Current avf_f1: 0.8800813007630075, Best f1: 0.8800813007630075
Epoch: 19/20, batch: 1374/1375, train_loss: 5.03668416756983e-06, t_sample_acc: 0.48034934497816595,lr: 2.5028559927002325e-05, batch_time: 1.7640719413757324, total_time: 2472.086092710495 ----------------

Validating:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch: 19/20, batch: 1375/1375, train_loss: 5.0350387408054136e-06, t_sample_acc: 0.4804848484992981,lr: 2.5000000000000008e-05, batch_time: 0.8874900341033936, total_time: 2472.99281001091 -------------


Validating: 100%|██████████| 250/250 [01:34<00:00,  2.63it/s]


Epoch: 19/20, val_sample_acc: 0.5015, val_rel_prec: 0.8958765261460286, val_rel_rec: 0.8766907123534519, val_rel_f1: 0.886179788032474,
val_time: 94.96892619132996
Current avf_f1: 0.886179788032474, Best f1: 0.886179788032474
Epoch: 20/20, batch: 1374/1375, train_loss: 3.752542988107898e-06, t_sample_acc: 0.5953420669577875,lr: 1.6313393930156295e-11, batch_time: 1.760580062866211, total_time: 2460.476558446884 ----------------

Validating:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch: 20/20, batch: 1375/1375, train_loss: 3.7528164102282478e-06, t_sample_acc: 0.595393939408389,lr: 5e-05, batch_time: 0.8824944496154785, total_time: 2461.3783543109894 -------------


Validating: 100%|██████████| 250/250 [01:34<00:00,  2.64it/s]


Epoch: 20/20, val_sample_acc: 0.537, val_rel_prec: 0.9129232895645955, val_rel_rec: 0.8816462736373553, val_rel_f1: 0.89701222267521,
val_time: 94.56996297836304
Current avf_f1: 0.89701222267521, Best f1: 0.89701222267521
