In [2]:
from datetime import datetime
import tensorflow as tf

from tensorflow.keras.datasets import cifar10
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.models import load_model, Model
from google.colab import drive
import sys
import os
sys.path.append('/content')

from ijepa_keras.utils.encoder import MobileNetV2Encoder
from ijepa_keras.utils.pretrain import JEPA
from ijepa_keras.utils.predictor import MobileNetV2Predictor
from ijepa_keras.utils.plotting import plot_training_history


In [4]:
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:

EPOCHS = 25
BATCH_SIZE = 32
SAVE_MODEL_PATH = '/content/drive/MyDrive/IJEPA_MobileNetV2_Without_Imagenet_Weights'

In [6]:
# Load CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 0us/step


In [7]:
# Initialize MobileNet V2 encoder
encoder = MobileNetV2Encoder()
context_encoder = encoder.get_context_encoder()
target_encoder = encoder.get_target_encoder()

In [8]:
# Initialize VGG16 predictor
predictor = MobileNetV2Predictor()
predictor_model = predictor.get_predictor()

In [9]:
# I-JEPA Pretraining
jepa = JEPA(context_encoder=context_encoder,
             target_encoder=target_encoder,
             predictor_model=predictor_model,
             optimizer=Adam(),
             loss_fn=MeanSquaredError(),
             model_save_path=SAVE_MODEL_PATH)

In [10]:
jepa.train(x_train, epochs=EPOCHS, batch_size=BATCH_SIZE)

Epoch 1/25: 100%|██████████| 1562/1562 [16:54<00:00,  1.54batch/s]


Epoch 1/25, Loss: 0.0003, Momentum: 0.9600
Saving best models with loss:0.0003113190267088825...


Epoch 2/25: 100%|██████████| 1562/1562 [17:26<00:00,  1.49batch/s]


Epoch 2/25, Loss: 0.0000, Momentum: 0.9616
Saving best models with loss:6.987360578176019e-26...


Epoch 3/25: 100%|██████████| 1562/1562 [17:05<00:00,  1.52batch/s]


Epoch 3/25, Loss: 0.0000, Momentum: 0.9632


Epoch 4/25: 100%|██████████| 1562/1562 [17:24<00:00,  1.50batch/s]


Epoch 4/25, Loss: 0.0000, Momentum: 0.9648


Epoch 5/25: 100%|██████████| 1562/1562 [17:18<00:00,  1.50batch/s]


Epoch 5/25, Loss: 0.0000, Momentum: 0.9664


Epoch 6/25: 100%|██████████| 1562/1562 [17:15<00:00,  1.51batch/s]


Epoch 6/25, Loss: 0.0000, Momentum: 0.9680
Saving best models with loss:6.97804352212381e-26...


Epoch 7/25: 100%|██████████| 1562/1562 [17:12<00:00,  1.51batch/s]


Epoch 7/25, Loss: 0.0000, Momentum: 0.9696


Epoch 8/25: 100%|██████████| 1562/1562 [17:07<00:00,  1.52batch/s]


Epoch 8/25, Loss: 0.0000, Momentum: 0.9712


Epoch 9/25: 100%|██████████| 1562/1562 [17:08<00:00,  1.52batch/s]


Epoch 9/25, Loss: 0.0000, Momentum: 0.9728


Epoch 10/25: 100%|██████████| 1562/1562 [17:27<00:00,  1.49batch/s]


Epoch 10/25, Loss: 0.0000, Momentum: 0.9744


Epoch 11/25: 100%|██████████| 1562/1562 [17:16<00:00,  1.51batch/s]


Epoch 11/25, Loss: 0.0000, Momentum: 0.9760


Epoch 12/25: 100%|██████████| 1562/1562 [17:08<00:00,  1.52batch/s]


Epoch 12/25, Loss: 0.0000, Momentum: 0.9776
Saving best models with loss:6.977425432070469e-26...


Epoch 13/25: 100%|██████████| 1562/1562 [17:11<00:00,  1.51batch/s]


Epoch 13/25, Loss: 0.0000, Momentum: 0.9792


Epoch 14/25: 100%|██████████| 1562/1562 [17:10<00:00,  1.52batch/s]


Epoch 14/25, Loss: 0.0000, Momentum: 0.9808


Epoch 15/25: 100%|██████████| 1562/1562 [17:09<00:00,  1.52batch/s]


Epoch 15/25, Loss: 0.0000, Momentum: 0.9824


Epoch 16/25: 100%|██████████| 1562/1562 [17:09<00:00,  1.52batch/s]


Epoch 16/25, Loss: 0.0000, Momentum: 0.9840
Saving best models with loss:6.975662906998729e-26...


Epoch 17/25: 100%|██████████| 1562/1562 [17:20<00:00,  1.50batch/s]


Epoch 17/25, Loss: 0.0000, Momentum: 0.9856


Epoch 18/25: 100%|██████████| 1562/1562 [17:01<00:00,  1.53batch/s]


Epoch 18/25, Loss: 0.0000, Momentum: 0.9872


Epoch 19/25: 100%|██████████| 1562/1562 [17:11<00:00,  1.51batch/s]


Epoch 19/25, Loss: 0.0000, Momentum: 0.9888


Epoch 20/25: 100%|██████████| 1562/1562 [16:57<00:00,  1.53batch/s]


Epoch 20/25, Loss: 0.0000, Momentum: 0.9904


Epoch 21/25: 100%|██████████| 1562/1562 [16:55<00:00,  1.54batch/s]


Epoch 21/25, Loss: 0.0000, Momentum: 0.9920


Epoch 22/25: 100%|██████████| 1562/1562 [17:00<00:00,  1.53batch/s]


Epoch 22/25, Loss: 0.0000, Momentum: 0.9936


Epoch 23/25: 100%|██████████| 1562/1562 [17:00<00:00,  1.53batch/s]


Epoch 23/25, Loss: 0.0000, Momentum: 0.9952


Epoch 24/25: 100%|██████████| 1562/1562 [17:12<00:00,  1.51batch/s]


Epoch 24/25, Loss: 0.0000, Momentum: 0.9968


Epoch 25/25: 100%|██████████| 1562/1562 [16:55<00:00,  1.54batch/s]


Epoch 25/25, Loss: 0.0000, Momentum: 0.9984
Saving best models with loss:6.963016704917395e-26...


In [10]:
model_path = '/content/drive/MyDrive/IJEPA_MobileNetV2_Without_Imagenet_Weightsbest_target_encoder.keras'
target_encoder  = tf.keras.models.load_model(model_path)
# Load saved models from disk
# target_encoder = load_model(os.path.join(SAVE_MODEL_PATH, "best_target_encoder.keras"))


In [11]:
# Freeze the target encoder
for layer in target_encoder.layers:
    layer.trainable = False


In [12]:
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model

# Build linear probing model
def build_linear_probe_model(encoder, num_classes):
    # Input for image
    input_layer = Input(shape=(32, 32, 3), name="embedding input")

    # Pass input through the frozen encoder
    x = encoder(input_layer)

    # Add linear classification head
    output_layer = Dense(num_classes, activation="softmax", name="classification_head")(x)

    # Build model
    model = Model(inputs=input_layer, outputs=output_layer)
    return model


In [13]:
# Create the linear probing model
num_classes = 10  # CIFAR-10 has 10 classes
linear_probe_model = build_linear_probe_model(target_encoder, num_classes)

In [14]:
# Compile the model
linear_probe_model.compile(
    optimizer=Adam(),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

In [15]:
# Train the linear probing model
history = linear_probe_model.fit(
    x_train, y_train,
    validation_data=(x_test, y_test),
    batch_size=64,
    epochs=60
)

Epoch 1/60
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m319s[0m 367ms/step - accuracy: 0.2293 - loss: 2.1594 - val_accuracy: 0.1000 - val_loss: 2.3185
Epoch 2/60
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m315s[0m 359ms/step - accuracy: 0.4254 - loss: 1.5792 - val_accuracy: 0.1000 - val_loss: 2.3368
Epoch 3/60
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m319s[0m 355ms/step - accuracy: 0.4793 - loss: 1.4556 - val_accuracy: 0.1000 - val_loss: 2.3189
Epoch 4/60
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m322s[0m 355ms/step - accuracy: 0.4756 - loss: 1.4670 - val_accuracy: 0.1000 - val_loss: 2.3499
Epoch 5/60
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m277s[0m 354ms/step - accuracy: 0.5023 - loss: 1.3967 - val_accuracy: 0.1829 - val_loss: 2.2729
Epoch 6/60
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m329s[0m 363ms/step - accuracy: 0.5121 - loss: 1.3666 - val_accuracy: 0.3681 - val_loss: 1.8046
Epoc

In [16]:
plot_training_history(history, SAVE_MODEL_PATH)