In [None]:
from google.colab import auth
auth.authenticate_user()

### Let's import some modules and download data

In [None]:
import logging
import tensorflow as tf
import tensorflow_datasets as tfds

from util.dataset import create_flowers_ds
from util.models import create_eff_net_trainable, create_eff_net_pre_trained
from util.utils import prediction_round
from util.adversarial import adversarial_round
from util.pretext import RotationPretextTrainer, JigsawPretextTrainer
from util.tpu import init_tpu, TpuStub
from util.evaluation import eval_round
from util.config import batch_size

logging.basicConfig(format='%(funcName)s()->%(levelname)s: %(message)s', level=logging.INFO)
tf.get_logger().setLevel('ERROR')
log = logging.getLogger(__name__)
log.info(f'Tensorflow version: {tf.__version__}')
log.info(f'Tensorflow datasets version: {tfds.__version__}')

### Configure hardware acceleration

In [None]:
#device_strategy = init_tpu()
device_strategy = CpuStub()

### Let's create datasets

In [None]:
ds = Dataset(ds_names[0])

### Let's create model

In [None]:
M_UT = create_basic_model(num_classes=ds.num_classes, device_strategy=device_strategy)

In [None]:
M_UT.summary()

### Verify that nothing is broken during downstream training

In [None]:
log.info(f'Training {M_UT.name}')
M_UT.fit(ds.train, validation_data=ds.val)

### Test how good Model can predict

In [None]:
result_imgs, result_labels = prediction_round(dataset=ds, model=M_UT)
log.info(f'Out of {len(ds.test) * batch_size} images {len(result_labels)} were correctly classified')

### Now let's try to fool our network

In [None]:
sf = adversarial_round(images=result_imgs, labels=result_labels, model=M_UT)
log.info(f'{len(sf)} out of {len(result_labels)} were missclassified with mean of {sf.mean()}')

### Train to identify rotation

In [None]:
rotationTrainer = RotationPretextTrainer()
M_UT = rotationTrainer.train_pretrext_task(dataset=ds, model=M_UT, device_strategy=device_strategy, epochs=1)

### Jigsaw pretext trainining

In [None]:
jigsawTrainer = JigsawPretextTrainer()
jigsawTrainer.train_pretrext_task(dataset=ds, model=M_UT, device_strategy=device_strategy, epochs=1)

### Let's evaluate our (hopefully) successes

In [None]:
pretext_tasks = [
                 [RotationPretextTrainer()],
                 [JigsawPretextTrainer()],
                 [RotationPretextTrainer(), JigsawPretextTrainer()]
]

epochs = [
          (10, 10),
          (10, 20),
          (10, 30),
          (10, 50),

          (20, 10),
          (20, 20),
          (20, 30),
          (20, 50),

          (30, 10),
          (30, 20),
          (30, 30),
          (30, 50),

          (50, 10),
          (50, 20),
          (50, 30),
          (50, 50)
]

models = [
          create_basic_model,
          create_eff_net_frozen,
          create_eff_net_trainable,
]
M = models[2]

log.info(f'Collecting data for {ds.name}')
for ep in epochs:
    for pt in pretext_tasks:
        eval_round(model_constr=M, dataset=ds, pretext_trainers=pt, device_strategy=device_strategy,
                   downstream_epochs=ep[0], pretext_epochs=ep[1])

In [None]:
downstream_epochs = [10, 20, 30, 50]
for m in models:
    for i in downstream_epochs:
        eval_round(model_constr=m, dataset=ds, pretext_trainers=[], device_strategy=device_strategy,
                   downstream_epochs=i, pretext_epochs=0)