In [1]:
import json
import torch
from tqdm import tqdm
from torch import nn
from transformers import BertJapaneseTokenizer

In [2]:
BATCH_SIZE = 5

In [3]:
BERT = torch.load(f'../savepoint/bert-fm/bert-fm-42.pt')
BERT = BERT.module.to('cuda')

BERT_WEIGHT = BERT.classifier.weight
BERT_BIAS = BERT.classifier.bias

bert_w_lst = BERT_WEIGHT.tolist()
grand_w_lst_l = []
grand_w_lst_r = []
for i in range(BATCH_SIZE):
    grand_w_lst_l += bert_w_lst[0]
    grand_w_lst_r += bert_w_lst[1]
BERT_WEIGHTED_WEIGHT = torch.tensor([grand_w_lst_l, grand_w_lst_r]).to('cuda')

In [4]:
def classification(work):
    tensor_ids = torch.tensor(work['batch']).to('cuda')
    out = BERT(tensor_ids, output_hidden_states=True)
    
    last_hidden_state = out.hidden_states[-1]
    pooler_output = BERT.bert.pooler(last_hidden_state)
    
    average = cls_average(pooler_output)
    grand = cls_grand(pooler_output)
    
    softmax = nn.Softmax(dim=0)
    
    work_json = {
        'label': work['label'],
        'avg': average.tolist(),
        'grand': grand.tolist(),
        'avg_softmax': softmax(average).tolist(),
        'grand_softmax': softmax(grand).tolist()
    }
    return work_json

In [5]:
def cls_average(pooler_output):
    cls_sum = list(0.0 for i in range(768))
    for pool_cls in pooler_output:
        cls_sum = [a+b for a, b in zip(pool_cls, cls_sum)]
    cls_avg = torch.tensor([a/BATCH_SIZE for a in cls_sum]).to('cuda')
    return BERT.classifier(cls_avg)

In [6]:
def cls_grand(pooler_output):
    cls_append = []
    for pool_cls in pooler_output:
        cls_append += pool_cls
    cls_gnd = torch.tensor(cls_append).to('cuda')
    
    gnd_linear = nn.Linear(768*BATCH_SIZE, 2).to('cuda')
    nn.init.normal_(gnd_linear.weight, std=0.0001)
    nn.init.normal_(gnd_linear.bias, 0.01)
    gnd_linear.weight = torch.nn.Parameter(BERT_WEIGHTED_WEIGHT)
    gnd_linear.bias = BERT_BIAS
    
    return gnd_linear(cls_gnd)

In [9]:
works_path = "../tsv/first-match-scatter/42/test.json"
with open(works_path, "r") as f:
    works = json.load(f)

tokenizer = BertJapaneseTokenizer.from_pretrained("cl-tohoku/bert-base-japanese-whole-word-masking", return_tensors='pt', padding='max_length', max_length=1024)
    
iter_works = []
for work in tqdm(works[:20]):
    label = work['label']
    contents = [w['paragraph'] for w in work['contents']]
    batch = tokenizer.batch_encode_plus(contents, pad_to_max_length=True, max_length=512, truncation=True, add_special_tokens=True)
    print(batch)
    iter_works.append({'label': label, 'batch': batch['input_ids']})

100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 113.47it/s]

{'input_ids': [[2, 4153, 18, 1805, 28517, 28475, 14, 6, 3228, 5, 52, 29140, 7, 3138, 29228, 15, 16, 33, 5228, 5, 11385, 12098, 12, 9, 5410, 16577, 14, 2463, 4627, 308, 10, 5, 7, 6, 4794, 5, 284, 7, 29, 393, 9, 3741, 14, 7181, 1165, 16, 21, 10, 2972, 11, 1878, 7244, 3741, 419, 7, 6, 5423, 28564, 9, 2649, 255, 2972, 5, 393, 7, 21494, 11, 1100, 10, 73, 23813, 2375, 6, 2606, 28, 3246, 9, 1541, 28575, 75, 283, 14, 2575, 4794, 5, 284, 7, 29, 563, 7, 1276, 16, 21, 10, 5903, 9318, 14, 6, 5880, 18, 2889, 28549, 2126, 11834, 387, 11, 735, 895, 1450, 5, 5693, 11, 287, 6, 240, 559, 15883, 24607, 10, 5423, 28564, 9, 70, 11704, 1504, 199, 5, 73, 7252, 21, 2375, 11, 212, 7642, 13, 6, 1285, 28736, 7, 6063, 4605, 28614, 11, 10357, 28844, 16, 1177, 11, 856, 7, 22608, 10, 4422, 6, 2941, 223, 28781, 2992, 5, 12, 6, 73, 704, 7, 2753, 28469, 255, 1778, 9, 3083, 16, 28, 9444, 15, 16, 33, 5, 14, 1661, 14604, 75, 5062, 8415, 9, 11353, 1058, 12, 28, 80, 14, 18, 3059, 895, 6, 5903, 9318, 9, 10127, 5, 109, 5, 130




In [8]:
for work in iter_works:
    print(classification(work))

{'label': 0, 'avg': [-0.20940428972244263, 0.5997509360313416], 'grand': [-1.0465319156646729, 2.998265027999878], 'avg_softmax': [0.308070570230484, 0.6919294595718384], 'grand_softmax': [0.017211824655532837, 0.9827881455421448]}
{'label': 1, 'avg': [-0.17428520321846008, 0.5690593719482422], 'grand': [-0.8709368109703064, 2.8448076248168945], 'avg_softmax': [0.32227322459220886, 0.6777268052101135], 'grand_softmax': [0.023759083822369576, 0.9762409329414368]}
{'label': 1, 'avg': [0.11449330300092697, 0.2659098505973816], 'grand': [0.5729559063911438, 1.3290596008300781], 'avg_softmax': [0.4622180163860321, 0.5377819538116455], 'grand_softmax': [0.31949278712272644, 0.6805071830749512]}
{'label': 1, 'avg': [0.2964443564414978, 0.0020526384469121695], 'grand': [1.482710838317871, 0.009773526340723038], 'avg_softmax': [0.5730709433555603, 0.4269290268421173], 'grand_softmax': [0.8135033845901489, 0.1864965558052063]}
{'label': 1, 'avg': [0.5158535838127136, -0.28108707070350647], 'gran