In [1]:
# Notebook Setup & Imports
import os
from pathlib import Path
import sys
import tensorflow as tf

current_dir = Path.cwd()
parent_dir = current_dir.parent
sys.path.append(str(parent_dir))
from raw_gen_model.raw_gen_encoders.raw_gen_drone_encoder import DroneEncoder
# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

from gen_dataset_generator import create_datasets
from helper import  set_up_environment
from config import EPOCHS, TRAIN_DATASET_FILE_PATH, VAL_DATASET_FILE_PATH
from tensorflow.keras import backend as K
import tf_keras
from tf_keras.saving import register_keras_serializable
from tqdm import tqdm
 
tf.config.run_functions_eagerly(True)

2025-06-25 09:05:04.329884: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750842304.359502    1141 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750842304.368281    1141 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-25 09:05:04.398078: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
I0000 00:00:1750842319.487382    1141 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 22

In [2]:
# Now we define the loss
@register_keras_serializable()
class TranslationLoss(tf_keras.losses.Loss):
    def __init__(self, name="translation_loss"):
        super().__init__(name=name)

    def call(self, y_true, y_pred):
        sat_like_tokens = y_pred
        real_sat_tokens = y_true

        # We can do a simple L2 (MSE) or L1 or even a cosine similarity loss
        # For example, L2:
        diff = sat_like_tokens - real_sat_tokens
        mse  = tf.reduce_mean(tf.square(diff))

        return mse

In [3]:
# Training Step & Loop

"""
We can re-use the general train_step logic, but adapt it to the new two-branch approach.
"""

@tf.function
def train_step(drone_encoder, satellite_encoder, optimizer, loss_fn, drone_batch, sat_batch):
    with tf.GradientTape() as tape:
            drone_tokens = satellite_encoder(drone_batch)
            # Produce satellite-like tokens from drone images.
            predicted_tokens, _ = drone_encoder(drone_tokens, training=True)
            # Compute the loss between the predicted tokens and the fixed satellite tokens.
            satellite_tokens = satellite_encoder(sat_batch)
            loss = loss_fn(satellite_tokens, predicted_tokens)

    # Update only the DroneEncoder's parameters.
    gradients = tape.gradient(loss, drone_encoder.trainable_variables)
    optimizer.apply_gradients(zip(gradients, drone_encoder.trainable_variables))

    return loss, predicted_tokens

def train_model(drone_encoder: DroneEncoder, satellite_encoder, train_dataset, val_dataset, epochs, optimizer, callbacks=None):
    train_loss_metric = tf_keras.metrics.Mean()
    val_loss_metric   = tf_keras.metrics.Mean()

    loss_fn = TranslationLoss()

    # Initialize callbacks if provided
    callbacks = callbacks or []
    for callback in callbacks:
        callback.drone_model = drone_encoder

    for epoch in range(epochs):
        print(f"\nEpoch {epoch + 1}/{epochs}")
        for callback in callbacks:
            callback.on_epoch_begin(epoch)

        # --- TRAIN LOOP ---
        train_loss_metric.reset_states()
        with tqdm(total=len(train_dataset), desc="Training", ncols=100, leave=False) as pbar:
            for step, (sat_batch, drone_batch) in enumerate(train_dataset):
                for callback in callbacks:
                    callback.on_batch_begin(step)

                loss, outputs = train_step(drone_encoder, satellite_encoder, optimizer, loss_fn, drone_batch, sat_batch)
                train_loss_metric.update_state(loss)

                pbar.set_postfix({"loss": f"{train_loss_metric.result().numpy():.4f}"})
                pbar.update(1)

                for callback in callbacks:
                    callback.on_batch_end(step, {'loss': loss.numpy()})
        
        # --- VALIDATION LOOP ---
        val_loss_metric.reset_states()
        for step, (drone_batch, sat_batch) in enumerate(val_dataset):
            drone_tokens = satellite_encoder(drone_batch)
            sat_like_tokens, _ = drone_encoder(drone_tokens, training=False)
            real_sat_tokens = satellite_encoder(sat_batch)
            loss = loss_fn(real_sat_tokens,sat_like_tokens)
            val_loss_metric.update_state(loss)

        train_loss_val = train_loss_metric.result().numpy()
        val_loss_val   = val_loss_metric.result().numpy()

        print(f"Epoch {epoch+1}: train_loss={train_loss_val:.4f} - val_loss={val_loss_val:.4f}")

        # on_epoch_end
        logs = {'train_loss': train_loss_val, 'val_loss': val_loss_val}
        for callback in callbacks:
            callback.on_epoch_end(epoch, logs)

In [None]:
import gc
# Model Initialization & Run
from raw_gen_model.raw_gen_config import PATCH_SIZE
from raw_gen_model.raw_gen_helper import create_satellite_decoder, create_satellite_encoder
from raw_gen_model.raw_gen_callbacks.raw_get_save_model_callback import SaveBestRawModelCallback
from config import IMAGE_CHANNELS, IMAGE_RESOLUTION, LOGS_DIR, TRAINING_ID
from raw_gen_model.raw_gen_callbacks.raw_gen_attention_callback import AttentionVisualizationCallback
from raw_gen_model.raw_gen_callbacks.raw_gen_token_match_callback import TokenMatchingCallback
from raw_gen_model.raw_gen_callbacks.raw_gen_image_transform_callback import ImageTransformCallback

set_up_environment()

# Environment & Data Loading
# Create the train & validation datasets
train_dataset, val_dataset = create_datasets(TRAIN_DATASET_FILE_PATH, VAL_DATASET_FILE_PATH)

train_steps = len(train_dataset)
val_steps   = len(val_dataset)

print("Number of training steps:", train_steps)
print("Number of validation steps:", val_steps)


drone_to_sat = DroneEncoder()

satellite_encoder = create_satellite_encoder(patch_size=PATCH_SIZE, input_shape=(IMAGE_RESOLUTION["height"], IMAGE_RESOLUTION["width"], IMAGE_CHANNELS))
satellite_decoder = create_satellite_decoder(patch_size=PATCH_SIZE, input_shape=(IMAGE_RESOLUTION["height"], IMAGE_RESOLUTION["width"], IMAGE_CHANNELS))

# Adam or any other optimizer 
optimizer = tf_keras.optimizers.Adam(learning_rate=1e-4)


attention_callback = AttentionVisualizationCallback(
        validation_data=val_dataset,
        log_dir= LOGS_DIR / "attention_visualizations",
        save_every_n_epochs=1
    )
attention_callback.set_model(drone_to_sat)
attention_callback.set_encoder(satellite_encoder)

token_matching_callback = TokenMatchingCallback(validation_data=val_dataset)
token_matching_callback.set_model(drone_to_sat)
token_matching_callback.set_encoder(satellite_encoder)

image_transform_callback = ImageTransformCallback(
    validation_data=val_dataset,
    encoder=satellite_encoder,
    decoder=satellite_decoder,
    save_dir=LOGS_DIR /"image_transformations",
)
image_transform_callback.set_model(drone_to_sat)
save_model_callback = SaveBestRawModelCallback(drone_to_sat)

# Add checkpoint callback to the callbacks list
callbacks = [
    attention_callback, 
    token_matching_callback, 
    image_transform_callback,
    save_model_callback
]
# Train
train_model(
    drone_to_sat,
    satellite_encoder,
    train_dataset,
    val_dataset,
    epochs=EPOCHS,
    optimizer=optimizer,
    callbacks=callbacks,
)
print("TRAINING COMPLETE")
# After training completes...
drone_encoder = None
dataset = None
K.clear_session()  # Clears backend, freeing session references
gc.collect()       # Triggers Python garbage collection

ModuleNotFoundError: No module named 'raw_gen_model.raw_gen_callbacks.raw_gen_token_viz_callback'