<a href="https://colab.research.google.com/github/archqua/pipeline_training/blob/master/mnist_experiment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Multi-Step Classification Pipeline with Human Feedback

## 1. Problem Statement

Given an input input $x \in \mathcal{X}$, we want to predict a final output $y \in \Delta^K$ through a series of classification steps, with human feedback on the final output.

## 2. Pipeline Definition

Let the pipeline consist of $N$ classifiers !точно ли только классификаторы?!, denoted as $f_1, f_2, ..., f_N$. !запятые заменил на декартово произведение!

- $f_1: \mathcal{X} \rightarrow \mathcal{Y}_1$ (e.g., digit recognition)
- $f_2: \mathcal{X} \times \mathcal{Y}_1 \rightarrow \mathcal{Y}_2$
- ...
- $f_N: \mathcal{X} \times \mathcal{Y}_1 \times \ldots \times \mathcal{Y}_{N-1} \rightarrow \Delta^K$

Where $\mathcal{Y}_i$ is the output space of the $i$-th classifier.

!предлагаю добавит что-то типа:!
For convenience we will sometimes use the following:
- $p_j: \mathcal{X}\times\mathcal{Y_1}\times\ldots\times\mathcal{Y_{j-1}} \rightarrow \mathcal{X}\times\mathcal{Y_1}\times\ldots\times\mathcal{Y_{j}}$
with $p_j\left(x, y_1, \ldots y_{j-1}\right) = \left(x, y_1, \ldots, y_{j-1}, f_j\left(x, y_1, \ldots, y_{j-1}\right)\right)$

## 3. Forward Pass

For an input $x$:

$y_1 = f_1(x)$

$y_2 = f_2(x, y_1)$

$\dots$

$y_N = f_N(x, y_1, \ldots y_{N-1}) \in \Delta^K$

## 4. Human Feedback

Let $h: \Delta^K \rightarrow \{0, 1\}$ !добавил ${}^K$! represent human feedback, where:

$h(y_N) = \begin{cases}
1 & \text{if human confirms } y_N \text{ is correct} \\
0 & \text{if human denies } y_N \text{ is correct}
\end{cases}$

## 5. Error Attribution

Let $e_i$ be the probability that classifier $f_i$ made an error:

$e_i = P(\text{error in } f_i | h(y_N) = 0)$

We can estimate $e_i$ based on classifier confidence or historical performance.
Note, we consider that error in classifier could not fix errors in previous classifiers. Such cases could be addressed further.

## 6. Learning Objective

Our goal is to minimize the error of the entire pipeline:

$\min_{\theta_1, ..., \theta_N} \mathbb{E}_{x \sim \mathcal{X}}[L(f_N(f_{N-1}(...f_1(x)...)), h(y_N))]$

!эта композиция корректна только с заменой $f_j$ на $p_j$ ($j < N$)
и её мб проще записать как $(f_N \circ p_{N-1} \circ \ldots \circ p_1)(x)$!

Where $\theta_i$ are the parameters of classifier $f_i$, and $L$ is a suitable loss function (e.g., binary cross-entropy).

## 7. White box classifiers with differentiable functions

### 7.1 Parameter Updates
If we can compute gradients, we update the parameters of each classifier, when $h(y_N) = 0$:

$\theta_i \leftarrow \theta_i - \alpha \cdot e_i \cdot \nabla_{\theta_i} L_i$

Where $\alpha$ is the learning rate, and $L_i$ is the loss specific to classifier $f_i$.

### 7.2 End-to-End Variant

In this case, we can formulate this as an end-to-end learning problem:

$f: \mathcal{X} \rightarrow \{0, 1\}$

With the learning objective:

$\min_{\theta} \mathbb{E}_{x \sim \mathcal{X}}[L(f(x), h(f(x)))]$

Where $\theta$ are the parameters of the end-to-end model $f$.

## 8. Parameter Updates for Black-box Classifiers

### 8.1 Parameter Updates
When dealing with black-box classifiers, we can't compute gradients directly. Instead, we can use gradient-free optimization methods. Here are three approaches:

### a) Bayesian Optimization

We can model the performance of each classifier $f_i$ as a function of its parameters $\theta_i$:

$p_i(\theta_i) = P(f_i \text{ is correct} | \theta_i)$

We then use Bayesian optimization to find the optimal parameters:

$\theta_i^* = \arg\max_{\theta_i} \mathbb{E}[p_i(\theta_i)]$

This involves:

1.  Building a probabilistic model (e.g., Gaussian Process) of $p_i(\theta_i)$
2.  Using an acquisition function to decide which $\theta_i$ to try next
3.  Updating the model based on observed performance

### b) Evolutionary Strategies

We can use evolutionary algorithms to optimize the parameters:

1.  Initialize a population of parameter sets: ${\theta_i^1, \theta_i^2, ..., \theta_i^M}$
2.  Evaluate the fitness of each set based on classifier performance
3.  Select the best-performing sets
4.  Generate new parameter sets through crossover and mutation
5.  Repeat steps 2-4 for multiple generations

The fitness function could be:

$F(\theta_i) = \sum_{x \in \mathcal{X}} h(y_N) \cdot (1 - e_i)$

Where $h(y_N)$ is the human feedback and $e_i$ is the error attribution for classifier $i$.

### c. Online Learning for Black-box Classifiers

For online learning scenarios, we can use multi-armed bandit algorithms:

1.  Treat each classifier as having multiple "arms" (parameter configurations)
2.  Use algorithms like Upper Confidence Bound (UCB) or Thompson Sampling to balance exploration and exploitation
3.  Update the probability distribution over arms based on observed performance

The reward for each arm could be:

$r(\theta_i) = h(y_N) \cdot (1 - e_i)$

### 8.2 End-to-End Black-box Optimization

For the end-to-end variant with a black-box model:

$f: \mathcal{X} \rightarrow {0, 1}$

We can use similar techniques (Bayesian Optimization, Evolutionary Strategies, or Bandit Algorithms) to optimize the entire pipeline as a single black-box function.

The objective function becomes:

$J(\theta) = \mathbb{E}_{x \sim \mathcal{X}}[h(f(x; \theta))]$

Where $\theta$ are the parameters of the end-to-end model $f$.

## 9. Pipeline with Fit/Predict Classifiers

*Все, что ниже нужно править.*

When classifiers only expose fit and predict functions, we need to adjust our approach. Let's define these functions for each classifier $f_i$:

-   $fit_i(X_i, Y_i) \rightarrow f_i$: Trains the classifier on input data $X_i$ and labels $Y_i$
-   $predict_i(X_i) \rightarrow \hat{Y}_i$: Makes predictions on input data $X_i$

### 9.1 Data Flow

For a pipeline with N classifiers:

!здесь та же неточность, что и в композиции выше!

1.  $f_1: fit_1(X, Y_1), predict_1(X) \rightarrow \hat{Y}_1$
2.  $f_2: fit_2(\hat{Y}_1, Y_2), predict_2(\hat{Y}_1) \rightarrow \hat{Y}_2$
3.  ...
4.  $f_N: fit_N(\hat{Y}_{N-1}, Y_N), predict_N(\hat{Y}_{N-1}) \rightarrow \hat{Y}_N$

Where $X$ is the original input data, and $Y_i$ are the labels for each classifier.

### 9.2 Training Process

1.  Initial Training:
    -   Train $f_1$ on the original data: $fit_1(X, Y_1)$
    -   Generate predictions: $\hat{Y}_1 = predict_1(X)$
    -   Train $f_2$ on these predictions: $fit_2(\hat{Y}_1, Y_2)$
    -   Continue this process for all classifiers
2.  Feedback Incorporation:
    -   Collect human feedback $h(\hat{Y}_N)$ on the final predictions
    -   Create feedback-adjusted datasets for each classifier

### 9.3 Feedback-Adjusted Training

#### 9.3.1 Only positive examples
For each classifier $f_i$, create a feedback-adjusted dataset:

$X_i' = {\hat{Y}_{i-1}[j] \mid h(\hat{Y}_N[j]) = 1}$
$Y_i' = {Y_i[j] \mid h(\hat{Y}_N[j]) = 1}$

Then retrain: $fit_i(X_i', Y_i')$

#### 9.3.2 Incorporating negative examples

Instead of using only positive examples, we'll use both positive and negative feedback to create a more balanced and informative training set. We'll also introduce a weighting scheme to control the influence of each type of feedback.


For each classifier $f_i$, create a feedback-adjusted dataset:

$X_i' = {\hat{Y}_{i-1}[j] \mid j \in \text{all examples}}$
$Y_i' = {Y_i'[j] \mid j \in \text{all examples}}$
$W_i' = {w_i[j] \mid j \in \text{all examples}}$

Where $Y_i'[j]$ and $w_i[j]$ are defined as:

$Y_i'[j] = \begin{cases} Y_i[j] & \text{if } h(\hat{Y}_N[j]) = 1 \text{ (positive feedback)} \\ \text{uncertain} & \text{if } h(\hat{Y}_N[j]) = 0 \text{ (negative feedback)} \end{cases}$

$w_i[j] = \begin{cases} 1 + \alpha & \text{if } h(\hat{Y}_N[j]) = 1 \text{ (positive feedback)} \\ \beta & \text{if } h(\hat{Y}_N[j]) = 0 \text{ (negative feedback)} \end{cases}$

Here, $\alpha$ and $\beta$ are hyperparameters that control the emphasis on positive and negative feedback respectively, where $0 \leq \alpha, \beta \leq 1$.

#### 9.3.1 Handling Negative Feedback

!что-то с нумерацией -- здесь 9.3.2.1 или 9.3.3?!

For examples with negative feedback, we have several options:

1.  Uncertainty Labeling: Mark these examples as "uncertain" and use a loss function that doesn't penalize predictions for uncertain labels.
2.  Label Flipping: If the task is binary classification, we can flip the label:
$Y_i'[j] = 1 - Y_i[j]$ if $h(\hat{Y}_N[j]) = 0$
3.  Soft Labeling: Assign probabilistic labels based on the classifier's current predictions and the negative feedback: $Y_i'[j] = (1 - \lambda) \cdot Y_i[j] + \lambda \cdot (1 - Y_i[j])$ if $h(\hat{Y}_N[j]) = 0$ Where $\lambda$ is a hyperparameter controlling the strength of label adjustment.
4.  Exclusion: In some cases, it might be best to exclude negative feedback examples from training certain classifiers if we can't determine the correct label.

#### 9.3.2 Training Process

The training process now depends on how we handle negative feedback:

1.  For Uncertainty Labeling: Implement a custom loss function that ignores uncertain labels.
2.  For Label Flipping or Soft Labeling: $fit_i(X_i', Y_i', sample\_weight=W_i')$
3.  For Exclusion: $fit_i(X_i'positive, Y_i'positive, sample_weight=W_i'positive)$

#### 9.3.3 Error Attribution in Multi-Step Pipelines

For multi-step pipelines, we still need to attribute errors to specific classifiers:

$e_i[j] = P(\text{error in } f_i \mid h(\hat{Y}_N[j]) = 0)$

Adjust the weights and handling of negative feedback based on these error attributions:

$w_i[j] = \begin{cases} 1 + \alpha & \text{if } h(\hat{Y}_N[j]) = 1 \\ \beta \cdot e_i[j] & \text{if } h(\hat{Y}_N[j]) = 0 \end{cases}$

This approach concentrates the impact of negative feedback on the classifiers most likely to have caused the error.

### 9.4 Inferring Positive Examples from Negative Feedback using EM Algorithm

We can use the Expectation-Maximization (EM) algorithm to infer likely positive examples from instances where we received negative feedback. This approach is particularly useful when we have a multi-class problem and negative feedback only tells us that the predicted class was incorrect, but doesn't specify which class would have been correct.

#### 9.4.1 EM Algorithm for Feedback Incorporation

Let's define our problem:

-   $X = {x_1, ..., x_N}$: Our set of examples
-   $Y = {y_1, ..., y_K}$: The set of possible classes
-   $F = {f_1, ..., f_N}$: Feedback for each example (1 for positive, 0 for negative)
-   $\theta$: Parameters of our classifier

The EM algorithm proceeds as follows:

1.  Initialization:
    -   Initialize $\theta$ using the current classifier parameters
2.  E-step: For each example $x_i$ with negative feedback ($f_i = 0$): Calculate $P(y_k | x_i, \theta)$ for all classes $k$ except the predicted class.
3.  M-step: Update $\theta$ to maximize the expected log-likelihood: $\theta = \arg\max_\theta \sum_{i: f_i=0} \sum_{k \neq predicted} P(y_k | x_i, \theta) \log P(x_i, y_k | \theta)$
4.  Repeat steps 2-3 until convergence or for a fixed number of iterations.

#### 9.4.2 Incorporating EM Results into Training

After running the EM algorithm:

1.  For positive feedback examples: Use the original labels and weights as before.
2.  For negative feedback examples:
    -   Use the class probabilities inferred by EM as soft labels.
    -   Adjust weights based on the confidence of these inferred labels.

$X_i' = {\hat{Y}_{i-1}[j] \mid j \in \text{all examples}}$ $Y_i' = \begin{cases} Y_i[j] & \text{if } f_j = 1 \ P(y | x_j, \theta) & \text{if } f_j = 0 \end{cases}$

$W_i' = \begin{cases} 1 + \alpha & \text{if } f_j = 1 \ \beta \cdot \max_k P(y_k | x_j, \theta) & \text{if } f_j = 0 \end{cases}$

Where $P(y | x_j, \theta)$ is the probability distribution over classes inferred by the EM algorithm.

#### 9.4.3 Training Process

Train the classifier using these EM-inferred soft labels and weights:

$fit_i(X_i', Y_i', sample_weight=W_i')$

Note: Ensure your classifier can handle soft labels (probability distributions) as targets. If not, you may need to implement a custom loss function.

#### 9.4.4 Iterative Refinement

This process can be done iteratively:

1.  Run the EM algorithm to infer labels for negative feedback examples.
2.  Train the classifier using both positive feedback and EM-inferred examples.
3.  Use the updated classifier to get new predictions.
4.  Repeat the process with new feedback.

### 9.5 Online Learning

For online learning, we can use a sliding window or incremental learning approach:

1.  Maintain a buffer of recent data points and their human feedback
2.  Periodically retrain classifiers on this buffer: $fit_i(X_i^{buffer}, Y_i^{buffer})$
3.  For classifiers that support partial_fit: $partial\_fit_i(X_i^{new}, Y_i^{new})$

### 9.6 Optimization Strategies

1.  Hyperparameter Tuning:
    -   Use techniques like Random Search or Bayesian Optimization to tune hyperparameters of each $fit_i$ function
    -   Objective: Maximize pipeline accuracy based on human feedback
2.  Ensemble Methods:
    -   Train multiple versions of each classifier with different hyperparameters
    -   Combine their predictions (e.g., voting, averaging)
3.  Data Augmentation:
    -   Generate synthetic data points based on confident predictions and human feedback
    -   Use these to augment training data for each classifier

### 9.7 Performance Metric

Define a performance metric based on human feedback:

$Performance = \frac{1}{|X|} \sum_{x \in X} h(predict_N(predict_{N-1}(...predict_1(x)...)))$

Optimize this metric through iterative training and feedback incorporation.


In [31]:
parameters = dict(
    seed  = 12309,
    batch = 128,
    metrics = ["accuracy",],
    drift = "flip",
    alg = "greedy",
)
for p in parameters:
    if p not in locals():
        v = parameters[p]
        if isinstance(v, str):
            exec(f"{p} = '{v}'")
        else:
            exec(f"{p} = {v}")


In [32]:
import tensorflow as tf
from tensorflow.keras import datasets, layers, models

tf.random.set_seed(seed)


In [33]:
def printeval(evvals, evkeys=["loss"] + metrics):
    for k, v in zip(evkeys, evvals):
        print(f"{k}:\t{v:.4f}")

def flipaug(img, lbl):
    return tf.image.random_flip_left_right(img), lbl


def drifted(dataset, seed=seed, drift=drift):
    drift = drift.lower()
    if drift == "flip":
        return dataset.map(flipaug)
    else:
        raise ValueError(f"Drift {drift} is unknown")


In [34]:
mnist = datasets.mnist

Xy_train, Xy_val = mnist.load_data()
def mapimg(img, label):
    return tf.image.convert_image_dtype(img, dtype=tf.float32), label

train_ref = (
    tf.data.Dataset.from_tensor_slices(Xy_train)
    .shuffle(Xy_train[0].shape[0])
    .batch(batch)
    .map(mapimg)
)
val_ref = (
    tf.data.Dataset.from_tensor_slices(Xy_val)
    .shuffle(Xy_val[0].shape[0])
    .batch(batch)
    .map(mapimg)
)
train_drifted = drifted(train_ref)
val_drifted = drifted(val_ref)


In [35]:
def greedy_loss_fn(
    model,
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    patience=None,
    factoring_layer="factor",
):
    factoring_layer = model.get_layer(factoring_layer)
    def wrapped(trg, logit):
        inferred_lbls, miss_mask, diff_mask = factoring_layer.greedy_lblprop(
            tf.stop_gradient(logits), trg, patience=patience,
        )
        match_mask = ~miss_mask
        return loss(
            tf.boolean_mask(inferred_lbls, match_mask),
            tf.boolean_mask(logits, match_mask),
        )
    return wrapped

class Factor(layers.Layer):
    def __init__(self, base=2, patience=1, name=None):
        super().__init__(name=name)
        self.base = tf.constant(base, dtype=tf.int32, shape=())
        self.patience = patience

    def call(self, logits):
        return self.factor(tf.argmax(logits, axis=-1))

    def factor(self, labels):
        return tf.cast(tf.keras.ops.mod(labels, self.base), tf.int32)

    def refactor(self, logits, miss_mask, ninf=-1.0e+06):
        am = tf.argmax(logits, axis=-1, output_type=tf.int32)
        mask = am[:, tf.newaxis] == tf.range(logits.shape[-1], dtype=tf.int32)[tf.newaxis, :]
        mask &= miss_mask[:, tf.newaxis]
        logits = tf.where(mask, ninf, logits)
        return self(logits), logits

    def greedy_lblprop(self, logits, trg, patience=None):
        # this line is entirely to satisfy tf static checker
        trg = tf.cast(trg, tf.int32)
        if patience is None:
            patience = self.patience
        assert patience >= 0, f"for greedy label propagation specify patience >= 0, not {patience}"
        guess = self(logits)
        first_miss_mask = guess != trg
        miss_mask = first_miss_mask
        # this implementation is suboptimal
        # because it recomputes good values
        for d in range(patience):
            guess, logits = self.refactor(logits, miss_mask)
            miss_mask = guess != trg
        diff_mask = first_miss_mask ^ miss_mask
        # we could smooth over miss_mask, not simply argmax
        return tf.argmax(logits, axis=-1, output_type=tf.int32), miss_mask, diff_mask

# input
inp = tf.keras.Input((28, 28, 1), name="img")
# cnn
cnn = tf.keras.Sequential([
    layers.InputLayer((28, 28, 1), name="inp"),
    layers.Conv2D(2, (5, 5), activation='relu', name="conv1"),
    layers.MaxPooling2D((2, 2), name="pool1"),
    layers.Conv2D(4, (5, 5), activation='relu', name="conv2"),
    layers.MaxPooling2D((2, 2), name="pool2"),
    layers.Conv2D(10, (4, 4), activation='relu', name="conv3"),
    layers.Flatten(name="flatten"),
    layers.Dense(10, name="logit"),
], name="logit")
# logit = layers.Conv2D(2, (5, 5), activation='relu', name="conv1")(inp)
# logit = layers.MaxPooling2D((2, 2), name="pool1")(logit)
# logit = layers.Conv2D(4, (5, 5), activation='relu', name="conv2")(logit)
# logit = layers.MaxPooling2D((2, 2), name="pool2")(logit)
# logit = layers.Conv2D(10, (4, 4), activation='relu', name="conv3")(logit)
# logit = layers.Flatten(name="flatten")(logit)
# outputs
logit = cnn(inp)
factor = Factor(base=2, name="factor")(logit)

pp = tf.keras.Model(inp, [factor, logit], name="Pipeline")
pp.summary()



In [36]:
def factor_labels(model, factoring_layer="factor"):
    factoring_layer = model.get_layer(factoring_layer)
    # @tf.function
    def wrapped(img, lbl):
        return img, factoring_layer.factor(lbl)
    return wrapped

train_drifted_factored = train_drifted.map(factor_labels(pp))
val_drifted_factored = val_drifted.map(factor_labels(pp))


In [37]:
pp.get_layer("logit").compile(
    optimizer="adam",
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=metrics,
)
pp.get_layer("logit").fit(
    train_ref, epochs=1,
    validation_data=val_ref,
)
printeval(pp.get_layer("logit").evaluate(val_ref))


[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m30s[0m 59ms/step - accuracy: 0.4775 - loss: 1.5542 - val_accuracy: 0.8780 - val_loss: 0.3982
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 21ms/step - accuracy: 0.8742 - loss: 0.4048
loss:	0.3982
accuracy:	0.8780


In [38]:
printeval(pp.get_layer("logit").evaluate(val_drifted))

[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 13ms/step - accuracy: 0.5945 - loss: 2.0370
loss:	2.1133
accuracy:	0.5846


In [None]:
def pipeline_train_loop(
    pipeline,
    train_dataset,
    optimizer,
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics="accuracy",
    epochs=1,
    val_dataset=None,
    early_stopping_rounds=None,
    use_tqdm=True,
):
    # use tf.keras.metrics.get("Accuracy")
    pass

def pipeline_eval_loop(
    pipeline,
    val_dataset,
    loss,
    metrics,
    use_tqdm=True,
):
    pass


In [39]:
pp.compile(
#     optimizer='adam',
#     loss=greedy_loss_fn(pp),
    metrics=metrics,
)

# pp.fit(train_drifted_factored, epochs=1, validation_data=val_drifted_factored)
from tqdm.auto import tqdm

optimizer = tf.keras.optimizers.Adam()
epochs = 1
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))

    # Iterate over the batches of the dataset.
    for step, (X_batch, y_batch) in enumerate(tqdm(train_drifted_factored)):
        with tf.GradientTape() as tape:
            factors, logits = pp(X_batch, training=True)

            loss_value = greedy_loss_fn(pp)(y_batch, logits)
            for metric in pp.metrics:
                if metric.name != "loss":
                    metric.update_state(y_batch, factors)


        grads = tape.gradient(loss_value, pp.trainable_weights)

        optimizer.apply_gradients(zip(grads, pp.trainable_weights))

        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %s samples" % ((step + 1) * batch))
            for metric in pp.metrics:
                if metric.name != "loss":
                    m = metric.result()
                    for mk, mv in m.items():
                        print(f"{mk}: {mv:.4f}")



Start of epoch 0


  0%|          | 0/469 [00:00<?, ?it/s]

Training loss (for one batch) at step 0: 0.2363
Seen so far: 128 samples
accuracy: 0.5829
factor_accuracy: 1.0000
Training loss (for one batch) at step 200: 0.3857
Seen so far: 25728 samples
accuracy: 0.4407
factor_accuracy: 0.6866
Training loss (for one batch) at step 400: 0.3750
Seen so far: 51328 samples
accuracy: 0.3951
factor_accuracy: 0.7232


In [77]:
# logits = tf.convert_to_tensor([
#     list(range(10)),  # odd, even
#     [-i for i in range(10)],  # even, odd
#     [0, 100, 0, 99, 0, 98, 0, 97, 0, 96],  # odd, odd
#     [100, 0, 99, 0, 98, 0, 97, 0, 96, 0],  # even, even
# ], dtype=tf.float32)
# # test 1
# expected_new_labels = tf.convert_to_tensor([8, 0, 3, 0], dtype=tf.int32)
# expected_miss_mask = tf.convert_to_tensor([False, False, True, False])
# expected_diff_mask = tf.convert_to_tensor([True, False, False, False])
# actual_new_labels, actual_miss_mask, actual_diff_mask =greedy_lblprop(
#     pp, logits, tf.convert_to_tensor([0, 0, 0, 0], dtype=tf.int32), _logits=True,
# )
# assert all(expected_new_labels == actual_new_labels)
# assert all(expected_miss_mask == actual_miss_mask)
# assert all(expected_diff_mask == actual_diff_mask)
# # test 2
# expected_new_labels = tf.convert_to_tensor([9, 1, 1, 2], dtype=tf.int32)
# expected_miss_mask = tf.convert_to_tensor([False, False, False, True])
# expected_diff_mask = tf.convert_to_tensor([False, True, False, False])
# actual_new_labels, actual_miss_mask, actual_diff_mask =greedy_lblprop(
#     pp, logits, tf.convert_to_tensor([1, 1, 1, 1], dtype=tf.int32), _logits=True,
# )
# assert all(expected_new_labels == actual_new_labels)
# assert all(expected_miss_mask == actual_miss_mask)
# assert all(expected_diff_mask == actual_diff_mask)
