In [2]:
from datetime import datetime
import tensorflow as tf

from tensorflow.keras.datasets import cifar10
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.models import load_model, Model
from google.colab import drive
import sys
import os
sys.path.append('/content')

from ijepa_keras.utils.encoder import MobileNetV2Encoder
from ijepa_keras.utils.pretrain import JEPA
from ijepa_keras.utils.predictor import MobileNetV2Predictor
from ijepa_keras.utils.plotting import plot_training_history


In [3]:
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:

EPOCHS = 100
BATCH_SIZE = 32
SAVE_MODEL_PATH = '/content/drive/MyDrive/IJEPA_MobileNetV2'

In [5]:
# Load CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 0us/step


In [6]:
# Initialize MobileNet V2 encoder
encoder = MobileNetV2Encoder()
context_encoder = encoder.get_context_encoder()
target_encoder = encoder.get_target_encoder()

In [7]:
# Initialize VGG16 predictor
predictor = MobileNetV2Predictor()
predictor_model = predictor.get_predictor()

In [8]:
# I-JEPA Pretraining
jepa = JEPA(context_encoder=context_encoder,
             target_encoder=target_encoder,
             predictor_model=predictor_model,
             optimizer=Adam(),
             loss_fn=MeanSquaredError(),
             model_save_path=SAVE_MODEL_PATH)

In [None]:
jepa.train(x_train, epochs=EPOCHS, batch_size=BATCH_SIZE)

Epoch 1/100: 100%|██████████| 1562/1562 [09:00<00:00,  2.89batch/s]


Epoch 1/100, Loss: 0.0087, Momentum: 0.9600
Saving best models with loss:0.00874591620082118...


Epoch 2/100: 100%|██████████| 1562/1562 [08:56<00:00,  2.91batch/s]


Epoch 2/100, Loss: 0.0093, Momentum: 0.9604


Epoch 3/100: 100%|██████████| 1562/1562 [08:56<00:00,  2.91batch/s]


Epoch 3/100, Loss: 0.0089, Momentum: 0.9608


Epoch 4/100: 100%|██████████| 1562/1562 [08:59<00:00,  2.90batch/s]


Epoch 4/100, Loss: 0.0085, Momentum: 0.9612
Saving best models with loss:0.0084965111322286...


Epoch 5/100: 100%|██████████| 1562/1562 [09:01<00:00,  2.88batch/s]


Epoch 5/100, Loss: 0.0075, Momentum: 0.9616
Saving best models with loss:0.007499184537703999...


Epoch 6/100: 100%|██████████| 1562/1562 [09:06<00:00,  2.86batch/s]


Epoch 6/100, Loss: 0.0042, Momentum: 0.9620
Saving best models with loss:0.004179295548178237...


Epoch 7/100: 100%|██████████| 1562/1562 [09:06<00:00,  2.86batch/s]


Epoch 7/100, Loss: 0.0041, Momentum: 0.9624
Saving best models with loss:0.004099608184924771...


Epoch 8/100: 100%|██████████| 1562/1562 [09:05<00:00,  2.86batch/s]


Epoch 8/100, Loss: 0.0041, Momentum: 0.9628
Saving best models with loss:0.0040644316642787235...


Epoch 9/100: 100%|██████████| 1562/1562 [09:06<00:00,  2.86batch/s]


Epoch 9/100, Loss: 0.0040, Momentum: 0.9632
Saving best models with loss:0.004037005058728235...


Epoch 10/100: 100%|██████████| 1562/1562 [09:03<00:00,  2.87batch/s]


Epoch 10/100, Loss: 0.0040, Momentum: 0.9636
Saving best models with loss:0.004026301478085832...


Epoch 11/100: 100%|██████████| 1562/1562 [09:04<00:00,  2.87batch/s]


Epoch 11/100, Loss: 0.0040, Momentum: 0.9640
Saving best models with loss:0.004004268750796397...


Epoch 12/100: 100%|██████████| 1562/1562 [09:00<00:00,  2.89batch/s]


Epoch 12/100, Loss: 0.0040, Momentum: 0.9644
Saving best models with loss:0.00399460602001968...


Epoch 13/100: 100%|██████████| 1562/1562 [09:00<00:00,  2.89batch/s]


Epoch 13/100, Loss: 0.0039, Momentum: 0.9648
Saving best models with loss:0.003948906510645015...


Epoch 14/100: 100%|██████████| 1562/1562 [09:02<00:00,  2.88batch/s]


Epoch 14/100, Loss: 0.0038, Momentum: 0.9652
Saving best models with loss:0.003833010878590402...


Epoch 15/100: 100%|██████████| 1562/1562 [09:07<00:00,  2.85batch/s]


Epoch 15/100, Loss: 0.0038, Momentum: 0.9656
Saving best models with loss:0.003830316759289389...


Epoch 16/100: 100%|██████████| 1562/1562 [09:03<00:00,  2.87batch/s]


Epoch 16/100, Loss: 0.0038, Momentum: 0.9660
Saving best models with loss:0.0038220635415549733...


Epoch 17/100: 100%|██████████| 1562/1562 [09:04<00:00,  2.87batch/s]


Epoch 17/100, Loss: 0.0038, Momentum: 0.9664
Saving best models with loss:0.0038150843657987258...


Epoch 18/100: 100%|██████████| 1562/1562 [09:07<00:00,  2.85batch/s]


Epoch 18/100, Loss: 0.0038, Momentum: 0.9668


Epoch 19/100: 100%|██████████| 1562/1562 [09:03<00:00,  2.87batch/s]


Epoch 19/100, Loss: 0.0038, Momentum: 0.9672
Saving best models with loss:0.003814449214833227...


Epoch 20/100: 100%|██████████| 1562/1562 [09:03<00:00,  2.87batch/s]


Epoch 20/100, Loss: 0.0038, Momentum: 0.9676
Saving best models with loss:0.0037877608623377093...


Epoch 21/100: 100%|██████████| 1562/1562 [09:05<00:00,  2.86batch/s]


Epoch 21/100, Loss: 0.0038, Momentum: 0.9680


Epoch 22/100: 100%|██████████| 1562/1562 [09:03<00:00,  2.88batch/s]


Epoch 22/100, Loss: 0.0038, Momentum: 0.9684


Epoch 23/100: 100%|██████████| 1562/1562 [09:01<00:00,  2.89batch/s]


Epoch 23/100, Loss: 0.0038, Momentum: 0.9688
Saving best models with loss:0.0037874090753581553...


Epoch 24/100: 100%|██████████| 1562/1562 [09:00<00:00,  2.89batch/s]


Epoch 24/100, Loss: 0.0038, Momentum: 0.9692
Saving best models with loss:0.003777984813989764...


Epoch 25/100: 100%|██████████| 1562/1562 [09:04<00:00,  2.87batch/s]


Epoch 25/100, Loss: 0.0037, Momentum: 0.9696
Saving best models with loss:0.003703092954541281...


Epoch 26/100: 100%|██████████| 1562/1562 [08:59<00:00,  2.89batch/s]


Epoch 26/100, Loss: 0.0034, Momentum: 0.9700
Saving best models with loss:0.0034002044223840427...


Epoch 27/100: 100%|██████████| 1562/1562 [08:58<00:00,  2.90batch/s]


Epoch 27/100, Loss: 0.0034, Momentum: 0.9704
Saving best models with loss:0.003388345530721694...


Epoch 28/100: 100%|██████████| 1562/1562 [08:59<00:00,  2.90batch/s]


Epoch 28/100, Loss: 0.0034, Momentum: 0.9708
Saving best models with loss:0.0033835856078214058...


Epoch 29/100: 100%|██████████| 1562/1562 [08:58<00:00,  2.90batch/s]


Epoch 29/100, Loss: 0.0034, Momentum: 0.9712


Epoch 30/100: 100%|██████████| 1562/1562 [09:09<00:00,  2.84batch/s]


Epoch 30/100, Loss: 0.0034, Momentum: 0.9716
Saving best models with loss:0.0033736133989108854...


Epoch 31/100: 100%|██████████| 1562/1562 [09:07<00:00,  2.85batch/s]


Epoch 31/100, Loss: 0.0034, Momentum: 0.9720


Epoch 32/100: 100%|██████████| 1562/1562 [09:06<00:00,  2.86batch/s]


Epoch 32/100, Loss: 0.0034, Momentum: 0.9724


Epoch 33/100: 100%|██████████| 1562/1562 [09:04<00:00,  2.87batch/s]


Epoch 33/100, Loss: 0.0034, Momentum: 0.9728


Epoch 34/100: 100%|██████████| 1562/1562 [09:02<00:00,  2.88batch/s]


Epoch 34/100, Loss: 0.0034, Momentum: 0.9732


Epoch 35/100: 100%|██████████| 1562/1562 [08:58<00:00,  2.90batch/s]


Epoch 35/100, Loss: 0.0034, Momentum: 0.9736
Saving best models with loss:0.0033644919870087398...


Epoch 36/100: 100%|██████████| 1562/1562 [08:59<00:00,  2.89batch/s]


Epoch 36/100, Loss: 0.0034, Momentum: 0.9740


Epoch 37/100: 100%|██████████| 1562/1562 [08:59<00:00,  2.89batch/s]


Epoch 37/100, Loss: 0.0034, Momentum: 0.9744


Epoch 38/100: 100%|██████████| 1562/1562 [08:57<00:00,  2.91batch/s]


Epoch 38/100, Loss: 0.0034, Momentum: 0.9748


Epoch 39/100: 100%|██████████| 1562/1562 [08:59<00:00,  2.89batch/s]


Epoch 39/100, Loss: 0.0034, Momentum: 0.9752
Saving best models with loss:0.0033560724214682885...


Epoch 40/100: 100%|██████████| 1562/1562 [09:00<00:00,  2.89batch/s]


Epoch 40/100, Loss: 0.0034, Momentum: 0.9756


Epoch 41/100: 100%|██████████| 1562/1562 [08:59<00:00,  2.90batch/s]


Epoch 41/100, Loss: 0.0034, Momentum: 0.9760
Saving best models with loss:0.0033556232649222415...


Epoch 42/100: 100%|██████████| 1562/1562 [09:02<00:00,  2.88batch/s]


Epoch 42/100, Loss: 0.0034, Momentum: 0.9764


Epoch 43/100: 100%|██████████| 1562/1562 [09:01<00:00,  2.88batch/s]


Epoch 43/100, Loss: 0.0034, Momentum: 0.9768


Epoch 44/100: 100%|██████████| 1562/1562 [08:58<00:00,  2.90batch/s]


Epoch 44/100, Loss: 0.0034, Momentum: 0.9772
Saving best models with loss:0.003351482979490609...


Epoch 45/100: 100%|██████████| 1562/1562 [08:59<00:00,  2.89batch/s]


Epoch 45/100, Loss: 0.0034, Momentum: 0.9776


Epoch 46/100: 100%|██████████| 1562/1562 [09:00<00:00,  2.89batch/s]


Epoch 46/100, Loss: 0.0034, Momentum: 0.9780


Epoch 47/100: 100%|██████████| 1562/1562 [09:01<00:00,  2.89batch/s]


Epoch 47/100, Loss: 0.0033, Momentum: 0.9784
Saving best models with loss:0.0033426965581176853...


Epoch 48/100: 100%|██████████| 1562/1562 [08:58<00:00,  2.90batch/s]


Epoch 48/100, Loss: 0.0034, Momentum: 0.9788


Epoch 49/100: 100%|██████████| 1562/1562 [08:59<00:00,  2.89batch/s]


Epoch 49/100, Loss: 0.0034, Momentum: 0.9792


Epoch 50/100: 100%|██████████| 1562/1562 [08:58<00:00,  2.90batch/s]


Epoch 50/100, Loss: 0.0034, Momentum: 0.9796


Epoch 51/100: 100%|██████████| 1562/1562 [09:01<00:00,  2.89batch/s]


Epoch 51/100, Loss: 0.0034, Momentum: 0.9800


Epoch 52/100: 100%|██████████| 1562/1562 [09:02<00:00,  2.88batch/s]


Epoch 52/100, Loss: 0.0033, Momentum: 0.9804
Saving best models with loss:0.0033180453531651563...


Epoch 53/100: 100%|██████████| 1562/1562 [08:59<00:00,  2.90batch/s]


Epoch 53/100, Loss: 0.0033, Momentum: 0.9808
Saving best models with loss:0.0033174012657958495...


Epoch 54/100: 100%|██████████| 1562/1562 [09:00<00:00,  2.89batch/s]


Epoch 54/100, Loss: 0.0033, Momentum: 0.9812
Saving best models with loss:0.0033123240319930885...


Epoch 55/100: 100%|██████████| 1562/1562 [09:00<00:00,  2.89batch/s]


Epoch 55/100, Loss: 0.0033, Momentum: 0.9816
Saving best models with loss:0.003307375573510216...


Epoch 56/100: 100%|██████████| 1562/1562 [09:00<00:00,  2.89batch/s]


Epoch 56/100, Loss: 0.0033, Momentum: 0.9820
Saving best models with loss:0.003303120848836041...


Epoch 57/100: 100%|██████████| 1562/1562 [08:59<00:00,  2.89batch/s]


Epoch 57/100, Loss: 0.0033, Momentum: 0.9824


Epoch 58/100: 100%|██████████| 1562/1562 [09:02<00:00,  2.88batch/s]


Epoch 58/100, Loss: 0.0033, Momentum: 0.9828


Epoch 59/100: 100%|██████████| 1562/1562 [08:59<00:00,  2.90batch/s]


Epoch 59/100, Loss: 0.0033, Momentum: 0.9832
Saving best models with loss:0.003299148352017743...


Epoch 60/100: 100%|██████████| 1562/1562 [08:59<00:00,  2.90batch/s]


Epoch 60/100, Loss: 0.0033, Momentum: 0.9836


Epoch 61/100: 100%|██████████| 1562/1562 [09:00<00:00,  2.89batch/s]


Epoch 61/100, Loss: 0.0033, Momentum: 0.9840


Epoch 62/100: 100%|██████████| 1562/1562 [08:59<00:00,  2.89batch/s]


Epoch 62/100, Loss: 0.0033, Momentum: 0.9844
Saving best models with loss:0.0032964172147312196...


Epoch 63/100: 100%|██████████| 1562/1562 [08:59<00:00,  2.90batch/s]


Epoch 63/100, Loss: 0.0033, Momentum: 0.9848


Epoch 64/100: 100%|██████████| 1562/1562 [09:00<00:00,  2.89batch/s]


Epoch 64/100, Loss: 0.0033, Momentum: 0.9852


Epoch 65/100: 100%|██████████| 1562/1562 [08:59<00:00,  2.89batch/s]


Epoch 65/100, Loss: 0.0033, Momentum: 0.9856


Epoch 66/100: 100%|██████████| 1562/1562 [08:58<00:00,  2.90batch/s]


Epoch 66/100, Loss: 0.0033, Momentum: 0.9860
Saving best models with loss:0.00329293635427575...


Epoch 67/100: 100%|██████████| 1562/1562 [09:00<00:00,  2.89batch/s]


Epoch 67/100, Loss: 0.0033, Momentum: 0.9864


Epoch 68/100: 100%|██████████| 1562/1562 [09:06<00:00,  2.86batch/s]


Epoch 68/100, Loss: 0.0033, Momentum: 0.9868
Saving best models with loss:0.0032917736101151468...


Epoch 69/100: 100%|██████████| 1562/1562 [09:02<00:00,  2.88batch/s]


Epoch 69/100, Loss: 0.0033, Momentum: 0.9872


Epoch 70/100:  62%|██████▏   | 969/1562 [05:36<03:22,  2.92batch/s]

In [9]:
model_path = '/content/drive/MyDrive/IJEPA_MobileNetV2best_target_encoder.keras'
target_encoder  = tf.keras.models.load_model(model_path)
# Load saved models from disk
# target_encoder = load_model(os.path.join(SAVE_MODEL_PATH, "best_target_encoder.keras"))


In [10]:
# Freeze the target encoder
for layer in target_encoder.layers:
    layer.trainable = False


In [11]:
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model

# Build linear probing model
def build_linear_probe_model(encoder, num_classes):
    # Input for image
    input_layer = Input(shape=(32, 32, 3), name="embedding input")

    # Pass input through the frozen encoder
    x = encoder(input_layer)

    # Add linear classification head
    output_layer = Dense(num_classes, activation="softmax", name="classification_head")(x)

    # Build model
    model = Model(inputs=input_layer, outputs=output_layer)
    return model


In [12]:
# Create the linear probing model
num_classes = 10  # CIFAR-10 has 10 classes
linear_probe_model = build_linear_probe_model(target_encoder, num_classes)

In [13]:
# Compile the model
linear_probe_model.compile(
    optimizer=Adam(),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

In [14]:
# Train the linear probing model
history = linear_probe_model.fit(
    x_train, y_train,
    validation_data=(x_test, y_test),
    batch_size=64,
    epochs=60
)

Epoch 1/60
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m285s[0m 325ms/step - accuracy: 0.4838 - loss: 1.6100 - val_accuracy: 0.3254 - val_loss: 4.6461
Epoch 2/60
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m263s[0m 327ms/step - accuracy: 0.6607 - loss: 1.0172 - val_accuracy: 0.4255 - val_loss: 3.3290
Epoch 3/60
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m262s[0m 335ms/step - accuracy: 0.5945 - loss: 1.2062 - val_accuracy: 0.3926 - val_loss: 2.2459
Epoch 4/60
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m256s[0m 327ms/step - accuracy: 0.6184 - loss: 1.1201 - val_accuracy: 0.2977 - val_loss: 4.8124
Epoch 5/60
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m257s[0m 321ms/step - accuracy: 0.6352 - loss: 1.0790 - val_accuracy: 0.2643 - val_loss: 4.1825
Epoch 6/60
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m266s[0m 326ms/step - accuracy: 0.5996 - loss: 1.1875 - val_accuracy: 0.3271 - val_loss: 2.9720
Epoc

In [15]:
plot_training_history(history, SAVE_MODEL_PATH)