In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from src.config import get_dataset_and_user
from src.user import DummyUser
%matplotlib inline

In [24]:
# get data
X_iris, user_iris = get_dataset_and_user('iris')
y_iris = user_iris.get_label(X_iris, update_counter=False)

X_housing, user_housing = get_dataset_and_user('housing')
y_housing = user_housing.get_label(X_housing, update_counter=False)

X_sdss, user_sdss = get_dataset_and_user('sdss_Q1.1')
print((user_sdss.get_label(X_sdss) == 1).sum())
user_sdss.clear()

62


In [13]:
from src.active_learning.svm import SimpleMargin, SolverMethod, OptimalMargin
from src.active_learning.svm.base import OptimalMarginCholesky
from src.active_learning.linear import LinearMajorityVote, KernelMajorityVote, KernelCholeskyMajorityVote
from src.active_learning.agnostic import RandomLearner
from src.showdown import Showdown
from src.plotting import plot_showdown
from src.initial_sampling import StratifiedSampler
from sklearn.svm import SVC

user_housing.max_iter = 35
user_iris.max_iter = 10
user_sdss.max_iter = 40

active_learners_list = [
    #("random", RandomLearner(SVC(C=100000, kernel='rbf'))),
    ("optimal margin", OptimalMargin(top=-1, chain_length=64, sample_size=8, kind='kernel', kernel='rbf', C=100000)),
    ("optimal margin cholesky", OptimalMarginCholesky(top=-1, chain_length=64, sample_size=8, kind='kernel', kernel='rbf', C=100000)),
    ("simple margin", SimpleMargin(top=-1, kind='kernel', kernel='rbf', C=100000)),
    ("Kernel majority vote", KernelMajorityVote(top=-1, chain_length=64, sample_size=8)),
    ("Kernel cholesky majority vote", KernelCholeskyMajorityVote(top=-1, chain_length=64, sample_size=8))
]

datasets_list = [
    ('sdss', X_sdss, user_sdss)
    #("housing", X_housing, user_housing),
    #("iris", X_iris, user_iris)
]

times = 10
showdown = Showdown()
output = showdown.run(datasets_list, active_learners_list, times, StratifiedSampler(1,1))

#print(output.head())
plot_showdown(output, times, metrics_list=['fscore', 'iteration_time'])

ValueError: a must be greater than 0

In [11]:
output['sdss']['fscore']

Unnamed: 0_level_0,Kernel cholesky majority vote,Kernel cholesky majority vote,Kernel cholesky majority vote,Kernel cholesky majority vote,Kernel majority vote,Kernel majority vote,Kernel majority vote,Kernel majority vote,optimal margin,optimal margin,optimal margin,optimal margin,optimal margin cholesky,optimal margin cholesky,optimal margin cholesky,optimal margin cholesky,simple margin,simple margin,simple margin,simple margin
Unnamed: 0_level_1,max,mean,min,std,max,mean,min,std,max,mean,min,std,max,mean,min,std,max,mean,min,std
2,0.009641,0.003979,0.002511,0.00222,0.021543,0.006495,0.002594,0.006008,0.003819,0.002911,0.002563,0.000369,0.003819,0.002911,0.002563,0.000369,0.003819,0.002911,0.002563,0.000369
3,0.020708,0.007637,0.002871,0.005306,0.048343,0.010667,0.003373,0.013627,0.007746,0.004895,0.003009,0.001584,0.006579,0.004199,0.003131,0.001127,0.004731,0.003944,0.003407,0.000402
4,0.022661,0.009443,0.004389,0.005479,0.127572,0.02386,0.003025,0.037017,0.009965,0.006535,0.003291,0.00213,0.007972,0.005819,0.003313,0.001412,0.006874,0.005787,0.003774,0.001026
5,0.043417,0.017011,0.006511,0.010834,0.086532,0.023931,0.002656,0.023386,0.063008,0.012878,0.004216,0.017749,0.015072,0.008552,0.005876,0.002602,0.009561,0.006837,0.004019,0.001839
6,0.035632,0.018237,0.011347,0.007452,0.093868,0.036944,0.001659,0.023122,0.075841,0.017325,0.004631,0.020928,0.020868,0.01176,0.006369,0.004877,0.01081,0.007886,0.004866,0.002144
7,0.051282,0.031496,0.021468,0.009185,0.150121,0.048979,0.004637,0.038514,0.103247,0.023265,0.00621,0.028782,0.018566,0.013225,0.00901,0.003482,0.013214,0.008635,0.005285,0.002457
8,0.105174,0.047782,0.01632,0.030688,0.33795,0.082965,0.012923,0.096266,0.16273,0.038919,0.016459,0.044353,0.034714,0.019214,0.013161,0.006167,0.016511,0.010051,0.006884,0.002829
9,0.08384,0.04067,0.015744,0.022466,0.098901,0.055979,0.015796,0.02587,0.16273,0.042846,0.012529,0.044914,0.040013,0.024173,0.016638,0.007221,0.017694,0.01195,0.008601,0.003185
10,0.137143,0.052023,0.026271,0.036449,0.141553,0.061743,0.022025,0.038151,0.105085,0.04484,0.01591,0.0283,0.041196,0.029652,0.018273,0.007821,0.017963,0.013126,0.009819,0.002893
11,0.19375,0.079442,0.027778,0.054674,0.135693,0.078021,0.026992,0.040416,0.083389,0.057737,0.023588,0.016772,0.100813,0.038515,0.021998,0.024329,0.023436,0.015901,0.010987,0.004204


In [5]:
from datetime import datetime

def point_parser(s):
    s = s[1:-1]  # remove square brackets
    arrays = s.split('\n')
    arrays_without_brackets = map(lambda x: x.strip()[1:-1], arrays)
    final = list(map(lambda x: [float(y) for y in x.split()], arrays_without_brackets))
    return final[0]

def parse_log(path='task.log'):
    with open(path, 'r') as f:
        for line in f:
            timestamp, _, _, _, iteration, point, label = line.strip().split('\t')
            timestamp = datetime.strptime(timestamp, "%Y-%m-%d %H:%M:%S,%f").timestamp()
            point = point_parser(point)
            label = float(label)
            iteration = int(iteration)
            yield timestamp, iteration, point, label

logs = parse_log()            
t0, i0, x0, l0 = next(logs)
points = []
labels = []
for t, i, x, l in logs:
    points.append(x)
    labels.append(l)
    if i != i0:
        i0 = i
        # compute metrics
        pass

print(points)
print(labels)

[[1.06763984, 0.120323], [0.64291794, 0.7749268], [0.7352499, 1.3297939], [0.31095826, -0.5516086], [0.03080411, -1.3922941], [0.30335195, 0.3285162], [0.51688648, 0.0122643], [0.25762603, -0.9693189], [0.6490522, -0.251936], [0.03770038, -0.645166], [0.23121638, 0.162121], [0.27197924, -0.0064406], [0.10793421, -0.0382355], [0.24735609, -0.4865311], [0.37994231, -0.1391162], [0.20752543, 0.000673], [0.15093067, 0.0927536], [0.05696656, -0.3690867], [0.04461903, -0.1547663], [0.13760706, -0.26648], [0.16803627, -0.1718175], [0.32095196, 0.0572935], [0.09500872, 0.0439689], [0.23776605, 0.0460453], [0.23912695, -0.1371897], [0.21604, -0.107509], [0.21554814, -0.0458295], [0.14981188, 0.0609587], [0.20587145, 0.0454314], [0.12435991, 0.0078047], [0.10323021, -0.046474], [0.17668082, -0.0602428], [0.15368051, 0.0391121], [0.08266119, -0.0527572], [0.07762279, -0.0185194], [0.18959411, -0.0480755], [0.15154231, -0.0542052], [0.08189409, -0.0385244], [0.13922537, -0.0585438], [0.20362366, 0