# Active Learner sobre 20 News Groups

## Obtengo el dataset desde sklearn

In [None]:
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer

import numpy as np

import seaborn as sns

In [None]:
from core import ActiveLearner, Dataset, Oracle
from sklearn.naive_bayes import MultinomialNB
from querys import CertaintySelector, UncertaintySelector, RandomSelector, MinDiffSelector, EntropySelector

In [None]:
import matplotlib.pyplot as plt

In [None]:
def get_n_each_category(dataset, n):
    train = []
    for cat in range(len(dataset.target_names)):
        count = 0
        i = 0
        while count < n and i < len(dataset.target):
            if dataset.target[i] == cat:
                train.append(i)
                count += 1
            i += 1
    train.sort()
    return train

In [None]:
def remove_from_dataset(dataset, i):
    del dataset.data[i]
    dataset.target = np.delete(dataset.target, i)
    dataset.filenames = np.delete(dataset.filenames, i)
    return dataset

In [None]:
def remove_many_from_dataset(dataset, indices):
    for i, index in enumerate(indices):
        dataset = remove_from_dataset(dataset, index-i)
    return dataset

In [None]:
def split_train_data(dataset, train_indices):
    train_data = []
    train_target = []
    for i in train_indices:
        train_data.append(dataset.data[i])
        train_target.append(dataset.target[i])
    dataset = remove_many_from_dataset(dataset, train_indices)
    return dataset, train_data, train_target

In [None]:
def clean_dataset(dataset):
    chars = set("abcdefghijklmnopqrstuvwxyz")
    to_remove = []
    for i in range(len(dataset.data)):
        dataset.data[i] = dataset.data[i].strip()
        dataset.data[i] = dataset.data[i].lower()
        if len(dataset.data[i]) == 0:
            to_remove.append(i)
        if not any((c in chars) for c in dataset.data[i]):
            to_remove.append(i)
    
    return remove_many_from_dataset(dataset, to_remove)

In [None]:
categories = [
#    'alt.atheism',
#    'comp.graphics',
#    'comp.os.ms-windows.misc',
#    'comp.sys.ibm.pc.hardware',
    'comp.sys.mac.hardware',
#    'comp.windows.x',
    'misc.forsale',
#    'rec.autos',
#    'rec.motorcycles',
#    'rec.sport.baseball',
    'rec.sport.hockey',
#    'sci.crypt',
#    'sci.electronics',
    'sci.med',
#    'sci.space',
    'soc.religion.christian',
#    'talk.politics.guns',
    'talk.politics.mideast',
#    'talk.politics.misc',
#    'talk.religion.misc',
]

In [None]:
dataset = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'), categories=categories)

In [None]:
dataset = clean_dataset(dataset)

In [None]:
train_indices = get_n_each_category(dataset, 10)
dataset, train_data, train_target = split_train_data(dataset, train_indices)

In [None]:
test_indices = get_n_each_category(dataset, 100)
dataset, test_data, test_target = split_train_data(dataset, test_indices)

## Obtengo los features TF-IDF

In [None]:
vectorizer = TfidfVectorizer(max_features=1000)
X_train = vectorizer.fit_transform(train_data)
X_unlabeled = vectorizer.transform(dataset.data)
X_test = vectorizer.transform(test_data)

## Instancio lo minimo necesario para el framework

In [None]:
class NGDataset(Dataset):
    dataset = dataset
    
    def get_unlabeled_readable(self, i):
        #return self.dataset.data[i]
        return self.dataset.target[i]

y_train = np.array(train_target)
ngdataset = NGDataset(X_train, y_train, X_unlabeled)

    
class NewsGroupOracle(Oracle):
    target_names = dataset.target_names
    
    def ask(self, X_readable, recoms):
        return X_readable


model = MultinomialNB(alpha=.01)
oracle = NewsGroupOracle()
al = ActiveLearner(model, ngdataset, MinDiffSelector, oracle)
scores = []

In [None]:
al.fit()
scores.append(al.model.score(X_test, test_target))

In [None]:
for _ in range(100):
    selected = al.select(10)
    y = al.ask(selected)
    al.tag_elements(selected, y)
    al.fit()
    scores.append(al.model.score(X_test, test_target))

In [None]:
#plt.ylim(0,1)
#plt.xlim(200,1000)
plt.plot(al.get_scores())
plt.plot(scores)
plt.show()

In [None]:
al.change_selector(MinDiffSelector)