In [1]:
import sys
import tqdm
import numpy as np
from scipy.sparse import csr_matrix
from pycox.datasets import kkbox_v1

sys.path.append("../")
sys.path.append("../../")

from tools import preprocess_kkbox

In [2]:
pairs_per_sample = 3

In [3]:
def get_time_bins(t, n_time_bins):
    """
    Get equal size bins
    """
    percent_list = np.linspace(0, 100, n_time_bins + 1, dtype=np.int)
    bins = np.percentile(a=t, q=percent_list[1:-1])
    q = np.digitize(t, bins)
    if n_time_bins != np.unique(q).shape[0]:
        raise Exception("There is too large value for n_time_bins selected")
    return q

In [4]:
df_train = kkbox_v1.read_df(subset='train')
x, t, y = preprocess_kkbox(df_train)
q = get_time_bins(t, 12)
n = t.shape[0]
m = np.max(t)

y_csr = csr_matrix(y)
y_csr = y_csr.transpose()
y_nonzero_ind = y_csr.nonzero()[0]
y_zero_ind = np.setdiff1d(np.arange(0, y.shape[0]), y_nonzero_ind)
comparability_m = csr_matrix((n, n))
target_m = csr_matrix((n, n))
dq_m = csr_matrix((n, n))
dq_m_0 = csr_matrix((n, n))

t_uniq, uniq_indices = np.unique(t, return_inverse=True)

# we will overwrite this matrix with each t_cur (on each step contains indices of elements for which t < t_cur)
t_less_ti = csr_matrix((n, 1))
# we will overwrite this matrix with each t_cur (on each step contains indices of elements for which t > t_cur)
t_more_ti = csr_matrix(np.ones((n, 1)))
ind_t_un = np.where(t_uniq == 1)[0]
ind_t_cur = np.where(uniq_indices == ind_t_un)[0]

# we will iterate through observations with t == t_cur
for t_cur in tqdm.notebook.tqdm(range(2, m)):
    
    # keep only those which are less than t_cur - 1
    t_less_ti += csr_matrix(
        (
            np.ones(ind_t_cur.shape[0]), 
            (
                ind_t_cur, 
                np.repeat(0, ind_t_cur.shape[0])
            )
        ), 
        shape=(n, 1)
    )
    # keep only those which are greater than t_cur - 1
    t_more_ti -= csr_matrix(
        (
            np.ones(ind_t_cur.shape[0]), 
            (
                ind_t_cur, 
                np.repeat(0, ind_t_cur.shape[0])
            )
        ), 
        shape=(n, 1)
    )

    ind_t_un = np.where(t_uniq == t_cur)[0]
    ind_t_cur = np.where(uniq_indices == ind_t_un)[0]
    
    t_cur_y_0 = np.intersect1d(y_zero_ind, ind_t_cur)
    # find comparable examples for observations with t == t_cur and y = 0
    res_0 = y_csr.multiply(t_less_ti)
    # t_cur_y_0 are comparable with res[res == 1]
    m2 = t_cur_y_0.shape[0]
    if res_0.count_nonzero() >= m2 * pairs_per_sample:
        # sample pair for each example
        ind_nonzero_sampled = np.random.choice(res_0.nonzero()[0], size=m2*pairs_per_sample, replace=False)
        final_comp_pairs = csr_matrix(
            (
                np.ones((pairs_per_sample * m2, )), 
                (
                    np.repeat(t_cur_y_0, pairs_per_sample), 
                    ind_nonzero_sampled
                )
            ), 
            shape=(n, n))
        comparability_m += final_comp_pairs
        target_m += final_comp_pairs
        dq_m += csr_matrix(
            (
                q[ind_nonzero_sampled] - q[np.repeat(t_cur_y_0, pairs_per_sample)],
                (
                    np.repeat(t_cur_y_0, pairs_per_sample), 
                    ind_nonzero_sampled
                )
            ), 
            shape=(n, n))
        dq_m_0 += csr_matrix(
            (
                (q[ind_nonzero_sampled] == q[np.repeat(t_cur_y_0, pairs_per_sample)]).astype(int),
                (
                    np.repeat(t_cur_y_0, pairs_per_sample), 
                    ind_nonzero_sampled
                )
            ), 
            shape=(n, n))
        

    t_cur_y_0 = np.intersect1d(y_nonzero_ind, ind_t_cur)
    # find comparable examples for observations with t == t_cur and y = 1
    res_1 = t_more_ti.multiply(csr_matrix(np.ones((n, 1))) - y_csr) + y_csr
    # t_cur_y_0 are comparable with res[res == 1]
    m2 = t_cur_y_0.shape[0] 
    if res_1.count_nonzero() >= m2 * pairs_per_sample:
        # sample pair for each example
        ind_nonzero_sampled = np.random.choice(res_1.nonzero()[0], size=m2*pairs_per_sample, replace=False)
        comparability_m += csr_matrix(
            (
                np.ones((pairs_per_sample * m2, )), 
                (
                    np.repeat(t_cur_y_0, pairs_per_sample), 
                    ind_nonzero_sampled
                )
            ), 
            shape=(n, n))
        target_m += csr_matrix(
            (
                np.reshape(res_0[ind_nonzero_sampled, :].toarray(), (ind_nonzero_sampled.shape)),
                (
                    np.repeat(t_cur_y_0, pairs_per_sample), 
                    ind_nonzero_sampled
                )
            ), 
            shape=(n, n))
        dq_m += csr_matrix(
            (
                q[ind_nonzero_sampled] - q[np.repeat(t_cur_y_0, pairs_per_sample)],
                (
                    np.repeat(t_cur_y_0, pairs_per_sample), 
                    ind_nonzero_sampled
                )
            ), 
            shape=(n, n))
        dq_m_0 += csr_matrix(
            (
                (q[ind_nonzero_sampled] == q[np.repeat(t_cur_y_0, pairs_per_sample)]).astype(int),
                (
                    np.repeat(t_cur_y_0, pairs_per_sample), 
                    ind_nonzero_sampled
                )
            ), 
            shape=(n, n))

HBox(children=(FloatProgress(value=0.0, max=818.0), HTML(value='')))




In [5]:
comparability_m.count_nonzero()

5311992

In [6]:
target_m.count_nonzero()

2924241

In [7]:
dq_m.count_nonzero()

4898231

In [8]:
dq_m_0.count_nonzero()

413761

In [16]:
t_rows, t_cols = target_m.nonzero()

In [17]:
t_rows[:10]

array([1, 1, 1, 2, 2, 2, 4, 4, 6, 7], dtype=int32)

In [18]:
t_cols[:10]

array([ 381227,  852183, 1477193,  537121,  565235,  685680,  213639,
        613665,  663279,  227555], dtype=int32)

In [19]:
target_m[t_rows[:10], t_cols[:10]]

matrix([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])

In [106]:
t[6], t[951601]

(96, 29)

In [107]:
y[6], y[951601]

(1, 1)

In [122]:
t_rows, t_cols = comparability_m.nonzero()

In [129]:
t_rows[:10]

array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4], dtype=int32)

In [130]:
t_cols[:10]

array([1378213, 1649047,  402448, 1715926,  775703, 1454283, 1416750,
       1549935,  458694,  641318], dtype=int32)

In [131]:
target_m[t_rows[:10], t_cols[:10]]

matrix([[0., 0., 0., 1., 1., 1., 0., 1., 1., 0.]])

In [135]:
i = 0
j = 1649047

In [136]:
t[i], t[j]

(5, 183)

In [137]:
y[i], y[j]

(1, 0)

In [138]:
dq_m[i].toarray()[:, j]

array([6.])