In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import plotly.graph_objects as go
from data.simulation.simulator import Simulator
from sampler import CVXSampler, RandomSampler
from learner import SVMLearner
template = "plotly_white"

# Initialize variables

In [3]:
# Initialize
simulator = Simulator("moon", noise=0.1)
learner = SVMLearner()
npr = np.random.RandomState(123)
sigma = 5 
assert sigma > 1.

# Generating Data

In [14]:
# Dataset
N = 100             # Number of simulation data
input_dim = 2       # Feature dimension
train_X, train_y = simulator.simulate(N, input_dim)
labeled_mask = np.zeros(N).astype(np.bool)
labeled_ratio = 0.2
labeled_mask[:int(labeled_ratio * N)] = True

test_X, test_y = simulator.simulate(N, input_dim)

In [15]:
def vis_scatter(X, color):
    # Visualize the data
    fig = go.Figure(data=go.Scatter(x=X[:, 0], 
                                    y=X[:, 1], 
                                    mode='markers', 
                                    marker=dict(color=color), 
                                    marker_line_width=2))
    fig.update_layout(template=template)
    fig.update_xaxes(showline=True, linewidth=1.5, linecolor='Black', mirror=True)
    fig.update_yaxes(showline=True, linewidth=1.5, linecolor='Black', mirror=True)
    fig.update_layout(width=600, height=400, margin=dict(b=0, t=0, l=0, r=0))
    fig.show()

# Visaulize train data
vis_scatter(train_X, train_y)

In [16]:
def eval_acc(X, y, vis=False):
    prob = learner.predict_proba(X)
    pred = prob.argmax(1)
    
    acc = (pred == y).mean()
    print("accuracy: {:.2f}".format(acc))
    
    if vis:
        vis_scatter(X, pred)
    return acc

In [17]:
learner.fit(train_X[labeled_mask], train_y[labeled_mask])
acc = eval_acc(test_X, test_y, vis=True)

accuracy: 0.94


 # CVXSampler

In [18]:
cvx_sampler = CVXSampler(train_X, labeled_mask, K=2, sigma=sigma, alpha=0.5)
idx_to_label, Z = cvx_sampler.sample(None, learner)
print("Sample {} data points".format(len(idx_to_label)))

Solving Problem
Takes 12.19 sec
Sample 6 data points


In [19]:
fig = go.Figure(data=go.Heatmap(z=Z, colorscale='Viridis'))
fig.update_layout(template=template)
fig.update_xaxes(showline=True, linewidth=1.5, linecolor='Black', mirror=True)
fig.update_yaxes(showline=True, linewidth=1.5, linecolor='Black', mirror=True)
fig.update_layout(width=600, height=400, margin=dict(b=0, t=0, l=0, r=0))
fig.show()

In [20]:
cur_labeled_mask = np.copy(labeled_mask)
cur_labeled_mask[idx_to_label] = True
learner.fit(train_X[cur_labeled_mask], train_y[cur_labeled_mask])
acc = eval_acc(test_X, test_y, vis=True)

accuracy: 0.97


# Random Sampler

In [23]:
n_data_to_label = len(idx_to_label)
random_sampler = RandomSampler(train_X, labeled_mask)
idx_to_label = random_sampler.sample(n_data_to_label)
print("Sample {} data points".format(len(idx_to_label)))
cur_labeled_mask = np.copy(labeled_mask)
cur_labeled_mask[idx_to_label] = True
learner.fit(train_X[cur_labeled_mask], train_y[cur_labeled_mask])
acc = eval_acc(test_X, test_y, vis=True)

Sample 6 data poitns
accuracy: 0.94
