* Binary Relevance Learning.  
* See http://www.aic.uniovi.es/~jdiez/Jorge_Diez/Journal_Papers_files/luaces2012a.pdf
* 文章データをBinary Relevance Learningでマルチラベル分類を行う
* 2値分類器にはSGDを使う

In [1]:
import numpy as np
import matplotlib
import matplotlib.pylab as plt
from sklearn import datasets, linear_model
from sklearn.feature_extraction.text import CountVectorizer

In [2]:
class BinaryRelevance:
    
    def __init__(self, corpus):
        """クラスの初期化
        
        Args:
            corpus (string(object) np.array): コーパス
        """
        self.labels = [] # 分類ラベルリスト
        self.clfs = {} # 分類器インスタンスリスト
        # set vectorizer
        self.vectorizer = CountVectorizer(binary=True) # BoW, binary
        self.vectorizer.fit_transform(corpus)
        
    def train(self, target_label, positive_x, negative_x):
        """学習
        
        Args:
            target_label (int): どのラベルの分類器を学習させるか
            positive_x (string(object) np.array): 正例の文章リスト
            negative_x (string(object) np.array): 負例の文章リスト
        Returns:
            (bool): 
        """
        # エラーチェック
        if not self.exists_label(target_label):
            return False
        # ペアデータセットにしてシャッフルする
        dataset = []
        for x in positive_x:
            dataset.append((x,1)) # 正例
        for x in negative_x:
            dataset.append((x,0)) # 負例
        dataset = np.array(dataset)
        np.random.shuffle(dataset) # シャッフル
        x = np.array(dataset[:,0], dtype="object") # 入力
        y = np.array(dataset[:,1], dtype="int32") # ラベル
        self.clfs[target_label].fit(self.vectorizer.transform(x), y) # 学習            
        return True
    
    def predict(self, x):
        """予測
        
        Args:
            x: 予測させる文章リスト
        Returns:
            result: 
        """
        result = []
        for i in range(len(x)):
            result.append([]) # 付与されたラベルを追加していくための配列
        for label in self.clfs: # 分類器をループ
            y = self.clfs[label].predict(self.vectorizer.transform(x)) # このラベルかどうかを予測
            for i, y_ in enumerate(y):
                if y_ == 1: # このラベルがつくと予想された
                    result[i].append(label)
        return result
        
    def set_labels(self, labels):
        """ラベルとラベルに対応する分類器インスタンスをセットする
        
        Args:
            labels (int np.array): 追加するラベルリスト
        """
        for label in labels:
            self.labels.append(label) # ラベル追加
            self.clfs[label] = linear_model.SGDClassifier(loss="hinge", penalty="l2", max_iter=5, n_iter=None) # 分類器インスタンス作成
    
    def exists_label(self, label):
        """ラベルが存在するかどうか
        
        Args:
            label (int): 調べるラベル
        Returns:
            (bool): 
        """
        if (label not in self.labels) or (label not in self.clfs):
            return False
        return True

In [3]:
categories = ['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc']
train = datasets.fetch_20newsgroups(subset='train', categories=categories)
valid = datasets.fetch_20newsgroups(subset='test', categories=categories)

br = BinaryRelevance(train.data)
br.set_labels([0,1,2])

N = 100
train0, train1, train2 = [], [], []
for x, y in zip(train.data, train.target):
    if y == 0:
        train0.append(x)
    elif y == 1:
        train1.append(x)
    elif y == 2:
        train2.append(x)
train0, train1, tarin2 = train0[:N], train1[:N], train2[:N]

negative_x = np.array(train1+train2)
np.random.shuffle(negative_x)
negative_x = negative_x[:N]
br.train(0, train0, negative_x)

negative_x = np.array(train0+train2)
np.random.shuffle(negative_x)
negative_x = negative_x[:N]
br.train(1, train1, negative_x)

negative_x = np.array(train0+train1)
np.random.shuffle(negative_x)
negative_x = negative_x[:N]
br.train(2, train2, negative_x)

True

In [4]:
valid_data_tmp = valid.data[:50]
valid_target_tmp = valid.target[:50]

preds = br.predict(valid_data_tmp)

print('y\tpred')
for y, pred in zip(valid_target_tmp, preds):
    print(y, '\t', pred)

y	pred
0 	 [2]
1 	 [2]
1 	 [2]
0 	 [0]
2 	 [2]
0 	 [0]
2 	 [2]
0 	 [0]
1 	 [2]
2 	 [2]
1 	 [1, 2]
1 	 [0]
2 	 [2]
2 	 [2]
2 	 [2]
2 	 [2]
2 	 [2]
0 	 [0]
0 	 [0]
1 	 [1]
1 	 [0]
0 	 [0]
1 	 [0]
1 	 [2]
1 	 [1]
2 	 [2]
1 	 [1, 2]
1 	 []
2 	 [2]
2 	 [2]
0 	 [1, 2]
0 	 [0]
2 	 [2]
2 	 [2]
1 	 [1]
0 	 []
2 	 [2]
2 	 [2]
0 	 [0]
1 	 [1]
1 	 [0, 1]
2 	 [2]
0 	 [0]
2 	 [2]
0 	 [0]
2 	 [2]
0 	 [0]
1 	 [2]
1 	 [1]
0 	 [0]


In [5]:
!python --version

Python 3.6.6


In [6]:
!pip freeze

alembic==0.9.9
asn1crypto==0.24.0
attrs==18.1.0
Automat==0.0.0
backcall==0.1.0
beautifulsoup4==4.6.1
bleach==2.1.3
bokeh==0.12.16
boto==2.49.0
boto3==1.7.71
botocore==1.10.71
bz2file==0.98
certifi==2018.4.16
cffi==1.11.5
chardet==3.0.4
cloudpickle==0.5.3
conda==4.5.8
constantly==15.1.0
cryptography==2.2.1
cycler==0.10.0
Cython==0.28.5
dask==0.18.2
decorator==4.3.0
dill==0.2.8.2
docutils==0.14
entrypoints==0.2.3
fastcache==1.0.2
gensim==3.5.0
gmpy2==2.0.8
h5py==2.7.1
html5lib==1.0.1
hyperlink==17.3.1
idna==2.7
imageio==2.3.0
incremental==17.5.0
ipykernel==4.8.2
ipython==6.5.0
ipython-genutils==0.2.0
ipywidgets==7.2.1
jedi==0.12.1
Jinja2==2.10
jmespath==0.9.3
jsonschema==2.6.0
jupyter-client==5.2.3
jupyter-core==4.4.0
jupyterhub==0.8.1
jupyterlab==0.33.4
jupyterlab-launcher==0.11.2
kiwisolver==1.0.1
llvmlite==0.23.0
Mako==1.0.7
MarkupSafe==1.0
matplotlib==2.2.2
mecab-python3==0.7
mistune==0.8.3
nbconvert==5.3.1
nbformat==4.4.0
netw