In [1]:
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification, logging
import tensorflow as tf
import numpy as np
import re
import string
from nltk.corpus import wordnet
from tqdm import tqdm
from sklearn import metrics
import csv
from tensorflow import keras

logging.set_verbosity_error()
np.random.seed(0)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# parameters
MODEL = "bert-base-uncased"
X_TRAIN = '../dataset/HF-BERT_x_train.npy'
Y_TRAIN = '../dataset/HF-BERT_y_train.npy'
X_TEST = '../dataset/HF-BERT_x_test.npy'
Y_TEST = '../dataset/HF-BERT_y_test.npy'
SAVED_MODEL = "../Baseline-HF-BERT.h5"
EPOCH = 10
BATCH_SIZE = 8

In [3]:
# 前処理
def preprocessing(text):
    # 括弧内文章の削除
    text = re.sub(r'\(.*\)',' ',text)
    text = re.sub(r'\[.*\]',' ',text)
    text = re.sub(r'\<.*\>',' ',text)
    text = re.sub(r'\{.*\}',' ',text)
    # 記号文字の削除
    text = text.translate(str.maketrans('','',string.punctuation))
    # スペースの調整
    text = re.sub(r'\s+',' ',text)
    return text

In [4]:
# preprocessing train data -----------------------------------------------------------------------
# load topic class labels
print("making train dataset...")
with open('../data/topic/classes.txt','r',encoding='utf-8') as f:
    labels = f.read().splitlines()
topic_class_hypothesis = dict()
for i,label in enumerate(labels):
    topic_class_hypothesis[i] = 'this text is about ' + ' or '.join([wordnet.synsets(word)[0].definition() for word in label.split(' & ')])

# load train data
with open('../data/topic/train_pu_half_v0.txt','r',encoding='utf-8') as f:
    texts_v0 = f.read()
with open('../data/topic/train_pu_half_v1.txt','r',encoding='utf-8') as f:
    texts_v1 = f.read()
texts = texts_v0 + texts_v1

# # ## example -------------------------------------
# import random
# texts = texts.splitlines()
# texts = random.sample(texts,10000)
# texts = "\n".join(texts)
# # ## ---------------------------------------------

tokenizer = AutoTokenizer.from_pretrained(MODEL)

x_train, y_train = [],[]
first, second = [],[]
for label_text in tqdm(texts.splitlines()):
    label,text = label_text.split('\t')
    rand_base = [0,1,2,3,4,5,6,7,8,9]
    rand_base.remove(int(label))
    label_rand = np.random.choice(rand_base)
    first.append(preprocessing(text))
    second.append(topic_class_hypothesis[int(label)])
    y_train.append(1)
    first.append(preprocessing(text))
    second.append(topic_class_hypothesis[int(label_rand)])
    y_train.append(0)

x_train = tokenizer(first, second, truncation=True, return_tensors="tf", padding="max_length", max_length=512)

making train dataset...


100%|██████████| 10000/10000 [00:00<00:00, 19575.35it/s]


In [5]:
y_train = np.array(y_train)

model = TFAutoModelForSequenceClassification.from_pretrained(MODEL)
model.classifier = tf.keras.layers.Dense(units=1, activation="sigmoid", name="classifier")
model.compile(optimizer=keras.optimizers.Adam(3e-5),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),
              metrics=tf.keras.metrics.BinaryAccuracy())
model.fit([x_train["input_ids"],x_train["attention_mask"]], y_train, epochs=EPOCH, batch_size=BATCH_SIZE)

Epoch 1/2
  18/2500 [..............................] - ETA: 9:22 - loss: 0.7170 - binary_accuracy: 0.5694

KeyboardInterrupt: 

In [None]:
pred = model.predict([x_train["input_ids"],x_train["attention_mask"]], batch_size=BATCH_SIZE)
print(pred.logits)
y_pred = np.where(pred.logits<0.5, 0, 1)

rep = metrics.classification_report(y_train,y_pred,digits=3)
print(rep)

[[0.21305315]
 [0.24673192]
 [0.99761766]
 ...
 [0.00314936]
 [0.998161  ]
 [0.05664188]]
              precision    recall  f1-score   support

           0      0.928     0.980     0.954     10000
           1      0.979     0.924     0.951     10000

    accuracy                          0.952     20000
   macro avg      0.954     0.952     0.952     20000
weighted avg      0.954     0.952     0.952     20000



In [None]:
# dbpedia class ------------------------------------------------------------------------------------------------------
with open('../data/dbpedia_csv/classes.txt','r',encoding='utf-8') as f:
    classes = f.read().splitlines()
    dbpedia_class = ['this text is about '+text for text in classes]

with open('../data/dbpedia_csv/test.csv','r',encoding='utf-8') as f:
    reader = [r for r in csv.reader(f)]
    
# # # example -------------------
# import random
# reader = random.sample(reader,1000)
# # #----------------------------

x_test, y_test = [],[]
first, second = [],[]
for cls_num,auth,readtext in tqdm(reader,total=len(reader)):
    for db_class in dbpedia_class:
        text = readtext.replace(auth, "")
        first.append(preprocessing(text))
        second.append(db_class)
    y_test.append(int(cls_num)-1)           

x_test = tokenizer(first, second, truncation=True, return_tensors="tf", padding="max_length", max_length=512)   

100%|██████████| 1000/1000 [00:00<00:00, 5277.69it/s]


In [None]:
pred = model.predict([x_test["input_ids"],x_test["attention_mask"]], batch_size=BATCH_SIZE)
split_pred = np.array_split(pred.logits,len(y_test))
y_pred = [np.argmax(p) for p in split_pred]

target_names = [c[:3]+"." for c in classes]
rep = metrics.classification_report(y_test,y_pred,target_names=target_names,digits=3)
print(rep)

              precision    recall  f1-score   support

        Com.      0.329     0.641     0.435        78
        Edu.      0.249     0.880     0.388        50
        Art.      0.111     0.027     0.044        73
        Ath.      0.208     0.054     0.086        92
        Off.      0.952     0.244     0.388        82
        Mea.      0.167     0.014     0.026        71
        Bui.      0.423     0.159     0.232        69
        Nat.      0.179     0.085     0.116        82
        Vil.      0.314     0.960     0.474        75
        Ani.      0.947     0.250     0.396        72
        Pla.      0.569     0.899     0.697        69
        Alb.      0.782     0.597     0.677        72
        Fil.      0.912     0.517     0.660        60
        Wri.      0.132     0.218     0.164        55

    accuracy                          0.378      1000
   macro avg      0.448     0.396     0.342      1000
weighted avg      0.448     0.378     0.333      1000

