<a href="https://colab.research.google.com/github/archqua/pipeline_training/blob/master/mnist_experiment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
parameters = dict(
    seed  = 12309,
    batch = 128,
    metrics = ["accuracy",],
    drift = "flip",
    alg = "greedy",
)
for p in parameters:
    if p not in locals():
        v = parameters[p]
        if isinstance(v, str):
            exec(f"{p} = '{v}'")
        else:
            exec(f"{p} = {v}")


In [2]:
import tensorflow as tf
from tensorflow.keras import datasets, layers, models

tf.random.set_seed(seed)


In [3]:
def printeval(evvals, evkeys=["loss"] + metrics):
    for k, v in zip(evkeys, evvals):
        print(f"{k}:\t{v:.4f}")

def flipaug(img, lbl):
    return tf.image.random_flip_left_right(img), lbl


def drifted(dataset, seed=seed, drift=drift):
    drift = drift.lower()
    if drift == "flip":
        return dataset.map(flipaug)
    else:
        raise ValueError(f"Drift {drift} is unknown")


In [4]:
mnist = datasets.mnist

Xy_train, Xy_val = mnist.load_data()
def mapimg(img, label):
    return tf.image.convert_image_dtype(img, dtype=tf.float32), label

train_ref = (
    tf.data.Dataset.from_tensor_slices(Xy_train)
    .shuffle(Xy_train[0].shape[0])
    .batch(batch)
    .map(mapimg)
)
val_ref = (
    tf.data.Dataset.from_tensor_slices(Xy_val)
    .shuffle(Xy_val[0].shape[0])
    .batch(batch)
    .map(mapimg)
)
train_drifted = drifted(train_ref)
val_drifted = drifted(val_ref)


In [5]:
def greedy_loss_fn(
    model,
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    patience=None,
    factoring_layer="factor",
):
    factoring_layer = model.get_layer(factoring_layer)
    def wrapped(trg, logits):
        inferred_lbls, miss_mask, diff_mask = factoring_layer.greedy_lblprop(
            tf.stop_gradient(logits), trg, patience=patience,
        )
        match_mask = ~miss_mask
        return loss(
            tf.boolean_mask(inferred_lbls, match_mask),
            tf.boolean_mask(logits, match_mask),
        )
    return wrapped

class Factor(layers.Layer):
    def __init__(self, base=2, patience=1, name=None):
        super().__init__(name=name)
        self.base = tf.constant(base, dtype=tf.int32, shape=())
        self.patience = patience

    def call(self, logits):
        return self.factor(tf.argmax(logits, axis=-1))

    def factor(self, labels):
        return tf.cast(tf.keras.ops.mod(labels, self.base), tf.int32)

    def refactor(self, logits, miss_mask, ninf=-1.0e+06):
        am = tf.argmax(logits, axis=-1, output_type=tf.int32)
        mask = am[:, tf.newaxis] == tf.range(logits.shape[-1], dtype=tf.int32)[tf.newaxis, :]
        mask &= miss_mask[:, tf.newaxis]
        logits = tf.where(mask, ninf, logits)
        return self(logits), logits

    def greedy_lblprop(self, logits, trg, patience=None):
        # this line is entirely to satisfy tf static checker
        trg = tf.cast(trg, tf.int32)
        if patience is None:
            patience = self.patience
        assert patience >= 0, f"for greedy label propagation specify patience >= 0, not {patience}"
        guess = self(logits)
        first_miss_mask = guess != trg
        miss_mask = first_miss_mask
        # this implementation is suboptimal
        # because it recomputes good values
        for d in range(patience):
            guess, logits = self.refactor(logits, miss_mask)
            miss_mask = guess != trg
        diff_mask = first_miss_mask ^ miss_mask
        # we could smooth over miss_mask, not simply argmax
        return tf.argmax(logits, axis=-1, output_type=tf.int32), miss_mask, diff_mask

# input
inp = tf.keras.Input((28, 28, 1), name="img")
# cnn
cnn = tf.keras.Sequential([
    layers.InputLayer((28, 28, 1), name="inp"),
    layers.Conv2D(2, (5, 5), activation='relu', name="conv1"),
    layers.MaxPooling2D((2, 2), name="pool1"),
    layers.Conv2D(4, (5, 5), activation='relu', name="conv2"),
    layers.MaxPooling2D((2, 2), name="pool2"),
    layers.Conv2D(10, (4, 4), activation='relu', name="conv3"),
    layers.Flatten(name="flatten"),
    layers.Dense(10, name="logit"),
], name="logit")
# outputs
logit = cnn(inp)
factor = Factor(base=2, name="factor")(logit)

pp = tf.keras.Model(inp, [factor, logit], name="Pipeline")
pp.summary()



In [6]:
def factor_labels(model, factoring_layer="factor", duplicate=True):
    factoring_layer = model.get_layer(factoring_layer)
    # @tf.function
    def wrapped(img, lbl):
        if duplicate:
            return img, factoring_layer.factor(lbl), lbl
        return img, factoring_layer.factor(lbl)
    return wrapped

train_drifted_factored = train_drifted.map(factor_labels(pp))
val_drifted_factored = val_drifted.map(factor_labels(pp))


In [7]:
pp.get_layer("logit").compile(
    optimizer="adam",
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=metrics,
)
pp.get_layer("logit").fit(
    train_ref, epochs=1,
    validation_data=val_ref,
)
printeval(pp.get_layer("logit").evaluate(val_ref))


[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 57ms/step - accuracy: 0.4529 - loss: 1.5859 - val_accuracy: 0.8418 - val_loss: 0.4952
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 15ms/step - accuracy: 0.8416 - loss: 0.4910
loss:	0.4952
accuracy:	0.8418


In [8]:
printeval(pp.get_layer("logit").evaluate(val_drifted))

[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 15ms/step - accuracy: 0.5749 - loss: 2.2798
loss:	2.3525
accuracy:	0.5637


In [9]:
from types import NoneType
import ipywidgets as widgets
from IPython.display import display
import time

def pipeline_train_loop(
    pipeline,
    train_dataset,
    optimizer,
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=metrics,
    epochs=1,
    val_dataset=None,
    early_stopping_rounds=None,
    use_tqdm=True,
):
    Metrics = {"logit": [], "factor": []}
    for m in metrics:
        for k in Metrics:
            Metrics[k].append(tf.keras.metrics.get(m.capitalize()))
    def smetrics() -> str:
        res = ""
        for k in Metrics:
            for M in Metrics[k]:
                res += f"\t{k}_{M.name}: {M.result():.4f}"
        return res
    total_progress = widgets.IntProgress(
        value=0, min=0, max=epochs,
        description=f"0/{epochs} after 0.0s",
        style={"description_width": "100px"},
    )
    display(total_progress)
    start_time = time.time()
    passed = 0.0
    def sdesc(ep):
        return f"{ep}/{epochs} after {passed:.1f}s"
    for e in range(epochs):
        try:
            dsl = len(train_dataset)
        except:
            dsl = NoneType
        epoch_progress = widgets.IntProgress(
            value=0, min=0, max=dsl,
        )
        display(epoch_progress)
        for k in Metrics:
            for M in Metrics[k]:
                M.reset_state()
        metrics_html = widgets.HTMLMath(
            value = smetrics(),
            placeholder="HTMLM",
            description=f"Epoch {e+1}/{epochs} (train) --",
            style = {"description_width": "initial"},
            # layout={"width": "1000px"},
        )
        display(metrics_html)
        for data in train_dataset:
            # training step begin
            img, trg, lbl = data
            with tf.GradientTape() as tape:
                factors, logits = pipeline(img, training=True)

                loss_value = greedy_loss_fn(pipeline)(trg, logits)
                for M in Metrics["logit"]:
                    M.update_state(lbl, tf.argmax(logits, axis=-1))
                for M in Metrics["factor"]:
                    M.update_state(trg, factors)

            grads = tape.gradient(loss_value, pipeline.trainable_weights)

            optimizer.apply_gradients(zip(grads, pipeline.trainable_weights))

            # training step end
            metrics_html.value = smetrics()
            passed = time.time() - start_time
            total_progress.description = sdesc(e)
            epoch_progress.value += 1
        if val_dataset is not None:
            for k in Metrics:
                for M in Metrics[k]:
                    M.reset_state()
            for data in val_dataset:
                # training step begin
                img, trg, lbl = data
                factors, logits = pipeline(img, training=True)
                for M in Metrics["logit"]:
                    M.update_state(lbl, tf.argmax(logits, axis=-1))
                for M in Metrics["factor"]:
                    M.update_state(trg, factors)
            print(f"Validation: {smetrics()}")


        passed = time.time() - start_time
        total_progress.description = sdesc(epochs)
        total_progress.value += 1

def pipeline_eval_loop(
    pipeline,
    val_dataset,
    loss,
    metrics,
    use_tqdm=True,
):
    pass


In [10]:
pipeline_train_loop(
    pp,
    train_drifted_factored,
    val_dataset=val_drifted_factored,
    optimizer=tf.keras.optimizers.Adam(),
    epochs=5,
)

IntProgress(value=0, description='0/5 after 0.0s', max=5, style=ProgressStyle(description_width='100px'))

IntProgress(value=0, max=469)

HTMLMath(value='\tlogit_accuracy: 0.0000\tfactor_accuracy: 0.0000', description='Epoch 1/5 (train) --', placeh…

Validation: 	logit_accuracy: 0.5312	factor_accuracy: 0.8219


IntProgress(value=0, max=469)

HTMLMath(value='\tlogit_accuracy: 0.0000\tfactor_accuracy: 0.0000', description='Epoch 2/5 (train) --', placeh…

Validation: 	logit_accuracy: 0.5264	factor_accuracy: 0.8715


IntProgress(value=0, max=469)

HTMLMath(value='\tlogit_accuracy: 0.0000\tfactor_accuracy: 0.0000', description='Epoch 3/5 (train) --', placeh…

Validation: 	logit_accuracy: 0.5150	factor_accuracy: 0.8807


IntProgress(value=0, max=469)

HTMLMath(value='\tlogit_accuracy: 0.0000\tfactor_accuracy: 0.0000', description='Epoch 4/5 (train) --', placeh…

Validation: 	logit_accuracy: 0.5145	factor_accuracy: 0.8860


IntProgress(value=0, max=469)

HTMLMath(value='\tlogit_accuracy: 0.0000\tfactor_accuracy: 0.0000', description='Epoch 5/5 (train) --', placeh…

Validation: 	logit_accuracy: 0.5347	factor_accuracy: 0.9008


In [12]:
# logits = tf.convert_to_tensor([
#     list(range(10)),  # odd, even
#     [-i for i in range(10)],  # even, odd
#     [0, 100, 0, 99, 0, 98, 0, 97, 0, 96],  # odd, odd
#     [100, 0, 99, 0, 98, 0, 97, 0, 96, 0],  # even, even
# ], dtype=tf.float32)
# # test 1
# expected_new_labels = tf.convert_to_tensor([8, 0, 3, 0], dtype=tf.int32)
# expected_miss_mask = tf.convert_to_tensor([False, False, True, False])
# expected_diff_mask = tf.convert_to_tensor([True, False, False, False])
# actual_new_labels, actual_miss_mask, actual_diff_mask =greedy_lblprop(
#     pp, logits, tf.convert_to_tensor([0, 0, 0, 0], dtype=tf.int32), _logits=True,
# )
# assert all(expected_new_labels == actual_new_labels)
# assert all(expected_miss_mask == actual_miss_mask)
# assert all(expected_diff_mask == actual_diff_mask)
# # test 2
# expected_new_labels = tf.convert_to_tensor([9, 1, 1, 2], dtype=tf.int32)
# expected_miss_mask = tf.convert_to_tensor([False, False, False, True])
# expected_diff_mask = tf.convert_to_tensor([False, True, False, False])
# actual_new_labels, actual_miss_mask, actual_diff_mask =greedy_lblprop(
#     pp, logits, tf.convert_to_tensor([1, 1, 1, 1], dtype=tf.int32), _logits=True,
# )
# assert all(expected_new_labels == actual_new_labels)
# assert all(expected_miss_mask == actual_miss_mask)
# assert all(expected_diff_mask == actual_diff_mask)
