In [6]:
import json

from tokenizers.implementations import BertWordPieceTokenizer
from Bert import Bert
from layers import Train
import torch
from torch.nn import Linear
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

_tokenizer = BertWordPieceTokenizer("../custom/vocab.txt")
_embedding_dim = 384
_hidden_size = 3072
_num_head = 12
_out_dim = 512
max_epoch = 17
batch = 140
_num_layers = 12
vocab_size = _tokenizer.get_vocab_size()
bert = Bert(_embedding_dim, _hidden_size, _num_head, 128, _num_layers, _tokenizer)
bert.load_state_dict(torch.load("../bert_impl_weights/down_stream_bert_emo.pth"))
bert.eval()

In [7]:
from torch.nn.utils.rnn import pad_sequence

with open("../bert_impl_data/emo_class.json", "r", encoding="utf-8") as fp:
    label = json.load(fp)
    id_label = {}
    for key, val in label.items():
        id_label[val] = key

In [8]:
with open("../bert_impl_data/emo_test.txt", "r", encoding="utf-8") as fp:
    data = fp.readlines()
    sentence_val = []
    labels_val = []
    max_len = 0
    for idx, i in enumerate(data):
        datas = data[idx].split(';')
        l = datas[1].replace('\n', '')
        sentence_val.append(datas[0])
        labels_val.append(l)
    for idx, i in enumerate(labels_val):
        labels_val[idx] = label[labels_val[idx]]
    for idx, sent in enumerate(sentence_val):
        sentence_val[idx] = torch.tensor(_tokenizer.encode(sent).ids, dtype=torch.long, device=device)
        if len(sentence_val[idx]) >= max_len:
            max_len = len(sentence_val[idx])
    sentence_val = pad_sequence(sentence_val, batch_first=True)
    labels_val = torch.tensor(labels_val)
sentence_val

tensor([[    2,  2399, 10133,  ...,     0,     0,     0],
        [    2,  2399,  5013,  ...,     0,     0,     0],
        [    2,    50,  3336,  ...,     0,     0,     0],
        ...,
        [    2,    50,  5271,  ...,     0,     0,     0],
        [    2,  2399, 10133,  ...,     0,     0,     0],
        [    2,    50,  5271,  ...,     0,     0,     0]], device='cuda:0')

In [9]:
with torch.no_grad():
    layers = Linear(384, len(label), device=bert.device)
    layers.load_state_dict(torch.load("../bert_impl_weights/down_stream_bert_emo_layer.pth"))
    questions = sentence_val[:10]
    score = bert.forward(questions, (questions != 0).to(device=device))
    score = layers(score)[:, 0, :]
    predict = torch.nn.Softmax(dim=-1)(score)
    predict = torch.argmax(predict, dim=-1)
    ls = questions.tolist()
    predict = predict.tolist()
    for idx, i in enumerate(predict):
        predict[idx] = id_label[i]
    for idx, i in enumerate(ls):
        ls[idx] = _tokenizer.decode(i) + f". [predict: {predict[idx]}]"

In [10]:
ls

['im feeling rather rotten so im not very ambitious right now. [predict: sadness]',
 'im updating my blog because i feel shitty. [predict: sadness]',
 'i never make her separate from me because i don t ever want her to feel like i m ashamed with her. [predict: sadness]',
 'i left with my bouquet of red and yellow tulips under my arm feeling slightly more optimistic than when i arrived. [predict: joy]',
 'i was feeling a little vain when i did this one. [predict: sadness]',
 "i cant walk into a shop anywhere where i don't feel uncomfortable. [predict: fear]",
 'i felt anger when at the end of a telephone call. [predict: fear]',
 'i explain why i clung to a relationship with a boy who was in many ways immature and uncommitted despite the excitement i should have been feeling for getting accepted into the masters program at the university of virginia. [predict: love]',
 'i like to have the same breathless feeling as a reader eager to see what will happen next. [predict: joy]',
 'i jest i 