<a href="https://colab.research.google.com/github/archqua/pipeline_training/blob/master/mnist_experiment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
parameters = dict(
    seed  = 12309,
    batch = 32,
    metrics = ["accuracy",],
    drift = "flip",
    alg = "greedy",
)
for p in parameters:
    if p not in locals():
        v = parameters[p]
        if isinstance(v, str):
            exec(f"{p} = '{v}'")
        else:
            exec(f"{p} = {v}")


In [None]:
import tensorflow as tf
from tensorflow.keras import datasets, layers, models

tf.random.set_seed(seed)


In [None]:
def rotaug(img, lbl):
    return tf.image.random_flip_left_right(img), lbl


def drifted(dataset, seed=seed, drift=drift):
    drift = drift.lower()
    if drift == "flip":
        return dataset.map(rotaug)
    else:
        raise ValueError(f"Drift {drift} is unknown")


In [None]:
class Pipeline(tf.keras.Model):
    def __init__(self, base=2):
        super().__init__()
        self.cnn = models.Sequential([
            layers.Conv2D(2, (5, 5), activation='relu', input_shape=(28, 28, 1), name="conv1"),
            layers.MaxPooling2D((2, 2), name="pool1"),
            layers.Conv2D(4, (5, 5), activation='relu', input_shape=(12, 12, 4), name="conv2"),
            layers.MaxPooling2D((2, 2), name="pool2"),
            layers.Conv2D(10, (4, 4), activation='relu', input_shape=(4, 4, 8), name="conv3"),
            layers.Flatten(name="flatten"),
            layers.Dense(10, input_shape=(16,), name="fc"),
        ], name="CNN")
        self.base = base

    def call(self, *args, **kwargs):
        return self.cnn(*args, **kwargs)
    # def forward(self, *args, **kwargs):
    #     return self.cnn.forward(*args, **kwargs)

    def factor(self, logits):
        return tf.argmax(logits, axis=-1, output_type=tf.int32) % self.base

    def refactor(self, logits, miss_mask, ninf=-999999):
        am = tf.argmax(logits, axis=-1, output_type=tf.int32)
        mask = am[:, tf.newaxis] == tf.range(logits.shape[-1], dtype=tf.int32)[tf.newaxis, :]
        mask &= miss_mask[:, tf.newaxis]
        logits = tf.where(mask, ninf, logits)
        return self.factor(logits), logits

pp = Pipeline()
pp.summary()

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [None]:
mnist = datasets.mnist

Xy_train, Xy_val = mnist.load_data()
def mapimg(img, label):
    return tf.image.convert_image_dtype(img, dtype=tf.float32), label

train_ref = (
    tf.data.Dataset.from_tensor_slices(Xy_train)
    .shuffle(Xy_train[0].shape[0])
    .batch(batch)
    .map(mapimg)
)
val_ref = (
    tf.data.Dataset.from_tensor_slices(Xy_val)
    .shuffle(Xy_val[0].shape[0])
    .batch(batch)
    .map(mapimg)
)

In [None]:
pp.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=metrics,
)

history = pp.fit(
    train_ref, epochs=1,
    validation_data=val_ref,
)


[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m35s[0m 17ms/step - accuracy: 0.6873 - loss: 0.9363 - val_accuracy: 0.9250 - val_loss: 0.2463


In [None]:
evkeys = ["loss"] + metrics
evvals = pp.evaluate(val_ref)
for k, v in zip(evkeys, evvals):
    print(f"{k}:\t{v:.2f}")

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 5ms/step - accuracy: 0.9230 - loss: 0.2445
loss:	0.25
accuracy:	0.93


In [None]:
train_drifted = drifted(train_ref)
val_drifted = drifted(val_ref)


In [None]:
evvals = pp.evaluate(val_drifted)
for k, v in zip(evkeys, evvals):
    print(f"{k}:\t{v:.2f}")

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 9ms/step - accuracy: 0.6250 - loss: 2.2281
loss:	2.21
accuracy:	0.63


In [None]:
def greedy_lblprop(model, X, trg, depth=2, _logits=False):
    assert depth > 1, f"for greedy label propagation specify depth > 1, not {depth}"
    if _logits:
        logits = X
    else:
        logits = model(X)
    guess = model.factor(logits)
    first_miss_mask = guess != trg
    miss_mask = first_miss_mask
    # this implementation is suboptimal
    # because it recomputes good values
    for d in range(depth-1):
        guess, logits = model.refactor(logits, miss_mask)
        miss_mask = guess != trg
    diff_mask = first_miss_mask ^ miss_mask
    # we could smooth over miss_mask, not simply argmax
    return tf.argmax(logits, axis=-1, output_type=tf.int32), miss_mask, diff_mask


In [None]:
logits = tf.convert_to_tensor([
    list(range(10)),  # odd, even
    [-i for i in range(10)],  # even, odd
    [0, 100, 0, 99, 0, 98, 0, 97, 0, 96],  # odd, odd
    [100, 0, 99, 0, 98, 0, 97, 0, 96, 0],  # even, even
], dtype=tf.float32)
# test 1
expected_new_labels = tf.convert_to_tensor([8, 0, 3, 0], dtype=tf.int32)
expected_miss_mask = tf.convert_to_tensor([False, False, True, False])
expected_diff_mask = tf.convert_to_tensor([True, False, False, False])
actual_new_labels, actual_miss_mask, actual_diff_mask =greedy_lblprop(
    pp, logits, tf.convert_to_tensor([0, 0, 0, 0], dtype=tf.int32), _logits=True,
)
assert all(expected_new_labels == actual_new_labels)
assert all(expected_miss_mask == actual_miss_mask)
assert all(expected_diff_mask == actual_diff_mask)
# test 2
expected_new_labels = tf.convert_to_tensor([9, 1, 1, 2], dtype=tf.int32)
expected_miss_mask = tf.convert_to_tensor([False, False, False, True])
expected_diff_mask = tf.convert_to_tensor([False, True, False, False])
actual_new_labels, actual_miss_mask, actual_diff_mask =greedy_lblprop(
    pp, logits, tf.convert_to_tensor([1, 1, 1, 1], dtype=tf.int32), _logits=True,
)
assert all(expected_new_labels == actual_new_labels)
assert all(expected_miss_mask == actual_miss_mask)
assert all(expected_diff_mask == actual_diff_mask)
