In [1]:
import os
import sys

from time import strftime, localtime
import logging
import random

import numpy as np
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import transformers
from transformers import BertTokenizer, BertModel

from sklearn import metrics
import spacy
from nltk.corpus import wordnet as wn

seed = 777

logger = logging.getLogger()
logger.setLevel(logging.INFO)
# logger.addHandler(logging.StreamHandler(sys.stdout))

transformers.logging.set_verbosity_error()

pretrained_bert_name = '/hy-tmp/models/bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(pretrained_bert_name)
max_seq_len = 100

img_dir = '/hy-tmp/data/dataset_image'
train_file = '/hy-tmp/data/processed_train.data'
valid_file = '/hy-tmp/data/processed_valid.data'
test_file = '/hy-tmp/data/processed_test.data'

model_name = 'CM_GCN'
check_point_path = '/hy-tmp/models'
log_file = f'/root/logs/{model_name}-{strftime("%y%m%d-%H%M", localtime())}.log'
result_file = f'/root/results/{model_name}_predicts.txt'
model_checkpoint = f'{check_point_path}/best_state/{model_name}'

inputs_cols = ['labels', 'box_vit', 'text_indices', 'graph']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

bert_dim = 768
vit_dim = 768
polarities_dim = 2
hidden_dim = 512

sp_nlp = spacy.load('en_core_web_sm')

filenames = os.listdir(img_dir)

def load_sentic_word():
    """
    load senticNet
    """
    path = '/hy-tmp/data/senticNet/senticnet_word.txt'
    senticNet = {}
    fp = open(path, 'r')
    for line in fp:
        line = line.strip()
        if not line:
            continue
        word, sentic = line.split('\t')
        senticNet[word] = float(sentic)
    fp.close()
    return senticNet

senticNet = load_sentic_word()

In [2]:
def get_sentic_score(word_i,word_j):
    if word_i not in senticNet or word_j not in senticNet or word_i == word_j:
        return 0
    return abs(float(senticNet[word_i] - senticNet[word_j])) * y**(-1*senticNet[word_i]*senticNet[word_j])

def get_doc(text, max_len=0):
    token_list = []
    text = text.lower().strip()
    
    document = sp_nlp(text)
    spacy_token = [str(x) for x in document]
    spacy_len = len(spacy_token)
    
    # if max_len > 0:
    #     if spacy_len > max_len:
    #         spacy_token = spacy_token[:max_len]

    s = ''
    for token in spacy_token:
        s = s + ' ' + token
    # document = sp_nlp(s)
    # spacy_token = [str(x) for x in document]
    return document, s.strip(), spacy_token

def generate_dep_graph(document):
    spacy_token = [str(x) for x in document]
    spacy_len = len(spacy_token)

    graph = np.identity(spacy_len).astype('float32')

    for token in document:
        for child in token.children:
            graph[token.i][child.i] = 1
            graph[child.i][token.i] = 1

    return graph

def generate_image_graph(document, attribute_object):
    spacy_token = [str(x) for x in document]
    spacy_len = len(spacy_token)
    assert len(attribute_object) == 10

    graph = np.zeros((spacy_len, 10)).astype('float32')
    for i,token_i in enumerate(spacy_token):
        cur = 0
        si = wn.synsets(token_i)
        if len(si) == 0:
            continue
        si = si[0]
        for attr,obj in attribute_object:
            sattr = wn.synsets(attr)
            sobj = wn.synsets(obj)
            if len(sattr)==0 or len(sobj)==0:
                cur += 1
                continue
            sobj = sobj[0]
            sattr = sattr[0]
            graph[i][cur] = wn.path_similarity(si,sobj) * get_sentic_score(si,sattr)
            cur += 1

    return graph

def pad_and_truncate(sequence, maxlen, dtype='int64', padding='post', truncating='post', value=0):
    x = (np.ones(maxlen) * value).astype(dtype)
    if truncating == 'pre':
        trunc = sequence[-maxlen:]
    else:
        trunc = sequence[:maxlen]
    trunc = np.asarray(trunc, dtype=dtype)
    if padding == 'post':
        x[:len(trunc)] = trunc
    else:
        x[-len(trunc):] = trunc
    return x

class gcn_Dataset(Dataset):
    def __init__(self, data_file):
        self.tokenizer = BertTokenizer.from_pretrained(pretrained_bert_name)
        data = pickle.load(open(data_file,'rb'))

        print("{}.data".format(data_file))
        all_data = []
        for key,value in data.items():
            img_id = value['id']
            label = int(value['label'])
            
            text = value['text']
            attribute_objects = value['attribute_objects']
            text_in_img = value['text_in_img']
            box_vit = value["box_vit"]
            box_vit = [x.numpy() for x in box_vit]
            
            data_ = {
                'img_id': img_id,
                'label':label,
                'box_vit':box_vit,
                'text':text,
                'text_in_img':text_in_img,
                'attribute_objects':attribute_objects,
            }
            all_data.append(data_)
        self.data = all_data
     
    def text_to_indices(self, text, text_pair=None):
        if text_pair is None:
            encoded_dict = self.tokenizer(
                                text,                      # Sentence to encode.
                                add_special_tokens = True, # Add '[CLS]' and '[SEP]'
                                padding = 'max_length',
                                truncation = True,
                                max_length = max_seq_len,    # Pad & truncate all sentences.
                                return_attention_mask = True,   # Construct attn. masks.
                                return_tensors = 'np',     # Return pytorch tensors.
                                return_length = True,
                                is_split_into_words = True,
                           )

        else:
            encoded_dict = self.tokenizer(
                            text,                      # Sentence to encode.
                            text_pair,
                            add_special_tokens = True, # Add '[CLS]' and '[SEP]'
                            padding = 'max_length',
                            truncation = 'longest_first',
                            max_length = max_seq_len,    # Pad & truncate all sentences.
                            return_attention_mask = True,   # Construct attn. masks.
                            return_tensors = 'np',     # Return pytorch tensors.
                            return_length = True,
                            is_split_into_words = True,
                       )
        return encoded_dict
    
    def my_collate_fn(self, data):
        # use bert tokenizer, no graph returned
        b_img_id = []
        b_label = []
        b_box_vit = []
        b_text_indices = []
        b_text_in_img_indices = []
        b_text_merge_indices = []
        b_graph = []

        for item in data:
            b_img_id.append(item['img_id'])
            b_label.append(item['label'])
            b_box_vit.append(item['box_vit'])

            text = item['text']
            text_in_img = item['text_in_img']
            attribute_objects = item['attribute_objects']

            text_doc,_,text_token = get_doc(text)
            text_in_img_doc,_,text_in_img_token = get_doc(text_in_img)
            if not text_token:
                text_token = ['']
            if not text_in_img_token:
                text_in_img_token = ['']
            
            b_text_indices.append(text_token)
            b_text_in_img_indices.append(text_in_img_token)
        
        text_encoded_dict = self.text_to_indices(b_text_indices)
        text_in_img_encoded_dict = self.text_to_indices(b_text_in_img_indices)
        text_merge_encoded_dict = self.text_to_indices(b_text_indices, b_text_in_img_indices)
        
        return {
                    'img_ids': b_img_id,
                    'labels': torch.tensor(b_label),
                    'box_vit':torch.tensor(np.array(b_box_vit)),
                    'text_indices':torch.tensor(text_encoded_dict.input_ids),
                    'text_in_img_indices':torch.tensor(text_in_img_encoded_dict.input_ids),
                    'text_merge_indices':torch.tensor(text_merge_encoded_dict.input_ids),
                    }
    
    def my_collate_fn2(self, data):
        # use spacy tokenizer, graph returned
        b_img_id = []
        b_label = []
        b_box_vit = []
        b_text_indices = []
        b_text_in_img_indices = []
        b_text_merge_indices = []
        b_dep_graph = []
        b_img_graph = []
        b_graph = []

        for item in data:
            b_img_id.append(item['img_id'])
            b_label.append(item['label'])
            b_box_vit.append(item['box_vit'])

            text = item['text']
            text_in_img = item['text_in_img']
            attribute_objects = item['attribute_objects']

            text_doc,_,text_token = get_doc(text)
            text_in_img_doc,_,text_in_img_token = get_doc(text_in_img)
            
            text_token = ["[CLS]"] + text_token[:max_seq_len] + ["[SEP]"]
            text_in_img_token = ["[CLS]"] + text_in_img_token[:max_seq_len] + ["[SEP]"]
            text_merge_token = text_token + ["[SEP]"] + text_in_img_token
            text_merge_token = ["[CLS]"] + text_merge_token[:max_seq_len] + ["[SEP]"]
            
            pad_text_ids = np.zeros(max_seq_len+2).astype('int64')
            text_ids = self.tokenizer.convert_tokens_to_ids(text_token)
            pad_text_ids[:len(text_ids)] = text_ids
            
            pad_text_in_img_ids = np.zeros(max_seq_len+2).astype('int64')
            text_in_img_ids = self.tokenizer.convert_tokens_to_ids(text_in_img_token)
            pad_text_in_img_ids[:len(text_in_img_ids)] = text_in_img_ids
            
            pad_merge_text_ids = np.zeros(max_seq_len+2).astype('int64')
            text_merge_ids = self.tokenizer.convert_tokens_to_ids(text_merge_token)
            pad_merge_text_ids[:len(text_merge_ids)] = text_merge_ids
            
            dep_graph = generate_dep_graph(text_doc)
            img_graph = generate_image_graph(text_doc, attribute_objects)
            b_dep_graph.append(dep_graph)
            b_img_graph.append(img_graph)
            
            b_text_indices.append(pad_text_ids)
            b_text_in_img_indices.append(pad_text_in_img_ids)
            b_text_merge_indices.append(pad_merge_text_ids)
            
            
        text_len = np.count_nonzero(b_text_indices, axis=-1)
        max_text_len = np.max(text_len)
        
        for s, item in enumerate(data):
            dep_graph = b_dep_graph[s][:max_seq_len, :max_seq_len][:]
            img_graph = b_img_graph[s][:max_seq_len, :]
            dep_graph = np.pad(dep_graph, ((1,1), (1,1)), 'constant')
            dep_graph[0][0] = 1
            dep_graph[-1][-1] = 1
            assert text_len[s] == np.size(dep_graph,0)
            img_graph = np.pad(img_graph, ((1,1), (0,0)), 'constant')
            
            # print(np.size(dep_graph,0), np.size(dep_graph,1))
            dep_graph = np.pad(dep_graph, ((0,max_text_len-text_len[s]), (0,max_text_len-text_len[s])), 'constant')
            img_graph = np.pad(img_graph, ((0,max_text_len-text_len[s]), (0,0)), 'constant')
            # print(np.size(dep_graph,0), np.size(dep_graph,1))
            graph = np.pad(dep_graph, ((0,10), (0,10)), 'constant')
            # print(np.size(graph,0), np.size(graph,1))
            for i in range(max_text_len+10):
                graph[i][i] = 1
            for i in range(text_len[s]):
                for j in range(10):
                    graph[i][text_len[s] + j] = img_graph[i][j] + 1
                    graph[text_len[s] + j][i] = img_graph[i][j] + 1
            b_graph.append(graph)
        
        return {
                    'img_ids': b_img_id,
                    'labels': torch.tensor(b_label),
                    'box_vit':torch.tensor(np.array(b_box_vit)),
                    'text_indices':torch.tensor(np.array(b_text_indices)),
                    'text_in_img_indices':torch.tensor(np.array(b_text_in_img_indices)),
                    'text_merge_indices':torch.tensor(np.array(b_text_merge_indices)),
                    'graph':torch.tensor(np.array(b_graph)),
                    }
    
    def my_collate_fn3(self, data):
        # use spacy tokenizer, no graph returned
        b_img_id = []
        b_label = []
        b_box_vit = []
        b_text_indices = []
        b_text_in_img_indices = []
        b_text_merge_indices = []

        for item in data:
            b_img_id.append(item['img_id'])
            b_label.append(item['label'])
            b_box_vit.append(item['box_vit'])

            text = item['text']
            text_in_img = item['text_in_img']
            attribute_objects = item['attribute_objects']

            text_doc,_,text_token = get_doc(text)
            text_in_img_doc,_,text_in_img_token = get_doc(text_in_img)
            
            text_merge_token = text_in_img_token + ["[SEP]"] + text_token
            text_token = ["[CLS]"] + text_token[:max_seq_len] + ["[SEP]"]
            text_in_img_token = ["[CLS]"] + text_in_img_token[:max_seq_len] + ["[SEP]"]
            text_merge_token = ["[CLS]"] + text_merge_token[:max_seq_len] + ["[SEP]"]
            
            pad_text_ids = np.zeros(max_seq_len+2).astype('int64')
            text_ids = self.tokenizer.convert_tokens_to_ids(text_token)
            pad_text_ids[:len(text_ids)] = text_ids
            
            pad_text_in_img_ids = np.zeros(max_seq_len+2).astype('int64')
            text_in_img_ids = self.tokenizer.convert_tokens_to_ids(text_in_img_token)
            pad_text_in_img_ids[:len(text_in_img_ids)] = text_in_img_ids
            
            pad_merge_text_ids = np.zeros(max_seq_len+2).astype('int64')
            text_merge_ids = self.tokenizer.convert_tokens_to_ids(text_merge_token)
            pad_merge_text_ids[:len(text_merge_ids)] = text_merge_ids
            
            b_text_indices.append(pad_text_ids)
            b_text_in_img_indices.append(pad_text_in_img_ids)
            b_text_merge_indices.append(pad_merge_text_ids)
            
        return {
                    'img_ids': b_img_id,
                    'labels': torch.tensor(b_label),
                    'box_vit':torch.tensor(np.array(b_box_vit)),
                    'text_indices':torch.tensor(np.array(b_text_indices)),
                    'text_in_img_indices':torch.tensor(np.array(b_text_in_img_indices)),
                    'text_merge_indices':torch.tensor(np.array(b_text_merge_indices)),
                    }
    
    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)

In [3]:
train_dataset = gcn_Dataset(data_file=train_file)
valid_dataset = gcn_Dataset(data_file=valid_file)
test_dataset = gcn_Dataset(data_file=test_file)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=train_dataset.my_collate_fn2)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, collate_fn=valid_dataset.my_collate_fn2)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=test_dataset.my_collate_fn2)

print(train_dataset.__len__(), valid_dataset.__len__(), test_dataset.__len__())

/hy-tmp/data/processed_train.data.data
/hy-tmp/data/processed_valid.data.data
/hy-tmp/data/processed_test.data.data
19816 2410 2409


In [4]:
class DynamicLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=True, dropout=0,
                 bidirectional=False, only_use_last_hidden_state=False, rnn_type = 'LSTM'):
        """
        LSTM which can hold variable length sequence, use like TensorFlow's RNN(input, length...).

        :param input_size:The number of expected features in the input x
        :param hidden_size:The number of features in the hidden state h
        :param num_layers:Number of recurrent layers.
        :param bias:If False, then the layer does not use bias weights b_ih and b_hh. Default: True
        :param batch_first:If True, then the input and output tensors are provided as (batch, seq, feature)
        :param dropout:If non-zero, introduces a dropout layer on the outputs of each RNN layer except the last layer
        :param bidirectional:If True, becomes a bidirectional RNN. Default: False
        :param rnn_type: {LSTM, GRU, RNN}
        """
        super(DynamicLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bias = bias
        self.batch_first = batch_first
        self.dropout = dropout
        self.bidirectional = bidirectional
        self.only_use_last_hidden_state = only_use_last_hidden_state
        self.rnn_type = rnn_type
        
        if self.rnn_type == 'LSTM': 
            self.RNN = nn.LSTM(
                input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
                bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional)  
        elif self.rnn_type == 'GRU':
            self.RNN = nn.GRU(
                input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
                bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional)
        elif self.rnn_type == 'RNN':
            self.RNN = nn.RNN(
                input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
                bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional)
        

    def forward(self, x, x_len, h0=None):
        """
        sequence -> sort -> pad and pack ->process using RNN -> unpack ->unsort

        :param x: sequence embedding vectors
        :param x_len: numpy/tensor list
        :return:
        """
        """sort"""
        x_sort_idx = torch.argsort(-x_len)
        x_unsort_idx = torch.argsort(x_sort_idx).long()
        x_len = x_len[x_sort_idx]
        x = x[x_sort_idx.long()]
        """pack"""
        x_emb_p = torch.nn.utils.rnn.pack_padded_sequence(x, x_len, batch_first=self.batch_first)
        
        if self.rnn_type == 'LSTM':
            if h0 is None: 
                out_pack, (ht, ct) = self.RNN(x_emb_p, None)
            else:
                out_pack, (ht, ct) = self.RNN(x_emb_p, (h0, h0))
        else: 
            if h0 is None:
                out_pack, ht = self.RNN(x_emb_p, None)
            else:
                out_pack, ht = self.RNN(x_emb_p, h0)
            ct = None
        """unsort: h"""
        ht = torch.transpose(ht, 0, 1)[
            x_unsort_idx]  
        ht = torch.transpose(ht, 0, 1)

        if self.only_use_last_hidden_state:
            return ht
        else:
            """unpack: out"""
            out = torch.nn.utils.rnn.pad_packed_sequence(out_pack, batch_first=self.batch_first)
            out = out[0]  #
            out = out[x_unsort_idx]
            """unsort: out c"""
            if self.rnn_type =='LSTM':
                ct = torch.transpose(ct, 0, 1)[
                    x_unsort_idx]
                ct = torch.transpose(ct, 0, 1)

            return out, (ht, ct)

class GraphConvolution(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)

    def forward(self, text, adj):
        hidden = torch.matmul(text, self.weight)
        denom = torch.sum(adj, dim=2, keepdim=True) + 1
        output = torch.matmul(adj, hidden.float()) / denom
        if self.bias is not None:
            return output + self.bias
        else:
            return output

class CM_GCN(nn.Module):
    def __init__(self, pretrained_bert_name):
        super(CM_GCN, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_bert_name)
        self.lstm1 = DynamicLSTM(bert_dim, hidden_dim, num_layers=1, batch_first=True, bidirectional=True)
        
        self.vit_fc = nn.Linear(vit_dim, 2*hidden_dim)
        
        
        self.gc1 = GraphConvolution(2*hidden_dim, 2*hidden_dim)
        self.gc2 = GraphConvolution(2*hidden_dim, 2*hidden_dim)
        #self.gc3 = GraphConvolution(2*hidden_dim, 2*hidden_dim)
        #self.gc4 = GraphConvolution(2*hidden_dim, 2*hidden_dim)
        
        self.fc = nn.Linear(2*hidden_dim, polarities_dim)
        
    
    def forward(self, inputs):
        labels, box_vit, text_indices, graph = inputs
        bs = text_indices.shape[0]
        # print(text_indices.shape, box_vit.shape, graph.shape)
        text_len = torch.sum(text_indices != 0, dim=-1)
        text_out = self.bert(text_indices,  output_hidden_states=False)
        text_out, (_, _) = self.lstm1(text_out.last_hidden_state, text_len.cpu())

        box_vit = self.vit_fc(box_vit)
        features = torch.cat([text_out,box_vit],dim = 1)
        # print(text_out.shape, box_vit.shape, graph.shape)
        
        x = F.relu(self.gc1(features, graph))
        x = F.relu(self.gc2(x, graph))

        alpha_mat = torch.matmul(features, x.transpose(1, 2))
        alpha = F.softmax(alpha_mat.sum(1, keepdim=True), dim=2)
        x = torch.matmul(alpha, x).squeeze(1)
        output = self.fc(x)
        return output

    def reset_params(self):
        for child in self.children():
            if child != self.bert:
                for p in child.parameters():
                    if p.requires_grad:
                        if len(p.shape) > 1:
                            torch.nn.init.xavier_uniform_(p)
                        else:
                            stdv = 1. / math.sqrt(p.shape[0])
                            torch.nn.init.uniform_(p, a=-stdv, b=stdv)


In [5]:
def eval_(model, data_loader, save_path=None):
    n_correct, n_total = 0, 0
    t_targets_all, t_outputs_all = None, None
    model.eval()
    
    with torch.no_grad():
        for i_batch, t_batch in enumerate(data_loader):
            t_inputs = [t_batch[col].to(device)   for col in inputs_cols]
            t_targets = t_batch['labels'].to(device)
            t_img_ids = t_batch['img_ids']
            
            t_outputs = model(t_inputs)

            n_correct += (torch.argmax(t_outputs, -1) == t_targets).sum().item()
            n_total += len(t_outputs)

            if t_targets_all is None:
                t_targets_all = t_targets
                t_outputs_all = t_outputs
                t_img_ids_all = t_img_ids
            else:
                t_targets_all = torch.cat((t_targets_all, t_targets), dim=0)
                t_outputs_all = torch.cat((t_outputs_all, t_outputs), dim=0)
                t_img_ids_all += t_img_ids
    
    if save_path:
        with open(save_path,'w',encoding='utf-8') as fout:
            img_ids_all = t_img_ids_all
            predicts_all = torch.argmax(t_outputs_all, -1).cpu().numpy().tolist()
            labels_all = t_targets_all.cpu().numpy().tolist()
            outputs_all = t_outputs_all.cpu().numpy().tolist()
            assert len(img_ids_all) == len(predicts_all) == len(labels_all) == len(outputs_all)
            
            for i in range(len(img_ids_all)):
                img_id = img_ids_all[i]
                predict = predicts_all[i]
                label = labels_all[i]
                output = outputs_all[i]
                fout.write(f'{str(img_id)} {str(predict)} {str(label)} {str(output)} \n')

    acc = n_correct / n_total
    f1 = metrics.f1_score(t_targets_all.cpu(), torch.argmax(t_outputs_all, -1).cpu())
    precision =  metrics.precision_score(t_targets_all.cpu(),torch.argmax(t_outputs_all, -1).cpu())
    recall = metrics.recall_score(t_targets_all.cpu(),torch.argmax(t_outputs_all, -1).cpu())
    return acc, f1 ,precision, recall

def train(model, train_data_loader, val_data_loader, test_data_loader):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam([{'params':model.bert.parameters(),'lr':2e-5},
                            {'params':model.fc.parameters(),'lr':1e-3} ], lr=1e-3, weight_decay=1e-5)
    global_step = 0
    max_val_acc = 0
    max_val_f1 = 0
    max_val_epoch = 0
    
    model.reset_params()
    
    for i_epoch in range(100):
        logger.info('>' * 100)
        logger.info('epoch: {}'.format(i_epoch))
        n_correct, n_total, loss_total = 0, 0, 0

        for i_batch, batch in enumerate(train_data_loader):
            model.train()
            global_step += 1

            inputs = [batch[col].to(device)   for col in inputs_cols]
            outputs = model(inputs)
            targets = batch['labels'].to(device)

            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            n_correct += (torch.argmax(outputs, -1) == targets).sum().item()
            n_total += len(outputs)
            loss_total += loss.item() * len(outputs)

            train_acc = n_correct / n_total
            train_loss = loss_total / n_total
            logger.info('loss: {:.4f}, acc: {:.4f}'.format(train_loss, train_acc))

            if global_step % 20 == 0:
                val_acc, val_f1,val_precision,val_recall = eval_(model, val_data_loader)
                logger.info('> max_val_f1: {:.4f}, max_val_acc: {:.4f}'.format(max_val_f1,max_val_acc))
                logger.info('> val_acc: {:.4f}, val_f1: {:.4f}, val_precision: {:.4f}, val_recall: {:.4f}'.format(val_acc,val_f1,val_precision,val_recall))

                if val_acc > max_val_acc:
                    max_val_f1 = val_f1
                    max_val_acc = val_acc
                    max_val_epoch = i_epoch
                    
                    torch.save(model.state_dict(), model_checkpoint)
                    logger.info(f'>> saved: {model_checkpoint}')

        torch.save(model.state_dict(), model_checkpoint)
        if i_epoch - max_val_epoch >= 3:
            logger.info('>> early stop.')
            break

    model.load_state_dict(torch.load(model_checkpoint))
    model = model.to(device)

    test_acc, test_f1,test_precision,test_recall = eval_(model, test_data_loader, save_path=result_file)
    
    logger.info(f"{test_acc} {test_f1} {test_precision} {test_recall}")

    return (test_acc, test_f1,test_precision,test_recall)

In [6]:
def main():
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    # logger.addHandler(logging.FileHandler(log_file))
    
    model = CM_GCN(pretrained_bert_name).to(device)
    
    # train(model, train_loader, valid_loader, test_loader)
    
    model.load_state_dict(torch.load(model_checkpoint))
    model = model.to(device)
    print(eval_(model, test_loader, save_path=result_file))
    
main()

(0.8298048982980489, 0.7945891783567134, 0.7647058823529411, 0.8269030239833159)
