In [1]:
import os

In [2]:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"

In [5]:
import wandb
from wandb.integration.keras import WandbMetricsLogger, WandbModelCheckpoint
import numpy as np
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
from tensorflow import keras as k
import os

In [6]:
wandb.login()

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/rohit/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mmanoghn[0m ([33mmanoghn-northeastern-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [9]:
class LogLRCallback(k.callbacks.Callback):
    """Log optimizer learning rate each epoch."""
    def on_epoch_end(self, epoch, logs=None):
        opt = self.model.optimizer
        lr = opt.learning_rate
        lr_val = float(lr.numpy() if hasattr(lr, "numpy") else lr)
        wandb.log({"lr": lr_val}, step=self.model.optimizer.iterations.numpy())

class LogSamplesCallback(k.callbacks.Callback):
    """Log a small table of predictions + images every epoch."""
    def __init__(self, x, y, labels, max_rows=32):
        super().__init__()
        self.x = x[:max_rows]
        self.y = y[:max_rows]
        self.labels = labels

    def on_epoch_end(self, epoch, logs=None):
        preds = self.model.predict(self.x, verbose=0)
        y_true = np.argmax(self.y, axis=1)
        y_pred = np.argmax(preds, axis=1)

        table = wandb.Table(columns=["image", "y_true", "y_pred", "correct", "p(y_pred)"])
        for i in range(len(self.x)):
            img = self.x[i]
            table.add_data(
                wandb.Image(img),
                self.labels[y_true[i]],
                self.labels[y_pred[i]],
                bool(y_true[i] == y_pred[i]),
                float(np.max(preds[i])),
            )
        wandb.log({f"samples/epoch_{epoch+1}": table})

class ConfusionMatrixCallback(k.callbacks.Callback):
    """Log a confusion matrix from the full validation set each epoch."""
    def __init__(self, x_val, y_val, labels):
        super().__init__()
        self.x_val = x_val
        self.y_val = y_val
        self.labels = labels

    def on_epoch_end(self, epoch, logs=None):
        preds = self.model.predict(self.x_val, verbose=0)
        y_true = np.argmax(self.y_val, axis=1)
        y_pred = np.argmax(preds, axis=1)
        cm_plot = wandb.plot.confusion_matrix(
            probs=None,
            y_true=y_true,
            preds=y_pred,
            class_names=self.labels,
        )
        wandb.log({"confusion_matrix": cm_plot})

# --- trainer -----------------------------------------------------------------

class CIFAR10Trainer:
    def __init__(self, project_name="Lab1-visualize-models", run_name="cifar10_cnn"):
        self.cfg = dict(
            dropout=0.2,
            layer_1_size=32,
            learn_rate=0.01,
            momentum=0.9,
            epochs=5,
            batch_size=64,
            sample=10000,
        )
        self.run = wandb.init(
            project=project_name,
            name=run_name,
            config=self.cfg,
            settings=wandb.Settings(start_method="thread"),
        )
        self.config = wandb.config
        # CIFAR-10 labels
        self.labels = ["Airplane", "Automobile", "Bird", "Cat", "Deer",
                       "Dog", "Frog", "Horse", "Ship", "Truck"]
        self._prepare_data()

    def _prepare_data(self):
        # Load CIFAR-10 dataset
        (xtr, ytr), (xte, yte) = cifar10.load_data()
        n = self.config.sample
        xtr = xtr[:n].astype("float32")/255.0
        ytr = ytr[:n].squeeze()  # CIFAR labels come as (n,1), squeeze to (n,)
        xte = xte[:n].astype("float32")/255.0
        yte = yte[:n].squeeze()
        
        # CIFAR-10 images are 32x32x3 (RGB)
        self.X_train = xtr
        self.X_test  = xte
        self.y_train = to_categorical(ytr)
        self.y_test  = to_categorical(yte)
        self.num_classes = self.y_test.shape[1]

    def _build_model(self):
        # Updated input shape for CIFAR-10: 32x32x3
        inputs = k.Input(shape=(32, 32, 3))
        x = k.layers.Conv2D(self.config.layer_1_size, (5,5), activation="relu")(inputs)
        x = k.layers.MaxPooling2D((2,2))(x)
        x = k.layers.Dropout(self.config.dropout)(x)
        x = k.layers.Flatten()(x)
        outputs = k.layers.Dense(self.num_classes, activation="softmax")(x)
        model = k.Model(inputs, outputs)

        opt = k.optimizers.SGD(
            learning_rate=self.config.learn_rate,
            momentum=self.config.momentum,
            nesterov=True,
        )
        model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
        return model

    def _log_model_artifact(self, model):
        # model summary as a text file + the saved model as an artifact
        summary_lines = []
        model.summary(print_fn=summary_lines.append)
        summary_txt = "\n".join(summary_lines)
        os.makedirs("artifacts", exist_ok=True)
        with open("artifacts/model_summary.txt", "w") as f:
            f.write(summary_txt)

        model_path = "artifacts/model.h5"
        model.save(model_path)

        art = wandb.Artifact("cifar10_model", type="model")
        art.add_file("artifacts/model_summary.txt")
        art.add_file(model_path)
        self.run.log_artifact(art)

    def train(self):
        model = self._build_model()

        callbacks = [
            WandbMetricsLogger(log_freq=10),
            WandbModelCheckpoint("checkpoints/model-{epoch:02d}.h5", save_weights_only=False),
            LogLRCallback(),
            LogSamplesCallback(self.X_test, self.y_test, self.labels, max_rows=32),
            ConfusionMatrixCallback(self.X_test, self.y_test, self.labels),
        ]

        model.fit(
            self.X_train, self.y_train,
            validation_data=(self.X_test, self.y_test),
            epochs=self.config.epochs,
            batch_size=self.config.batch_size,
            callbacks=callbacks,
            verbose=1,
        )

        loss, acc = model.evaluate(self.X_test, self.y_test, verbose=0)
        wandb.log({"final/loss": loss, "final/accuracy": acc})

        self._log_model_artifact(model)

        self.run.finish()


# Run the trainer
CIFAR10Trainer().train()

Epoch 1/5
[1m155/157[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 8ms/step - accuracy: 0.2209 - loss: 2.1171



[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 29ms/step - accuracy: 0.2957 - loss: 1.9566 - val_accuracy: 0.3560 - val_loss: 1.8435
Epoch 2/5
[1m153/157[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 14ms/step - accuracy: 0.4192 - loss: 1.6284



[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 34ms/step - accuracy: 0.4387 - loss: 1.5820 - val_accuracy: 0.4076 - val_loss: 1.6349
Epoch 3/5
[1m156/157[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 7ms/step - accuracy: 0.4841 - loss: 1.4535



[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 27ms/step - accuracy: 0.4926 - loss: 1.4267 - val_accuracy: 0.4567 - val_loss: 1.5317
Epoch 4/5
[1m153/157[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 7ms/step - accuracy: 0.5282 - loss: 1.3487



[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 26ms/step - accuracy: 0.5254 - loss: 1.3510 - val_accuracy: 0.4976 - val_loss: 1.4247
Epoch 5/5
[1m149/157[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 8ms/step - accuracy: 0.5368 - loss: 1.3085



[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 27ms/step - accuracy: 0.5411 - loss: 1.3027 - val_accuracy: 0.5008 - val_loss: 1.4339




0,1
batch/accuracy,▁▂▂▃▃▃▃▄▄▅▆▆▆▆▆▇▇▇▇▇▇▇▇███████▇█████████
batch/batch_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇███
batch/learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
batch/loss,██▇▇▇▆▆▆▆▆▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▂▁▁▁▁▁
epoch/accuracy,▁▅▇██
epoch/epoch,▁▃▅▆█
epoch/learning_rate,▁▁▁▁▁
epoch/loss,█▄▂▂▁
epoch/val_accuracy,▁▃▆██
epoch/val_loss,█▅▃▁▁

0,1
batch/accuracy,0.5416
batch/batch_step,790
batch/learning_rate,0.01
batch/loss,1.30048
epoch/accuracy,0.5411
epoch/epoch,4
epoch/learning_rate,0.01
epoch/loss,1.30265
epoch/val_accuracy,0.5008
epoch/val_loss,1.43391
