In [None]:
import argparse
import os
import sys
import math

import torch
from torch import nn
from torch import optim

import numpy as np
import pandas as pd
from tqdm import tqdm
import json


## vis
import matplotlib.pyplot as plt
plt.style.use('ggplot')

In [None]:
# def predict(x, a, mu):
#     '''UMAP-inspired predict function
#     x - torch tensor, shape [n_data_points, n_features]
#     a - torch tensor, shape [n_features]
#         1/a.abs() is the extent of bounding box at prediction=0.5
#     mu - torch tensor, shape [n_features]
#     b - scalar. hyper parameter for predict function. Power exponent
#     '''

#     b = 3
#     return 1 / (1 + ((a.abs() * (x - mu).abs()).pow(b)).sum(1))

# # test: UMAP-inspired predict function
# # n = 100
# # x = torch.linspace(-3,3,n).view(n,1)
# # a = torch.tensor(0.5)
# # plt.plot(x, predict(x, a))


# def compute_predicate(x0, selected, n_iter=1000, mu_init=None, a_init=0.4):
#     '''
#         x0 - numpy array, shape=[n_points, n_feature]. Data points
#         selected - boolean array. shape=[n_points] of selection
#     '''

#     # prepare training data
#     # orginal data extent
#     n_points, n_features = x0.shape
#     vmin = x0.min(0)
#     vmax = x0.max(0)
#     x = torch.from_numpy(x0.astype(np.float32))
#     label = torch.from_numpy(selected).float()
#     # normalize
#     mean = x.mean(0)
#     scale = x.std(0) + 0.1
#     x = (x - mean) / scale

#     # Trainable parameters
#     # since data is normalized,
#     # mu can initialized around mean_pos examples
#     # a can initialized around a constant across all axes
#     center_selected = x[selected].mean(0)
#     if mu_init is None:
#         mu_init = center_selected
#     a = (a_init + 0.1*(2*torch.rand(n_features)-1))
#     mu = mu_init + 0.1 * (2*torch.rand(x.shape[1]) - 1)
#     a.requires_grad_(True)
#     mu.requires_grad_(True)

#     # weight-balance selected vs. unselected based on their size
#     n_selected = selected.sum()
#     n_unselected = n_points - n_selected
#     instance_weight = torch.ones(x.shape[0])
#     instance_weight[selected] = n_points/n_selected
#     instance_weight[~selected] = n_points/n_unselected
#     bce = nn.BCELoss(weight=instance_weight)
#     optimizer = optim.SGD([
#         {'params': mu, 'weight_decay': 0},
#         # smaller a encourages larger reach of the bounding box
#         {'params': a, 'weight_decay': 0.01}
#     ], lr=1e-2, momentum=0.9)

#     # training loop
#     for e in range(n_iter):
#         pred = predict(x, a, mu)
#         loss = bce(pred, label)
#         loss += (mu - center_selected).pow(2).mean() * 20
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
#         if e % (n_iter//5) == 0:
#             # print(pred.min().item(), pred.max().item())
#             print(f'[{e:>4}] loss {loss.item()}')
#     a.detach_()
#     mu.detach_()
#     # plt.stem(a.abs().numpy()); plt.show()

#     pred = (pred > 0.5).float()
#     correct = (pred == label).float().sum().item()
#     total = selected.shape[0]
#     accuracy = correct/total
#     # 1 meaning points are selected
#     tp = ((pred == 1).float() * (label == 1).float()).sum().item()
#     fp = ((pred == 1).float() * (label == 0).float()).sum().item()
#     fn = ((pred == 0).float() * (label == 1).float()).sum().item()
#     precision = tp/(tp+fp)
#     recall = tp/(tp+fn)
#     f1 = 1/(1/precision + 1/recall)
#     print(f'''
# accuracy = {correct/total}
# precision = {precision}
# recall = {recall}
# f1 = {f1}
#     ''')

#     # predicate clause selection
#     # r is the range of the bounding box on each dimension
#     # bounding box is defined by the level set of prediction=0.5
#     r = 1 / a.abs()
#     predicates = []
#     for k in range(mu.shape[0]):
#         # denormalize
#         r_k = (r[k] * scale[k]).item()
#         mu_k = (mu[k] * scale[k] + mean[k]).item()
#         ci = [mu_k - r_k, mu_k + r_k]
#         assert ci[0] < ci[1], 'ci[0] is not less than ci[1]'
#         if ci[0] < vmin[k]:
#             ci[0] = vmin[k]
#         if ci[1] > vmax[k]:
#             ci[1] = vmax[k]
#         # feature selection based on extent range
# #         should_include = r[k] < 1.0 * (x[:,k].max()-x[:,k].min())
#         should_include = not (ci[0] <= vmin[k] and ci[1] >= vmax[k])
#         if should_include:
#             predicates.append(dict(
#                 dim=k, interval=ci
#             ))
#     for p in predicates:
#         print(p)
#     return predicates, mu, a, [accuracy, precision, recall, f1]


## Predicate sequence

In [None]:
# from textwrap import dedent

In [None]:
b = 7
def predict(x, a, mu=0):
    '''
    x - torch tensor with shape [n_data_points, n_features]
    a - torch tensor with shape [n_features]
    '''
    return 1 / (1 + ((a.abs() * (x - mu).abs()).pow(b)).sum(1))

# def predict_gaussian(x, mu=0, b=2):
#     print(x.shape)
#     return torch.exp(-(x-mu).pow(b).sum(1))

gx, gy = torch.meshgrid(torch.linspace(-4,4,65), torch.linspace(-4,4,65), indexing='xy')
xy = torch.stack([gx.reshape(-1), gy.reshape(-1)], dim=1)
z = predict(xy, a=torch.tensor([1/1,1/2])).reshape([65,65])
# z = predict_gaussian(xy, b=4).reshape([40,40])

gx = gx.numpy()
gy = gy.numpy()
z = z.numpy()

plt.figure(figsize=[4,4])
plt.contour(gx, gy, z, levels=[0.1, 0.25, 0.5, 0.75, 0.9])
plt.axis('square');

plt.figure(figsize=[4,2])
plt.plot(gx[0], z[10]);

In [None]:
# !pip install plotly

In [None]:
# !pip install chart-studio

In [None]:
# import chart_studio.plotly as py
import plotly.graph_objs as go

In [None]:
# compute a contour path for plotly:

def get_contour_verts(cn):
    contours = []
    # for each contour line
    for cc in cn.collections:
        paths = []
        # for each separate section of the contour line
        for pp in cc.get_paths():
            xy = []
            # for each segment of that section
            for vv in pp.iter_segments():
                xy.append(vv[0])
            paths.append(np.vstack(xy))
        contours.append(paths)

    return contours

cn = plt.contour(gx, gy, z, levels=[0.5])
contour_verts_2d = get_contour_verts(cn)[0][0]

In [None]:
# !pip install -U kaleido

In [None]:
#colors
# '#C2C973'=muted yellow
# '#1f78b4' C0 blue
surface_color = '#7B3F00'  # chocolate
wireframe_color = '#7B3F00'
predicate_line_color = '#333'
pos_point_color = 'orange'
neg_point_color = '#fff'


# Creating the lines
lines = []
line_marker = dict(color=wireframe_color, width=2, )
skip = 8
for i, j, k in zip(gx[::skip], gy[::skip], z[::skip]):
    lines.append(go.Scatter3d(x=i, y=j, z=k, mode='lines', line=line_marker, opacity=0.3))
for i, j, k in zip(gx.T[::skip], gy.T[::skip], z.T[::skip]):
    lines.append(go.Scatter3d(x=i, y=j, z=k, mode='lines', line=line_marker, opacity=0.3))
    

# lines for the predicate range
zoffset = 0.005
predicate_line_style = dict(color=predicate_line_color, width=3, dash = 'dash')
lines.append(go.Scatter3d(x=[-1,-1], y=[-4,4], z=[zoffset,zoffset], mode='lines', line=predicate_line_style))
lines.append(go.Scatter3d(x=[1,1], y=[-4,4], z=[zoffset,zoffset], mode='lines', line=predicate_line_style))
lines.append(go.Scatter3d(x=[-4,4], y=[-2,-2], z=[zoffset,zoffset], mode='lines', line=predicate_line_style))
lines.append(go.Scatter3d(x=[-4,4], y=[2,2], z=[zoffset,zoffset], mode='lines', line=predicate_line_style))


# contour lines
contour_line_style = dict(color='yellow', width=5)
lines.append(go.Scatter3d(
    x=contour_verts_2d[:,0], 
    y=contour_verts_2d[:,1], 
    z=np.zeros(contour_verts_2d.shape[0])+0.001, 
    mode='lines', line=contour_line_style,
    name='contour_floor'
))
lines.append(go.Scatter3d(
    x=contour_verts_2d[:,0], 
    y=contour_verts_2d[:,1], 
    z=np.zeros(contour_verts_2d.shape[0])+0.5, 
    mode='lines', line=contour_line_style,
    name='0.5 level set'
))


# Layout (Plotly)
layout = go.Layout(
    title='Predicate proxy function',
    width=1000,
    height=1000,
    scene=dict(
        xaxis=dict(
            gridcolor='rgb(255, 255, 255)',
            zerolinecolor='rgb(255, 255, 255)',
            showbackground=True,
            backgroundcolor='rgb(230, 230, 230)'
        ),
        yaxis=dict(
            gridcolor='rgb(255, 255, 255)',
            zerolinecolor='rgb(255, 255, 255)',
            showbackground=True,
            backgroundcolor='rgb(230, 230, 230)'
        ),
        zaxis=dict(
            gridcolor='rgb(255, 255, 255)',
            zerolinecolor='rgb(255, 255, 255)',
            showbackground=True,
            backgroundcolor='rgb(230, 230, 230)',
            nticks=4, 
            range=[0,1.05],
        )
        
    ),
    showlegend=False,
)

# scatter plot
pos = (np.random.rand(20,2)-0.5)*2*[0.5,1]
neg = (np.random.rand(50,2)-0.5)*2*4
neg = neg[~np.logical_and(np.abs(neg[:,0])<1.3, np.abs(neg[:,1])<2)]
names = ["Pattern Points (P)", 'Background Points (B)', '0.5 level set']
scatter_pos = go.Scatter3d(
    name=names[0],
    x=pos[:,0],
    y=pos[:,1],
    z=np.ones(pos.shape[0])+0.01, 
    mode='markers',
    marker=dict(
        color=[pos_point_color]*pos.shape[0],
        symbol=['circle']*pos.shape[0],
        size=[20]*pos.shape[0],
#         line_width=0,
        line=dict(
            width=1,
            color='#000'
        ),
        opacity=1.0,
    )
)
scatter_neg = go.Scatter3d(
    name=names[1],
    x=neg[:,0],
    y=neg[:,1],
    z=np.zeros(neg.shape[0])+0.05, 
    mode='markers',
    marker=dict(
        color=[neg_point_color,]*neg.shape[0],
#         symbol=['x',]*neg.shape[0],
        symbol=['circle']*neg.shape[0],
        size=[20,]*neg.shape[0],
        line=dict(
            width=1,
            color='#000'
        ),
        opacity=1.0,
    )
)

# bump function surface, and contour
colorscale = [
    [0, surface_color], 
    [1, surface_color], 
]
surface = go.Surface(
    z=z, x=gx, y=gy, 
    opacity=0.7, 
    showscale=False,
    colorscale=colorscale,
    surfacecolor=np.ones_like(z),
    lighting=dict(ambient=0.8, roughness=1.0),
#     contours_z=dict(
#         show=False, 
#         width=4, #not working
#         start=0.5, end=0.6, size=0.25,
#         usecolormap=False,
#         color='midnightblue',
#         highlightcolor="limegreen", 
#         project_z=True,
#     )
)


# assemble the figure
fig = go.Figure(
    data=[surface,
        scatter_pos, 
        scatter_neg,
        *lines
    ], 
    layout=layout
)

# aspect ratio
fig.update_scenes(
    aspectmode='manual', 
#     aspectratio=dict(x=1,y=1,z=0.3), # for VIS paper
    aspectratio=dict(x=1,y=1,z=0.4),
)

# axis labels
fig.update_layout(
    scene = dict(
        xaxis_title='x1',
        yaxis_title='x2',
        zaxis_title='f(x1, x2)'
    ),
#     margin=dict(r=20, b=10, l=10, t=10)
)

# legends
fig.update_layout(
    showlegend=True,
    legend=dict(
        orientation="h",
        xanchor="left",
        yanchor="top",
        x=0.10,
        y=0.70,
        bgcolor='#e6e6e6',
        font=dict(
            # family="Courier",
            size=18,
            # color="black"
        ),
        
    )
)
for trace in fig['data']: 
    if (not trace['name'] in names):
        trace['showlegend'] = False

# camera
camera = dict(
    up=dict(x=0, y=0, z=1),
    center=dict(x=0, y=0, z=0),
    eye=dict(x=-1.5, y=1.25, z=0.5)
)
fig.update_layout(scene_camera=camera)


fig.show()


In [None]:
fig.write_image("fig1.pdf")

In [None]:
# fig['data'][1]['name']