# CIFAR-10 Efficient Training & Distillation Pipeline

Este notebook contém uma pipeline completa e comentada para treinar modelos eficientes no CIFAR-10 com foco em aplicação posterior para sistemas tipo CAPTCHA. O fluxo inclui:

- carregamento e pré-processamento eficiente com `tf.data`
- data augmentation (random crop, flip, brightness/contrast, cutout, mixup)
- transferência de aprendizado com backbone pré-treinado em ImageNet (EfficientNetB0)
- definição de um *student* leve (MobileNetV2)
- treinamento por distilação (teacher -> student)
- dicas/células para pruning e quantização e exportação para TFLite

OBS: as células de treinamento estão configuradas para rodar de forma segura em notebooks (ex.: Colab). Ajuste `EPOCHS` e `BATCH_SIZE` conforme seu hardware.

----


In [None]:
import tensorflow as tf
import numpy as np
import os
import time
print('TensorFlow version:', tf.__version__)
print('Eager execution:', tf.executing_eagerly())

## 1) Pipeline de dados (tf.data) com augmentations
Funções: preprocessamento, random crop, cutout, mixup e construção dos datasets.

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

def random_crop_and_resize(image, target_size=96, crop_pad=8):
    image = tf.image.resize_with_crop_or_pad(image, target_size + crop_pad, target_size + crop_pad)
    image = tf.image.random_crop(image, size=[target_size, target_size, 3])
    return image

def preprocess_train(image, label, target_size=96):
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.image.resize(image, [96,96])
    image = random_crop_and_resize(image, target_size=target_size, crop_pad=8)
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_brightness(image, 0.06)
    image = tf.image.random_contrast(image, 0.95, 1.05)
    return image, label

def preprocess_eval(image, label, target_size=96):
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.image.resize(image, [target_size, target_size])
    return image, label

def cutout(image, size=16):
    h = tf.shape(image)[0]
    w = tf.shape(image)[1]
    y = tf.random.uniform([], 0, h, dtype=tf.int32)
    x = tf.random.uniform([], 0, w, dtype=tf.int32)
    y1 = tf.clip_by_value(y - size//2, 0, h)
    y2 = tf.clip_by_value(y + size//2, 0, h)
    x1 = tf.clip_by_value(x - size//2, 0, w)
    x2 = tf.clip_by_value(x + size//2, 0, w)
    img = image
    img = tf.concat([img[:y1],
                     tf.concat([img[y1:y2, :x1], tf.zeros([y2-y1, x2-x1, 3]), img[y1:y2, x2:]], axis=1),
                     img[y2:]], axis=0)
    return img

def mixup(images, labels, alpha=0.2):
    if alpha <= 0:
        return images, labels
    lam = np.random.beta(alpha, alpha)
    batch_size = tf.shape(images)[0]
    index = tf.random.shuffle(tf.range(batch_size))
    mixed_images = lam * images + (1 - lam) * tf.gather(images, index)
    mixed_labels = lam * labels + (1 - lam) * tf.gather(labels, index)
    return mixed_images, mixed_labels

def build_datasets(batch_size=128, target_size=96, use_mixup=True, mixup_alpha=0.2):
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
    num_classes = 10
    y_train = tf.keras.utils.to_categorical(y_train, num_classes)
    y_test = tf.keras.utils.to_categorical(y_test, num_classes)

    train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    train_ds = train_ds.shuffle(50000).map(lambda x,y: preprocess_train(x,y, target_size), num_parallel_calls=AUTOTUNE)
    train_ds = train_ds.map(lambda x,y: (cutout(x, size=16), y), num_parallel_calls=AUTOTUNE)
    train_ds = train_ds.batch(batch_size).prefetch(AUTOTUNE)

    val_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
    val_ds = val_ds.map(lambda x,y: preprocess_eval(x,y, target_size), num_parallel_calls=AUTOTUNE)
    val_ds = val_ds.batch(batch_size).prefetch(AUTOTUNE)

    return train_ds, val_ds, (x_test, y_test)


## 2) Modelos: Teacher (EfficientNetB0 pré-treinado) e Student (MobileNetV2)

In [None]:
from tensorflow.keras import layers, models, applications

def build_teacher(input_shape=(96,96,3), num_classes=10):
    base = tf.keras.applications.EfficientNetB0(include_top=False, weights='imagenet', input_shape=input_shape)
    x = layers.GlobalAveragePooling2D()(base.output)
    x = layers.Dropout(0.3)(x)
    outputs = layers.Dense(num_classes, activation='softmax', dtype='float32')(x)
    model = models.Model(inputs=base.input, outputs=outputs, name='teacher_effnetb0')
    return model

def build_student(input_shape=(96,96,3), num_classes=10, alpha=1.0):
    base = tf.keras.applications.MobileNetV2(include_top=False, weights=None, input_shape=input_shape, alpha=alpha)
    x = layers.GlobalAveragePooling2D()(base.output)
    x = layers.Dropout(0.3)(x)
    outputs = layers.Dense(num_classes, activation='softmax', dtype='float32')(x)
    model = models.Model(inputs=base.input, outputs=outputs, name='student_mobilenetv2')
    return model

# Build and show summaries (weights download happens when first instantiated in some envs)
teacher = build_teacher()
student = build_student()
teacher.summary()
student.summary()


## 3) Treinamento por distilação (Distiller)

In [None]:
class Distiller(tf.keras.Model):
    def __init__(self, student, teacher, temperature=4.0, alpha=0.5):
        super(Distiller, self).__init__()
        self.student = student
        self.teacher = teacher
        self.temperature = temperature
        self.alpha = alpha

    def compile(self, optimizer, metrics, student_loss_fn, distill_loss_fn):
        super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distill_loss_fn = distill_loss_fn

    def train_step(self, data):
        x, y = data
        teacher_preds = self.teacher(x, training=False)
        with tf.GradientTape() as tape:
            student_preds = self.student(x, training=True)
            s_loss = self.student_loss_fn(y, student_preds)
            t_soft = tf.nn.softmax(teacher_preds / self.temperature, axis=1)
            s_soft = tf.nn.softmax(student_preds / self.temperature, axis=1)
            d_loss = self.distill_loss_fn(t_soft, s_soft)
            loss = self.alpha * s_loss + (1.0 - self.alpha) * (self.temperature**2) * d_loss
        gradients = tape.gradient(loss, self.student.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.student.trainable_variables))
        self.compiled_metrics.update_state(y, student_preds)
        results = {m.name: m.result() for m in self.metrics}
        results.update({'loss': loss})
        return results

    def test_step(self, data):
        x, y = data
        y_pred = self.student(x, training=False)
        t_loss = self.student_loss_fn(y, y_pred)
        self.compiled_metrics.update_state(y, y_pred)
        results = {m.name: m.result() for m in self.metrics}
        results.update({'loss': t_loss})
        return results

def prepare_distiller(teacher, student, lr=1e-3, temperature=4.0, alpha=0.5):
    distiller = Distiller(student, teacher, temperature=temperature, alpha=alpha)
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
    student_loss = tf.keras.losses.CategoricalCrossentropy()
    distill_loss = tf.keras.losses.KLDivergence()
    distiller.compile(optimizer=optimizer, metrics=[tf.keras.metrics.CategoricalAccuracy()], 
                      student_loss_fn=student_loss, distill_loss_fn=distill_loss)
    return distiller


## 4) Orquestração do treinamento (exemplo)

In [None]:
# Ajuste estes valores conforme seu hardware
EPOCHS_TEACHER = 20
EPOCHS_STUDENT = 40
BATCH_SIZE = 128
TARGET_SIZE = 96

train_ds, val_ds, (x_test, y_test) = build_datasets(batch_size=BATCH_SIZE, target_size=TARGET_SIZE, use_mixup=False)

# Treine ou carregue o teacher
teacher = build_teacher(input_shape=(TARGET_SIZE,TARGET_SIZE,3))
teacher.compile(optimizer=tf.keras.optimizers.Adam(1e-4), loss=tf.keras.losses.CategoricalCrossentropy(), metrics=['accuracy'])
# teacher.fit(train_ds, epochs=EPOCHS_TEACHER, validation_data=val_ds)

# Prepare student and distiller
student = build_student(input_shape=(TARGET_SIZE,TARGET_SIZE,3), alpha=0.35)
distiller = prepare_distiller(teacher, student, lr=1e-3, temperature=4.0, alpha=0.5)
# distiller.fit(train_ds, epochs=EPOCHS_STUDENT, validation_data=val_ds)

print('Fluxo preparado. Descomente as chamadas .fit() para treinar no seu ambiente.')


## 5) Pruning e Quantização (snippets)

In [None]:
try:
    import tensorflow_model_optimization as tfmot
    print('tfmot available for pruning/QAT')
    prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
    pruning_params = {
        'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.0,
                                                                 final_sparsity=0.5,
                                                                 begin_step=200,
                                                                 end_step=2000)
    }
    student_for_prune = prune_low_magnitude(student, **pruning_params)
    student_for_prune.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    # student_for_prune.fit(...)
except Exception as e:
    print('Pruning not configured (tensorflow_model_optimization missing?) :', e)


# TFLite conversion helper
def convert_to_tflite(model, quantize=False, representative_data_gen=None, filename='model.tflite'):
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    if quantize:
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        if representative_data_gen is not None:
            converter.representative_dataset = representative_data_gen
    tflite_model = converter.convert()
    open(filename, 'wb').write(tflite_model)
    print('Saved TFLite model to', filename)


## 6) Export e avaliação

In [None]:
save_dir = 'saved_student_model'
try:
    student.save(save_dir, include_optimizer=False)
    print('Saved student to', save_dir)
except Exception as e:
    print('Failed to save student model:', e)

# Example convert (may require representative dataset for quantization)
try:
    convert_to_tflite(student, quantize=False, filename='student_float.tflite')
except Exception as e:
    print('TFLite conversion failed:', e)


### Notas finais
- Ajuste `EPOCHS` e `BATCH_SIZE` conforme seu hardware.
- Para deploy em dispositivos restritos use quantização e pruning.
- Treine com exemplos sintéticos de CAPTCHA para melhorar robustez.