In [1]:
import os

In [2]:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"

In [3]:

import wandb
import numpy as np
import time
import tensorflow as tf
import tensorflow.keras as k
from sklearn.metrics import confusion_matrix
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.applications.mobilenet_v2 import (
    preprocess_input,
    MobileNetV2
)

In [4]:
# Configure GPU memory growth to prevent OOM
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(f"Memory growth setting error: {e}")

# Enable mixed precision for better T4 performance
tf.keras.mixed_precision.set_global_policy('mixed_float16')

In [5]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mrishg[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [6]:
class LogLRCallback(k.callbacks.Callback):
    """Logs optimizer learning rate after each epoch."""
    def on_epoch_end(self, epoch, logs=None):
        opt = self.model.optimizer
        lr = opt.learning_rate
        lr_val = float(lr.numpy() if hasattr(lr, "numpy") else lr)
        wandb.log({"lr": lr_val})

class LogSamplesCallback(k.callbacks.Callback):
    """Logs sample predictions after each epoch."""
    def __init__(self, x_test, y_test, labels):
        self.x = x_test[:16]
        self.y = y_test[:16]
        self.labels = labels

    def on_epoch_end(self, epoch, logs=None):
        preds = self.model.predict(self.x, verbose=0)
        preds = np.argmax(preds, axis=1)
        true  = np.argmax(self.y, axis=1)

        images = []
        for i in range(len(self.x)):
            images.append(
                wandb.Image(
                    self.x[i].astype(np.uint8),
                    caption=f"Pred: {self.labels[preds[i]]}, True: {self.labels[true[i]]}"
                )
            )
        wandb.log({"sample_predictions": images})


class ConfusionMatrixCallback(k.callbacks.Callback):
    """Logs confusion matrix each epoch."""
    def __init__(self, x_test, y_test, labels):
        self.x = x_test
        self.y = y_test
        self.labels = labels

    def on_epoch_end(self, epoch, logs=None):
        preds = self.model.predict(self.x, verbose=0)
        preds = np.argmax(preds, axis=1)
        true  = np.argmax(self.y, axis=1)

        wandb.log({
            "confusion_matrix": wandb.plot.confusion_matrix(
                y_true=true,
                preds=preds,
                class_names=self.labels
            )
        })


In [7]:
# --------------------- NEW CALLBACK ---------------------------
class EpochTimeCallback(k.callbacks.Callback):
    """Logs how long each epoch takes."""
    def on_epoch_begin(self, epoch, logs=None):
        self.start_time = time.time()

    def on_epoch_end(self, epoch, logs=None):
        epoch_time = time.time() - self.start_time
        wandb.log({"epoch_time_sec": epoch_time})



In [10]:
class CIFAR10Trainer:

    def __init__(self):

        self.labels = [
            "airplane","automobile","bird","cat","deer",
            "dog","frog","horse","ship","truck"
        ]
        self.num_classes = 10

        self.config = {
            "dropout": 0.25,
            "learn_rate": 0.0005,
            "epochs": 5,
            "batch_size": 32,   # smaller batch to avoid GPU OOM
            "base_trainable": False
        }

        self.run = wandb.init(
            project="CIFAR10-MobileNetV2",
            config=self.config,
            name="cifar10_transfer_learning_run"
        )

        self._prepare_data()


    # ===========================================================
    # DATA PIPELINE (memory safe)
    # ===========================================================
    def _prepare_data(self):

        (xtr, ytr), (xte, yte) = tf.keras.datasets.cifar10.load_data()

        AUTOTUNE = tf.data.AUTOTUNE

        xtr = tf.cast(xtr, tf.float32)
        xte = tf.cast(xte, tf.float32)

        target_size = (160, 160)

        def preprocess(image, label):
            # image: (32,32,3), label: shape (1,)
            image = tf.image.resize(image, target_size)
            image = preprocess_input(image)

            # Fix label shape: (1,) -> ()
            label = tf.squeeze(label, axis=0)
            label = tf.one_hot(label, depth=self.num_classes)  # shape (10,)

            return image, label

        # TRAIN DS
        train_ds = tf.data.Dataset.from_tensor_slices((xtr, ytr))
        train_ds = train_ds.shuffle(50000)
        train_ds = train_ds.map(preprocess, num_parallel_calls=AUTOTUNE)
        train_ds = train_ds.batch(self.config["batch_size"])
        train_ds = train_ds.prefetch(AUTOTUNE)

        # TEST DS
        test_ds = tf.data.Dataset.from_tensor_slices((xte, yte))
        test_ds = test_ds.map(preprocess, num_parallel_calls=AUTOTUNE)
        test_ds = test_ds.batch(self.config["batch_size"])
        test_ds = test_ds.prefetch(AUTOTUNE)

        self.train_ds = train_ds
        self.test_ds = test_ds

        # Extract 1 small numpy batch for callbacks (sample + confusion matrix)
        for xb, yb in test_ds.take(1):
            self.sample_x_test = xb.numpy()
            self.sample_y_test = yb.numpy()


    # ===========================================================
    # MODEL: MobileNetV2 (transfer learning)
    # ===========================================================
    def _build_model(self):

        base = MobileNetV2(
            include_top=False,
            weights="imagenet",
            pooling="avg",
            input_shape=(160, 160, 3)
        )

        base.trainable = self.config["base_trainable"]

        x = k.layers.Dropout(self.config["dropout"])(base.output)
        output = k.layers.Dense(self.num_classes, activation="softmax")(x)

        model = k.Model(inputs=base.input, outputs=output)

        opt = k.optimizers.Adam(self.config["learn_rate"])

        model.compile(
            optimizer=opt,
            loss="categorical_crossentropy",
            metrics=["accuracy"]
        )

        return model


    # ===========================================================
    # TRAIN
    # ===========================================================
    def train(self):

        model = self._build_model()

        callbacks = [
            LogLRCallback(),
            LogSamplesCallback(self.sample_x_test, self.sample_y_test, self.labels),
            ConfusionMatrixCallback(self.sample_x_test, self.sample_y_test, self.labels),
            EpochTimeCallback()
            # NOTE: We intentionally removed wandb.keras.WandbCallback
        ]

        model.fit(
            self.train_ds,
            epochs=self.config["epochs"],
            validation_data=self.test_ds,
            callbacks=callbacks,
            verbose=1
        )

        # Final evaluation
        loss, acc = model.evaluate(self.test_ds, verbose=0)
        wandb.log({"final/loss": loss, "final/accuracy": acc})

        self._log_model_artifact(model)
        self.run.finish()


    # ===========================================================
    # MODEL ARTIFACT SAVE
    # ===========================================================
    def _log_model_artifact(self, model):
        model.save("mobilenetv2_cifar10.h5")
        artifact = wandb.Artifact("mobilenetv2_cifar10", type="model")
        artifact.add_file("mobilenetv2_cifar10.h5")
        self.run.log_artifact(artifact)


In [11]:
# ============================================================
# RUN TRAINING
# ============================================================

CIFAR10Trainer().train()

Epoch 1/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m110s[0m 59ms/step - accuracy: 0.6747 - loss: 0.9837 - val_accuracy: 0.8490 - val_loss: 0.4436
Epoch 2/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 12ms/step - accuracy: 0.8391 - loss: 0.4644 - val_accuracy: 0.8572 - val_loss: 0.4103
Epoch 3/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 12ms/step - accuracy: 0.8500 - loss: 0.4326 - val_accuracy: 0.8641 - val_loss: 0.4028
Epoch 4/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 12ms/step - accuracy: 0.8546 - loss: 0.4141 - val_accuracy: 0.8595 - val_loss: 0.4221
Epoch 5/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 12ms/step - accuracy: 0.8635 - loss: 0.3956 - val_accuracy: 0.8686 - val_loss: 0.3909




0,1
epoch_time_sec,█▁▁▁▁
final/accuracy,▁
final/loss,▁
lr,▁▁▁▁▁

0,1
epoch_time_sec,18.84802
final/accuracy,0.8686
final/loss,0.3909
lr,0.0005
