In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np
def set_gpu_devices(gpu):
    physical_devices = tf.config.experimental.list_physical_devices('GPU')
    assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
    tf.config.experimental.set_visible_devices(physical_devices[gpu], 'GPU')
    tf.config.experimental.set_memory_growth(physical_devices[gpu], True)
set_gpu_devices(0)

In [11]:
def sequential_concat(x_slice, y_slice, duration=20):
    """Opposite operation of sequential_slice. 
    x_slice's shape will change 
    from (batch * (duration - order_sprt), order_sprt + 1, feat dim )
    to  (batch, (duration - order_sprt), order_sprt + 1, feat dim).
    y changes accordingly.
    Args:
        x_slice: A Tensor with shape (batch * (duration - order_sprt), order_sprt + 1, feat dim). This is the output of models.backbones_lstm.LSTMModel.__call__(inputs, training). 
        y_slice: A Tensor with shape (batch*(duration - order_sprt),).
        duration: An int. 20 for nosaic MNIST.
    Returns:
        x_cocnat: A Tensor with shape (batch, (duration - order_sprt), order_sprt + 1, feat dim).
        y_concat: A Tensor with shape (batch).
    """
    x_shape = x_slice.shape
    order_sprt = int(x_shape[1] - 1)
    batch = int(x_shape[0] / (duration - order_sprt))
    feat_dim = x_shape[-1]

    # Cancat time-sliced, augumented batch
    x_concat = tf.reshape(x_slice, (duration - order_sprt, batch, order_sprt + 1, feat_dim))
    x_concat = tf.transpose(x_concat, [1, 0, 2, 3]) # (batch, duration - order_sprt, order_sprt + 1, feat_dim)
    y_concat = y_slice[:batch]

    return x_concat, y_concat

def _sequentially_calc_binary_llrs(logits_concat):
    """Calculate the frame-by-frame confusion matrices based on the log-likelihood ratios.
    Args:
        logits_concat: A logit Tensor with shape (batch, (duration - order_sprt), order_sprt + 1, 2). This is the output of utils.data_processing.sequential_concat(logit_slice, labels_slice).
    Returns:
        dict_llrs: A dictionary of log-likelihood ratio Tensors. Only the order_sprt-th order, not all order.
    Remark:
        - Binary classification (num classes = 2) is assumed.
    """
    logits_concat_shape = logits_concat.shape

    # Start calc of LLR loss. See the N-th-order SPRT formula.
    order_sprt = int(logits_concat_shape[2] - 1)
    duration = int(logits_concat_shape[1] + order_sprt)
    dict_llrs = dict()

    for iter_frame in range(duration):
        # i.i.d. SPRT (0th-order SPRT)
        if order_sprt == 0:
            llrs_all_frames = logits_concat[:, :, order_sprt, 1] - logits_concat[:, :, order_sprt, 0] # (batch, duration-order_sprt, order_sprt+1, nb_cls=2) -> (batch, duration-order_sprt)
            llrs = tf.reduce_sum(llrs_all_frames[:, :iter_frame+1], -1) # (batch,)
            dict_llrs["llr_{}th_order_frame{:03d}".format("000", iter_frame+1)] = llrs

        # N-th-order SPRT
        else:
            if iter_frame < order_sprt + 1:
                llrs = logits_concat[:, 0, iter_frame, 1] - logits_concat[:, 0, iter_frame, 0] 
                dict_llrs["llr_{:03d}th_order_frame{:03d}".format(order_sprt, iter_frame+1)] = llrs

            else:
                llrs1 = logits_concat[:, :iter_frame, order_sprt, 1] - logits_concat[:, :iter_frame, order_sprt, 0] # (batch, iter_frame)
                llrs1 = tf.reduce_sum(llrs1, -1) # (batch,)
                llrs2 = logits_concat[:, 1:iter_frame, order_sprt-1, 1] - logits_concat[:, 1:iter_frame, order_sprt-1, 0] # (batch, iter_frame-1)
                llrs2 = tf.reduce_sum(llrs2, -1) # (batch,)
                llrs = llrs1 - llrs2 # (batch, )
                dict_llrs["llr_{:03d}th_order_frame{:03d}".format(order_sprt, iter_frame+1)] = llrs

    return dict_llrs

In [32]:
# Example dummy data
batch = 4
duration = 20
order_sprt = 19
nb_cls = 2
feat_dim = nb_cls
assert duration >= order_sprt + 1

#data = [k for k in range(batch*duration*feat_dim)]
data = np.random.rand(batch*duration*feat_dim)
data = np.reshape(data, (batch, duration, feat_dim))
label = [k%nb_cls for k in range(batch)]
print(data) # (batch, duration, feat_dim)
print(label) # (batch,)

[[[0.4555872  0.80769406]
  [0.21401404 0.21298602]
  [0.54932954 0.44270126]
  [0.92474932 0.45696688]
  [0.66159267 0.23012716]
  [0.90822107 0.0534633 ]
  [0.15924537 0.16743213]
  [0.57618239 0.858917  ]
  [0.05886835 0.52406228]
  [0.16078754 0.46109362]
  [0.6392866  0.02212261]
  [0.03224048 0.23299577]
  [0.2233559  0.18707425]
  [0.39689337 0.81043285]
  [0.22554088 0.95713654]
  [0.55684142 0.44890078]
  [0.35851801 0.54158167]
  [0.36457194 0.88120864]
  [0.17070167 0.43887175]
  [0.49051366 0.64076217]]

 [[0.32640182 0.26471594]
  [0.58733699 0.57063145]
  [0.03227289 0.12678464]
  [0.74592963 0.24596626]
  [0.05214763 0.72570692]
  [0.24775064 0.15336312]
  [0.09577495 0.37922493]
  [0.84594573 0.41747837]
  [0.58004797 0.23566674]
  [0.41203506 0.01800845]
  [0.76823589 0.83605781]
  [0.86238905 0.68698813]
  [0.10239206 0.91056564]
  [0.95286417 0.09652983]
  [0.37795248 0.98080644]
  [0.88838314 0.65193546]
  [0.69810418 0.56829116]
  [0.20478651 0.98007408]
  [0.90343

In [33]:
# Slice and concat a batch to make a time-sliced, augumented bathch
for i in range(duration - order_sprt):
    if i == 0:
        data_timeslice = data[:, i:i+order_sprt+1, :]
        label_timeslice = label
    else:
        data_timeslice = tf.concat([data_timeslice, data[:, i:i+order_sprt+1, :]],0)
        label_timeslice = tf.concat([label_timeslice, label], 0)
data_timeslice = tf.cast(data_timeslice, tf.float32)
label_timeslice = tf.cast(label_timeslice, tf.int32)
print(data_timeslice)
print(label_timeslice)

tf.Tensor(
[[[0.4555872  0.8076941 ]
  [0.21401404 0.21298602]
  [0.5493295  0.44270125]
  [0.9247493  0.45696688]
  [0.66159266 0.23012716]
  [0.90822107 0.0534633 ]
  [0.15924537 0.16743213]
  [0.57618237 0.858917  ]
  [0.05886836 0.5240623 ]
  [0.16078754 0.46109363]
  [0.6392866  0.02212261]
  [0.03224048 0.23299578]
  [0.2233559  0.18707426]
  [0.39689335 0.81043285]
  [0.22554088 0.9571365 ]
  [0.55684143 0.4489008 ]
  [0.358518   0.5415817 ]
  [0.36457193 0.88120866]
  [0.17070167 0.43887174]
  [0.49051365 0.64076215]]

 [[0.32640183 0.26471594]
  [0.587337   0.57063144]
  [0.03227289 0.12678464]
  [0.74592966 0.24596626]
  [0.05214763 0.72570693]
  [0.24775064 0.15336312]
  [0.09577495 0.37922493]
  [0.8459457  0.41747838]
  [0.58004797 0.23566674]
  [0.41203505 0.01800845]
  [0.7682359  0.83605784]
  [0.862389   0.6869881 ]
  [0.10239206 0.9105656 ]
  [0.95286417 0.09652983]
  [0.3779525  0.98080647]
  [0.88838315 0.65193546]
  [0.6981042  0.5682912 ]
  [0.20478651 0.9800741 ]

In [34]:
# Concat
logits_concat, labels_concat = sequential_concat(data_timeslice, label_timeslice, duration=duration)
print(logits_concat) # (batch, (duration - order_sprt), order_sprt + 1, nb_cls)
print(labels_concat) # (batch, )

tf.Tensor(
[[[[0.4555872  0.8076941 ]
   [0.21401404 0.21298602]
   [0.5493295  0.44270125]
   [0.9247493  0.45696688]
   [0.66159266 0.23012716]
   [0.90822107 0.0534633 ]
   [0.15924537 0.16743213]
   [0.57618237 0.858917  ]
   [0.05886836 0.5240623 ]
   [0.16078754 0.46109363]
   [0.6392866  0.02212261]
   [0.03224048 0.23299578]
   [0.2233559  0.18707426]
   [0.39689335 0.81043285]
   [0.22554088 0.9571365 ]
   [0.55684143 0.4489008 ]
   [0.358518   0.5415817 ]
   [0.36457193 0.88120866]
   [0.17070167 0.43887174]
   [0.49051365 0.64076215]]]


 [[[0.32640183 0.26471594]
   [0.587337   0.57063144]
   [0.03227289 0.12678464]
   [0.74592966 0.24596626]
   [0.05214763 0.72570693]
   [0.24775064 0.15336312]
   [0.09577495 0.37922493]
   [0.8459457  0.41747838]
   [0.58004797 0.23566674]
   [0.41203505 0.01800845]
   [0.7682359  0.83605784]
   [0.862389   0.6869881 ]
   [0.10239206 0.9105656 ]
   [0.95286417 0.09652983]
   [0.3779525  0.98080647]
   [0.88838315 0.65193546]
   [0.6981042

# binary_sprt_confmx 

In [50]:
al = 3e-2
be = 3e-2

In [51]:
# Calc thresholds
thresh = [np.log(be/(1-al)), np.log((1-be)/al)]
if not ( (thresh[1] >= thresh[0]) and (thresh[1] * thresh[0] < 0) ):
    raise ValueError("thresh must be thresh[1] >= thresh[0] and thresh[1] * thresh[0] < 0. Now thresh = {}".format(thresh))

# Calc log-likelihood ratios
dict_llrs = _sequentially_calc_binary_llrs(logits_concat)

In [52]:
thresh

[-3.4760986898352733, 3.4760986898352733]

In [53]:
dict_llrs

{'llr_019th_order_frame001': <tf.Tensor: id=3020, shape=(4,), dtype=float32, numpy=array([ 0.35210687, -0.06168589, -0.883421  , -0.19616163], dtype=float32)>,
 'llr_019th_order_frame002': <tf.Tensor: id=3029, shape=(4,), dtype=float32, numpy=array([-0.00102802, -0.01670557,  0.07032281,  0.38635477], dtype=float32)>,
 'llr_019th_order_frame003': <tf.Tensor: id=3038, shape=(4,), dtype=float32, numpy=array([-0.10662827,  0.09451175, -0.0805136 , -0.19959724], dtype=float32)>,
 'llr_019th_order_frame004': <tf.Tensor: id=3047, shape=(4,), dtype=float32, numpy=array([-0.46778244, -0.4999634 , -0.4373281 , -0.43428487], dtype=float32)>,
 'llr_019th_order_frame005': <tf.Tensor: id=3056, shape=(4,), dtype=float32, numpy=array([-0.4314655 ,  0.6735593 , -0.33924285,  0.3339158 ], dtype=float32)>,
 'llr_019th_order_frame006': <tf.Tensor: id=3065, shape=(4,), dtype=float32, numpy=array([-0.8547578 , -0.09438752, -0.41665122, -0.44763365], dtype=float32)>,
 'llr_019th_order_frame007': <tf.Tensor:

In [54]:
list_preds = []
list_hittimes = []
list_optout = [] ######################

logits_concat_shape = logits_concat.shape
order_sprt = int(logits_concat_shape[2] - 1)
duration = int(logits_concat_shape[1] + order_sprt)
batch_size = labels_concat.shape[0]

# Truncated Sequential Probability Ratio Test
for iter_batch in range(batch_size):
    for iter_frame in range(duration):
        # Get log-likelihood ratios
        if order_sprt == 0:
            key = "llr_{}th_order_frame{:03d}".format("000", iter_frame+1)
        else:
            key = "llr_{:03d}th_order_frame{:03d}".format(order_sprt, iter_frame+1)

        llr = dict_llrs[key][iter_batch] # scalar

        # Decision: reject null hypothesis (classified to class 1)
        if llr > thresh[1]:
            # Prediction
            list_preds.append(tf.constant(1, dtype=tf.int32))
            # Hitting time
            list_hittimes.append(iter_frame+1)
            break

        # Decision: accept null hypothesis (classified to class 0)
        elif llr < thresh[0]:
            # Prediction
            list_preds.append(tf.constant(0, dtype=tf.int32))
            # Hitting time
            list_hittimes.append(iter_frame+1)
            break

        # Truncate and add to optout list
        elif iter_frame == duration - 1:
            # Prediction
            pred = tf.cast(tf.round(tf.nn.sigmoid(llr)), tf.int32)
            list_preds.append(pred)
            # Hitting time
            list_hittimes.append(iter_frame+1)
            # Add to optout list
            list_optout.append(iter_batch)
                
        # Hold
        else:
            continue

In [55]:
# Confusion matrix
confmx = tf.math.confusion_matrix(labels=labels_concat, predictions=list_preds, num_classes=2, dtype=tf.int32)
# Mean hitting time
mean_hittime = tf.reduce_mean(list_hittimes)

In [56]:
confmx

<tf.Tensor: id=3871, shape=(2, 2), dtype=int32, numpy=
array([[0, 2],
       [0, 2]], dtype=int32)>

In [57]:
mean_hittime

<tf.Tensor: id=3878, shape=(), dtype=int32, numpy=20>

In [58]:
list_optout

[0, 1, 2, 3]

In [None]:
def binary_truncated_sprt_confmx_with_hittime(logits_concat, labels_concat, alpha, beta):
    """Calculate the confusion matrix and the mean hitting time of the truncated Sequential Probability Ratio Test.
    Args:
        logits_concat: A logit Tensor with shape (batch, (duration - order_sprt), order_sprt + 1, 2). This is the output of utils.data_processing.sequential_concat(logit_slice, labels_slice).
        labels_concat: A non-one-hot label Tensor with shape (batch,). This is the output of utils.data_processing.sequential_conclogit_slice, labels_slice).
        alpha: A float number. This is the user-defined false positive rate to be used to calculate thresholds. Note that class 1 is defined as the true class.
        beta: A float number. This is the user-defined false negative rate to be used to calculate thresholds. Note that class 1 is defined as the true class.
    Returns:
        confmx: A confusion matrix Tensor with shape (2, 2).
        mean_hittime: A Tensor. The mean hitting time of a batch.
        list_optout: A list of offsets. Those sequences didn't finish decision and thus truncated.
    Remark:
        - Binary classification (num classes = 2) is assumed.
        - According to Wald's theory,
            false positive rate (alpha) = 0.01
            false negative rate (beta) = 0.01
        is achievable under the Wald approximation (ignore overshoots), if
            thresh[1] (A) = np.log((1-beta)/alpha)
            thresh[0] (B) = np.log(beta/(1-alpha)) .
        - For example, 
            thresh = [-4.59511985013459, 4.59511985013459] for alpha = beta = 1e-2.
            thresh = [-23.025850929840455, 23.025850929840455] for alpha = beta = 1e-10
    """
    # Calc thresholds
    thresh = [np.log(be/(1-al)), np.log((1-be)/al)]
    if not ( (thresh[1] >= thresh[0]) and (thresh[1] * thresh[0] < 0) ):
        raise ValueError("thresh must be thresh[1] >= thresh[0] and thresh[1] * thresh[0] < 0. Now thresh = {}".format(thresh))

    # Calc log-likelihood ratios
    dict_llrs = _sequentially_calc_binary_llrs_v2(logits_concat)
    
    # Truncated Sequential Probability Ratio Test
    list_preds = []
    list_hittimes = []
    list_optout = []

    logits_concat_shape = logits_concat.shape
    order_sprt = int(logits_concat_shape[2] - 1)
    duration = int(logits_concat_shape[1] + order_sprt)
    batch_size = labels_concat.shape[0]

    for iter_batch in range(batch_size):
        for iter_frame in range(duration):
            # Get log-likelihood ratios
            if order_sprt == 0:
                key = "llr_{}th_order_frame{:03d}".format("000", iter_frame+1)
            else:
                key = "llr_{:03d}th_order_frame{:03d}".format(order_sprt, iter_frame+1)

            llr = dict_llrs[key][iter_batch] # scalar

            # Decision: reject null hypothesis (classified to class 1)
            if llr > thresh[1]:
                # Prediction
                list_preds.append(tf.constant(1, dtype=tf.int32))
                # Hitting time
                list_hittimes.append(iter_frame+1)
                break

            # Decision: accept null hypothesis (classified to class 0)
            elif llr < thresh[0]:
                # Prediction
                list_preds.append(tf.constant(0, dtype=tf.int32))
                # Hitting time
                list_hittimes.append(iter_frame+1)
                break

            # Truncate and add to optout list
            elif iter_frame == duration - 1:
                # Prediction
                pred = tf.cast(tf.round(tf.nn.sigmoid(llr)), tf.int32)
                list_preds.append(pred)
                # Hitting time
                list_hittimes.append(iter_frame+1)
                # Add to optout list
                list_optout.append(iter_batch)

            # Hold
            else:
                continue
    
    # Confusion matrix
    confmx = tf.math.confusion_matrix(labels=labels_concat, predictions=list_preds, num_classes=2, dtype=tf.int32)
    # Mean hitting time
    mean_hittime = tf.reduce_mean(list_hittimes)
    
    return confmx, mean_hittime, list_optout

In [None]:
# retry, optout_SPRT