In [1]:
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification, logging
import tensorflow as tf
import numpy as np
import re
import string
from nltk.corpus import wordnet
from tqdm import tqdm
from sklearn import metrics
import csv
from tensorflow import keras

logging.set_verbosity_error()
np.random.seed(0)

In [2]:
# parameters
MODEL = "bert-base-uncased"
X_TRAIN = '../dataset/HF-BERT_x_train.npy'
Y_TRAIN = '../dataset/HF-BERT_y_train.npy'
X_TEST = '../dataset/HF-BERT_x_test.npy'
Y_TEST = '../dataset/HF-BERT_y_test.npy'
SAVED_MODEL = "../Baseline-HF-BERT.h5"
EPOCH = 1
BATCH_SIZE = 2

In [3]:
# 前処理
def preprocessing(text):
    # 括弧内文章の削除
    text = re.sub(r'\(.*\)',' ',text)
    text = re.sub(r'\[.*\]',' ',text)
    text = re.sub(r'\<.*\>',' ',text)
    text = re.sub(r'\{.*\}',' ',text)
    # 記号文字の削除
    text = text.translate(str.maketrans('','',string.punctuation))
    # スペースの調整
    text = re.sub(r'\s+',' ',text)
    return text

In [4]:
# preprocessing train data -----------------------------------------------------------------------
# load topic class labels
print("making train dataset...")
with open('../data/topic/classes.txt','r',encoding='utf-8') as f:
    labels = f.read().splitlines()
topic_class_hypothesis = dict()
for i,label in enumerate(labels):
    topic_class_hypothesis[i] = 'this text is about ' + ' or '.join([wordnet.synsets(word)[0].definition() for word in label.split(' & ')])

# load train data
with open('../data/topic/train_pu_half_v0.txt','r',encoding='utf-8') as f:
    texts_v0 = f.read()
with open('../data/topic/train_pu_half_v1.txt','r',encoding='utf-8') as f:
    texts_v1 = f.read()
texts = texts_v0 + texts_v1

texts = texts_v0

# ## example -------------------------------------
import random
texts = texts.splitlines()
texts = random.sample(texts,1000)
texts = "\n".join(texts)
# ## ---------------------------------------------

tokenizer = AutoTokenizer.from_pretrained(MODEL)

x_train, y_train = [],[]
first, second = [],[]
for label_text in tqdm(texts.splitlines()):
    label,text = label_text.split('\t')
    rand_base = [0,1,2,3,4,5,6,7,8,9]
    rand_base.remove(int(label))
    label_rand = np.random.choice(rand_base)
    first.append(preprocessing(text))
    second.append(topic_class_hypothesis[int(label)])
    y_train.append(1)
    first.append(preprocessing(text))
    second.append(topic_class_hypothesis[int(label_rand)])
    y_train.append(0)

x_train = tokenizer(first, second, truncation=True, return_tensors="tf", padding="max_length", max_length=512)

making train dataset...


100%|██████████| 1000/1000 [00:00<00:00, 17241.42it/s]


In [5]:
y_train = np.array(y_train)

model = TFAutoModelForSequenceClassification.from_pretrained(MODEL)
model.classifier = tf.keras.layers.Dense(units=1, activation="sigmoid", name="classifier")
model.compile(optimizer=keras.optimizers.Adam(3e-5),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),
              metrics=tf.keras.metrics.BinaryAccuracy())
model.fit([x_train["input_ids"],x_train["attention_mask"]], y_train, epochs=EPOCH, batch_size=BATCH_SIZE)



<keras.callbacks.History at 0x1b23ea49b40>

In [6]:
pred = model.predict([x_train["input_ids"],x_train["attention_mask"]], batch_size=BATCH_SIZE)
print(pred.logits)
y_pred = np.where(pred.logits<0.5, 0, 1)

rep = metrics.classification_report(y_train,y_pred,digits=3)
print(rep)

[[7.9871696e-01]
 [3.1703088e-04]
 [8.0799925e-01]
 ...
 [3.1677075e-04]
 [8.0644631e-01]
 [8.0094749e-01]]
              precision    recall  f1-score   support

           0      1.000     0.540     0.701      1000
           1      0.685     1.000     0.813      1000

    accuracy                          0.770      2000
   macro avg      0.842     0.770     0.757      2000
weighted avg      0.842     0.770     0.757      2000



In [9]:
# dbpedia class ------------------------------------------------------------------------------------------------------
with open('../data/dbpedia_csv/classes.txt','r',encoding='utf-8') as f:
    classes = f.read().splitlines()
    dbpedia_class = ['this text is about '+text for text in classes]

with open('../data/dbpedia_csv/test.csv','r',encoding='utf-8') as f:
    reader = [r for r in csv.reader(f)]
    
# # example -------------------
import random
reader = random.sample(reader,1000)
# #----------------------------

x_test, y_test = [],[]
first, second = [],[]
for cls_num,auth,readtext in tqdm(reader,total=len(reader)):
    for db_class in dbpedia_class:
        text = readtext.replace(auth, "")
        first.append(preprocessing(text))
        second.append(db_class)
    y_test.append(int(cls_num)-1)           

x_test = tokenizer(first, second, truncation=True, return_tensors="tf", padding="max_length", max_length=512)   

100%|██████████| 1000/1000 [00:00<00:00, 5278.91it/s]


In [10]:
pred = model.predict([x_test["input_ids"],x_test["attention_mask"]], batch_size=BATCH_SIZE)
split_pred = np.array_split(pred.logits,len(y_test))
y_pred = [np.argmax(p) for p in split_pred]

target_names = [c[:3]+"." for c in classes]
rep = metrics.classification_report(y_test,y_pred,target_names=target_names,digits=3)
print(rep)

              precision    recall  f1-score   support

        Com.      0.034     0.085     0.049        59
        Edu.      0.147     0.319     0.202        72
        Art.      0.333     0.015     0.028        68
        Ath.      0.000     0.000     0.000        66
        Off.      0.133     0.730     0.225        74
        Mea.      0.000     0.000     0.000        63
        Bui.      0.562     0.118     0.196        76
        Nat.      0.206     0.084     0.120        83
        Vil.      0.274     0.565     0.369        85
        Ani.      0.036     0.014     0.020        73
        Pla.      0.000     0.000     0.000        85
        Alb.      1.000     0.100     0.182        60
        Fil.      0.000     0.000     0.000        76
        Wri.      1.000     0.017     0.033        60

    accuracy                          0.155      1000
   macro avg      0.266     0.146     0.102      1000
weighted avg      0.251     0.155     0.106      1000



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
