In [81]:
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification, logging
from sklearn.metrics import classification_report
from sklearn.metrics.pairwise import cosine_similarity
from tensorflow import keras
import tensorflow as tf
import gensim.downloader
from tqdm import tqdm
import numpy as np
import string
import time
import csv
import re

np.random.seed(0)
logging.set_verbosity_error()

In [82]:
# parameters
MODEL = "bert-base-uncased"
THRESHOLD = 0.005
MAXLEN_GET_PSEUDO = 50
EPOCH = 1
BATCH_SIZE = 1

In [83]:
# 前処理
def preprocessing(text):
    # 括弧内文章の削除
    text = re.sub(r'\(.*\)',' ',text)
    text = re.sub(r'\[.*\]',' ',text)
    text = re.sub(r'\<.*\>',' ',text)
    text = re.sub(r'\{.*\}',' ',text)
    # 記号文字の削除
    text = text.translate(str.maketrans('','',string.punctuation))
    # スペースの調整
    text = re.sub(r'\s+',' ',text)
    return text

In [84]:
# 20 newsgroups datasets
from sklearn.datasets import fetch_20newsgroups
newsgroups = fetch_20newsgroups(subset="all")
newsgroups_datasets = list()

# example ------------------------------------------------
for texts in tqdm(newsgroups.data[:1000]):
  texts = texts.split("\n\n")
  texts = " ".join(texts[1:])
  newsgroups_datasets.append(preprocessing(texts))
# --------------------------------------------------------

# for texts in tqdm(newsgroups.data):
#   texts = texts.split("\n\n")
#   texts = " ".join(texts[1:])
#   newsgroups_datasets.append(preprocessing(texts))

100%|██████████| 1000/1000 [00:00<00:00, 15126.49it/s]


In [85]:
# yahoo topic datasets
with open('../data/topic/train_pu_half_v0.txt','r',encoding='utf-8') as f:
    texts_v0 = f.read()
with open('../data/topic/train_pu_half_v1.txt','r',encoding='utf-8') as f:
    texts_v1 = f.read()
texts = texts_v0 + texts_v1
topic_datasets = list()

# example ----------------------------------------------
for label_text in tqdm(texts.splitlines()[:1000]):
  _, text = label_text.split("\t")
  topic_datasets.append(preprocessing(text))
# -------------------------------------------------------

# for label_text in tqdm(texts.splitlines()):
#   _, text = label_text.split("\t")
#   topic_datasets.append(preprocessing(text))

100%|██████████| 1000/1000 [00:00<00:00, 46139.93it/s]


In [86]:
# reuters datasets
with open("../data/reuter/sourceall.txt", "r", encoding="utf-8") as f:
  reuter = f.read().split("\n")[:-1]

# example -----------------------------------
reuter = reuter[:1000]
# -------------------------------------------

# reuter = reuter

reuters_datasets = list()
for label_text in tqdm(reuter):
  _, text = label_text.split("\t")
  reuters_datasets.append(preprocessing(text))

100%|██████████| 1000/1000 [00:00<00:00, 24972.49it/s]


In [87]:
# dbpedia datasets train
with open('../data/dbpedia_csv/train.csv','r',encoding='utf-8') as f:
    reader = [r for r in csv.reader(f)]
    
# example -------------------
reader = reader[:1000]
#----------------------------

# reader = reader

dbpedia_train_datasets = list()
for _, auth, text in tqdm(reader):
    text = text.replace(auth,'')
    dbpedia_train_datasets.append(preprocessing(text))

100%|██████████| 1000/1000 [00:00<00:00, 62106.55it/s]


In [88]:
# dbpedia classes
with open("../data/dbpedia_csv/classes.txt", "r", encoding="utf-8") as f:
  classes = f.read().splitlines()

In [89]:
datasets_texts = newsgroups_datasets + topic_datasets + reuters_datasets + dbpedia_train_datasets

In [90]:
word2vec = gensim.downloader.load('word2vec-google-news-300')

def w2v_avg_vector(sentence):
  vector = np.zeros((300,), dtype="float32")
  count = 0
  for word in sentence.split():
    try:
      vector = np.add(vector, word2vec[word])
      count += 1
    except:
      pass
  if count > 0:
    vector = np.divide(vector, len(word))
  return vector

In [91]:
classes_vector = list()
for cls in classes:
  classes_vector.append(w2v_avg_vector(cls))

In [92]:
diff_datasets = {i:[] for i in range(len(classes))}
for texts in tqdm(datasets_texts):
  texts_vector = w2v_avg_vector(texts)
  similarity = cosine_similarity([texts_vector], classes_vector)[0]
  sim_argsorted = np.argsort(similarity)
  diff = similarity[sim_argsorted[-1]] - similarity[sim_argsorted[-2]]
  if diff > THRESHOLD:
    diff_datasets[sim_argsorted[-1]].append((similarity[sim_argsorted[-1]], texts))

pseudo_texts = list()
pseudo_labels = list()
for i in range(len(classes)):
  sorted_diff_data = sorted(diff_datasets[i], reverse=True)[:MAXLEN_GET_PSEUDO]
  pseudo_texts.extend([i[1] for i in sorted_diff_data])
  pseudo_labels.extend([i]*len(sorted_diff_data[:MAXLEN_GET_PSEUDO]))

100%|██████████| 4000/4000 [00:01<00:00, 2191.09it/s]


In [93]:
print("Number of all selected data")
for i in diff_datasets:
  print(classes[i][:3]+". : "+str(len(diff_datasets[i])))

Number of all selected data
Com. : 540
Edu. : 91
Art. : 12
Ath. : 8
Off. : 1084
Mea. : 1172
Bui. : 53
Nat. : 281
Vil. : 13
Ani. : 23
Pla. : 24
Alb. : 55
Fil. : 61
Wri. : 96


In [94]:
tokenizer = AutoTokenizer.from_pretrained(MODEL)
x_train = tokenizer(pseudo_texts, truncation=True, return_tensors="tf", padding="max_length", max_length=512)
y_train = np.array(pseudo_labels)

model = TFAutoModelForSequenceClassification.from_pretrained(MODEL)
model.classifier = tf.keras.layers.Dense(units=14, activation="softmax", name="classifer")
model.compile(optimizer=keras.optimizers.Adam(3e-5), 
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=tf.metrics.SparseCategoricalAccuracy())
model.fit(x_train["input_ids"], y_train, batch_size=BATCH_SIZE, epochs=EPOCH)

  output, from_logits = _get_logits(




<keras.callbacks.History at 0x26535043970>

In [95]:
pred = model.predict(x_train["input_ids"], batch_size=BATCH_SIZE)
y_pred = [np.argmax(i) for i in pred.logits]

target_names = [c[:3]+"." for c in classes]
rep = classification_report(y_train, y_pred, target_names=target_names, digits=3)
print(rep)

              precision    recall  f1-score   support

        Com.      0.000     0.000     0.000        50
        Edu.      0.000     0.000     0.000        50
        Art.      0.000     0.000     0.000        12
        Ath.      0.000     0.000     0.000         8
        Off.      0.333     0.380     0.355        50
        Mea.      0.214     0.240     0.226        50
        Bui.      0.000     0.000     0.000        50
        Nat.      0.000     0.000     0.000        50
        Vil.      0.000     0.000     0.000        13
        Ani.      0.000     0.000     0.000        23
        Pla.      0.000     0.000     0.000        24
        Alb.      0.449     0.880     0.595        50
        Fil.      0.147     0.940     0.255        50
        Wri.      0.000     0.000     0.000        50

    accuracy                          0.230       530
   macro avg      0.082     0.174     0.102       530
weighted avg      0.108     0.230     0.135       530



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [96]:
# load test data
# dbpedia datasets train
with open('../data/dbpedia_csv/test.csv','r',encoding='utf-8') as f:
    reader = [r for r in csv.reader(f)]
    
# example -------------------
import random
reader = random.sample(reader, 1000)
#----------------------------

# reader = reader

test_texts = list()
test_labels = list()
for labels, auth, text in tqdm(reader):
    text = text.replace(auth,'')
    test_texts.append(preprocessing(text))
    test_labels.append(int(labels)-1)

100%|██████████| 1000/1000 [00:00<00:00, 57097.99it/s]


In [97]:
x_test = tokenizer(test_texts, truncation=True, return_tensors="tf", padding="max_length", max_length=512)
y_test = np.array(test_labels)

In [98]:
pred = model.predict(x_test["input_ids"], batch_size=BATCH_SIZE)
y_pred = [np.argmax(i) for i in pred.logits]
target_names = [c[:3]+"." for c in classes]
rep = classification_report(y_test, y_pred, target_names=target_names, digits=3)
print(rep)

              precision    recall  f1-score   support

        Com.      0.000     0.000     0.000        72
        Edu.      0.000     0.000     0.000        81
        Art.      0.000     0.000     0.000        57
        Ath.      0.000     0.000     0.000        86
        Off.      0.040     0.036     0.038        83
        Mea.      0.179     0.145     0.160        69
        Bui.      0.000     0.000     0.000        71
        Nat.      0.000     0.000     0.000        69
        Vil.      0.000     0.000     0.000        71
        Ani.      0.000     0.000     0.000        71
        Pla.      0.000     0.000     0.000        68
        Alb.      0.463     0.912     0.614        68
        Fil.      0.084     0.925     0.155        67
        Wri.      0.000     0.000     0.000        67

    accuracy                          0.137      1000
   macro avg      0.055     0.144     0.069      1000
weighted avg      0.053     0.137     0.066      1000



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
