In [40]:
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification, logging
from sklearn.metrics import classification_report
from sklearn.metrics.pairwise import cosine_similarity
from tensorflow import keras
import tensorflow as tf
import gensim.downloader
from tqdm import tqdm
import numpy as np
import string
import time
import csv
import re

np.random.seed(0)
logging.set_verbosity_error()

In [41]:
# parameters
MODEL = "bert-base-uncased"
THRESHOLD = 0.005
MAXLEN_GET_PSEUDO = 100
EPOCH = 1
BATCH_SIZE = 1

In [42]:
# 前処理
def preprocessing(text):
    # 括弧内文章の削除
    text = re.sub(r'\(.*\)',' ',text)
    text = re.sub(r'\[.*\]',' ',text)
    text = re.sub(r'\<.*\>',' ',text)
    text = re.sub(r'\{.*\}',' ',text)
    # 記号文字の削除
    text = text.translate(str.maketrans('','',string.punctuation))
    # スペースの調整
    text = re.sub(r'\s+',' ',text)
    return text

In [43]:
# 20 newsgroups datasets
from sklearn.datasets import fetch_20newsgroups
newsgroups = fetch_20newsgroups(subset="all")
newsgroups_datasets = list()

# example ------------------------------------------------
for texts in tqdm(newsgroups.data[:10000]):
  texts = texts.split("\n\n")
  texts = " ".join(texts[1:])
  newsgroups_datasets.append(preprocessing(texts))
# --------------------------------------------------------

# for texts in tqdm(newsgroups.data):
#   texts = texts.split("\n\n")
#   texts = " ".join(texts[1:])
#   newsgroups_datasets.append(preprocessing(texts))

100%|██████████| 10000/10000 [00:00<00:00, 15922.76it/s]


In [44]:
# yahoo topic datasets
with open('../data/topic/train_pu_half_v0.txt','r',encoding='utf-8') as f:
    texts_v0 = f.read()
with open('../data/topic/train_pu_half_v1.txt','r',encoding='utf-8') as f:
    texts_v1 = f.read()
texts = texts_v0 + texts_v1
topic_datasets = list()

# example ----------------------------------------------
for label_text in tqdm(texts.splitlines()[:10000]):
  _, text = label_text.split("\t")
  topic_datasets.append(preprocessing(text))
# -------------------------------------------------------

# for label_text in tqdm(texts.splitlines()):
#   _, text = label_text.split("\t")
#   topic_datasets.append(preprocessing(text))

100%|██████████| 10000/10000 [00:00<00:00, 51808.78it/s]


In [45]:
# reuters datasets
with open("../data/reuter/sourceall.txt", "r", encoding="utf-8") as f:
  reuter = f.read().split("\n")[:-1]

# example -----------------------------------
reuter = reuter[:10000]
# -------------------------------------------

# reuter = reuter

reuters_datasets = list()
for label_text in tqdm(reuter):
  _, text = label_text.split("\t")
  reuters_datasets.append(preprocessing(text))

100%|██████████| 10000/10000 [00:00<00:00, 23809.38it/s]


In [46]:
# dbpedia datasets train
with open('../data/dbpedia_csv/train.csv','r',encoding='utf-8') as f:
    reader = [r for r in csv.reader(f)]
    
# example -------------------
reader = reader[:10000]
#----------------------------

# reader = reader

dbpedia_train_datasets = list()
for _, auth, text in tqdm(reader):
    text = text.replace(auth,'')
    dbpedia_train_datasets.append(preprocessing(text))

100%|██████████| 10000/10000 [00:00<00:00, 51622.77it/s]


In [47]:
# dbpedia classes
with open("../data/dbpedia_csv/classes.txt", "r", encoding="utf-8") as f:
  classes = f.read().splitlines()

In [48]:
datasets_texts = newsgroups_datasets + topic_datasets + reuters_datasets + dbpedia_train_datasets

In [49]:
word2vec = gensim.downloader.load('word2vec-google-news-300')

def w2v_avg_vector(sentence):
  vector = np.zeros((300,), dtype="float32")
  count = 0
  for word in sentence.split():
    try:
      vector = np.add(vector, word2vec[word])
      count += 1
    except:
      pass
  if count > 0:
    vector = np.divide(vector, len(word))
  return vector

In [50]:
classes_vector = list()
for cls in classes:
  classes_vector.append(w2v_avg_vector(cls))

In [51]:
diff_datasets = {i:[] for i in range(len(classes))}
for texts in tqdm(datasets_texts):
  texts_vector = w2v_avg_vector(texts)
  similarity = cosine_similarity([texts_vector], classes_vector)[0]
  sim_argsorted = np.argsort(similarity)
  diff = similarity[sim_argsorted[-1]] - similarity[sim_argsorted[-2]]
  if diff > THRESHOLD:
    diff_datasets[sim_argsorted[-1]].append((diff, texts))

pseudo_texts = list()
pseudo_labels = list()
for i in range(len(classes)):
  sorted_diff_data = sorted(diff_datasets[i], reverse=True)[:MAXLEN_GET_PSEUDO]
  pseudo_texts.extend([i[1] for i in sorted_diff_data])
  pseudo_labels.extend([i]*len(sorted_diff_data[:MAXLEN_GET_PSEUDO]))

100%|██████████| 40000/40000 [00:14<00:00, 2741.37it/s]


In [52]:
print("Number of all selected data")
for i in diff_datasets:
  print(classes[i][:3]+". : "+str(len(diff_datasets[i])))

Number of all selected data
Com. : 5155
Edu. : 1112
Art. : 114
Ath. : 112
Off. : 10734
Mea. : 11927
Bui. : 493
Nat. : 2713
Vil. : 210
Ani. : 217
Pla. : 287
Alb. : 570
Fil. : 570
Wri. : 889


In [53]:
tokenizer = AutoTokenizer.from_pretrained(MODEL)
x_train = tokenizer(pseudo_texts, truncation=True, return_tensors="tf", padding="max_length", max_length=512)
y_train = np.array(pseudo_labels)

model = TFAutoModelForSequenceClassification.from_pretrained(MODEL)
model.classifier = tf.keras.layers.Dense(units=14, activation="softmax", name="classifer")
model.compile(optimizer=keras.optimizers.Adam(3e-5), 
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=tf.metrics.SparseCategoricalAccuracy())
model.fit(x_train["input_ids"], y_train, batch_size=BATCH_SIZE, epochs=EPOCH)

  output, from_logits = _get_logits(




<keras.callbacks.History at 0x268daef77c0>

In [54]:
pred = model.predict(x_train["input_ids"], batch_size=BATCH_SIZE)
y_pred = [np.argmax(i) for i in pred.logits]

target_names = [c[:3]+"." for c in classes]
rep = classification_report(y_train, y_pred, target_names=target_names, digits=3)
print(rep)

              precision    recall  f1-score   support

        Com.      0.000     0.000     0.000       100
        Edu.      0.000     0.000     0.000       100
        Art.      0.000     0.000     0.000       100
        Ath.      0.000     0.000     0.000       100
        Off.      0.000     0.000     0.000       100
        Mea.      0.000     0.000     0.000       100
        Bui.      0.000     0.000     0.000       100
        Nat.      0.000     0.000     0.000       100
        Vil.      0.000     0.000     0.000       100
        Ani.      0.000     0.000     0.000       100
        Pla.      0.000     0.000     0.000       100
        Alb.      0.071     1.000     0.133       100
        Fil.      0.000     0.000     0.000       100
        Wri.      0.000     0.000     0.000       100

    accuracy                          0.071      1400
   macro avg      0.005     0.071     0.010      1400
weighted avg      0.005     0.071     0.010      1400



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [55]:
# load test data
# dbpedia datasets train
with open('../data/dbpedia_csv/test.csv','r',encoding='utf-8') as f:
    reader = [r for r in csv.reader(f)]
    
# example -------------------
import random
reader = random.sample(reader, 1000)
#----------------------------

# reader = reader

test_texts = list()
test_labels = list()
for labels, auth, text in tqdm(reader):
    text = text.replace(auth,'')
    test_texts.append(preprocessing(text))
    test_labels.append(int(labels)-1)

100%|██████████| 1000/1000 [00:00<00:00, 55544.87it/s]


In [56]:
x_test = tokenizer(test_texts, truncation=True, return_tensors="tf", padding="max_length", max_length=512)
y_test = np.array(test_labels)

In [57]:
pred = model.predict(x_test["input_ids"], batch_size=BATCH_SIZE)
y_pred = [np.argmax(i) for i in pred.logits]
target_names = [c[:3]+"." for c in classes]
rep = classification_report(y_test, y_pred, target_names=target_names, digits=3)
print(rep)

              precision    recall  f1-score   support

        Com.      0.000     0.000     0.000        73
        Edu.      0.000     0.000     0.000        69
        Art.      0.000     0.000     0.000        77
        Ath.      0.000     0.000     0.000        75
        Off.      0.000     0.000     0.000        67
        Mea.      0.000     0.000     0.000        71
        Bui.      0.000     0.000     0.000        63
        Nat.      0.000     0.000     0.000        63
        Vil.      0.000     0.000     0.000        79
        Ani.      0.000     0.000     0.000        75
        Pla.      0.000     0.000     0.000        74
        Alb.      0.082     1.000     0.152        82
        Fil.      0.000     0.000     0.000        62
        Wri.      0.000     0.000     0.000        70

    accuracy                          0.082      1000
   macro avg      0.006     0.071     0.011      1000
weighted avg      0.007     0.082     0.012      1000



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
