In [3]:
# Import libraries to check installation and TPU availability
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
import numpy as np
import os
from PIL import Image
import imageio
from transformers import CLIPTokenizer # We'll use the CLIP tokenizer

# Check JAX and TPU setup
print("JAX Version:", jax.__version__)
print("Available Devices:", jax.devices())
if not jax.devices('tpu'):
    print("Warning: TPU not detected. Check Runtime -> Change runtime type.")
else:
    print(f"TPU detected with {jax.device_count()} devices.")

JAX Version: 0.5.3
Available Devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
TPU detected with 8 devices.


In [None]:
#@title Define Paths for Preprocessing (Using Local Colab Storage)

import os
import pandas as pd
import time
import requests
import numpy as np
from PIL import Image
import imageio
# No need for tfds here yet
# from transformers import CLIPTokenizer # Assume loaded from previous cell
# import tensorflow as tf # No need for tf.io.gfile for local storage

# --- Configuration ---
TSV_FILE_PATH = '/content/tgif-v1.0.tsv' # Replace with your TSV file path


# --- Use Local Colab Storage ---
# IMPORTANT: This directory is temporary and will be deleted when the runtime restarts!
PREPROCESSED_DATA_DIR = '/content/tgif_preprocessed_local_64px_16f'
# Manifest file path is also local
PROCESSED_MANIFEST_FILE = os.path.join(PREPROCESSED_DATA_DIR, 'processed_manifest.csv')

# Preprocessing constants (ensure these match your desired settings)
IMAGE_SIZE = 64
N_FRAMES = 16
MAX_LENGTH = 77

# Temporary local directory for downloads (this is fine locally)
TEMP_DOWNLOAD_DIR = '/content/temp_gifs'

# --- Create directories ---
os.makedirs(PREPROCESSED_DATA_DIR, exist_ok=True)
os.makedirs(TEMP_DOWNLOAD_DIR, exist_ok=True)
print(f"Temporary download directory: {TEMP_DOWNLOAD_DIR}")
print(f"Preprocessed data directory (Local Colab Disk): {PREPROCESSED_DATA_DIR}") # Emphasize local
print(f"Manifest file path (Local Colab Disk): {PROCESSED_MANIFEST_FILE}")

# --- Initialize Tokenizer (ensure this cell was run before) ---
try:
    if 'tokenizer' not in globals():
      from transformers import CLIPTokenizer
      tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
    print("CLIP Tokenizer loaded.")
except Exception as e:
    print(f"Error loading tokenizer. Make sure transformers is installed and model exists: {e}")

# --- Helper functions (make sure they are defined) ---
# Define or ensure sample_frames_from_path and preprocess_text_for_saving are available
# (Definitions are same as before, no changes needed for local storage option)
def sample_frames_from_path(gif_path, n_frames, image_size):
    """Reads a GIF from path, samples, resizes, normalizes."""
    try:
        reader = imageio.get_reader(gif_path)
        frames_raw = []
        for frame in reader:
            pil_frame = Image.fromarray(frame).convert('RGB')
            frames_raw.append(np.array(pil_frame))
        reader.close()
    except Exception as e:
        # print(f"Error reading GIF {os.path.basename(gif_path)}: {e}")
        return None

    if not frames_raw: return None
    total_frames = len(frames_raw)
    if total_frames == 0: return None

    indices = np.linspace(0, total_frames - 1, n_frames, dtype=int)
    processed_frames = []
    for idx in indices:
        frame = frames_raw[idx]
        pil_img = Image.fromarray(frame)
        resized_img = pil_img.resize((image_size, image_size), Image.Resampling.LANCZOS)
        np_img = np.array(resized_img)
        normalized_img = (np_img / 255.0) * 2.0 - 1.0
        processed_frames.append(normalized_img)

    return np.stack(processed_frames).astype(np.float32)

def preprocess_text_for_saving(text, tokenizer, max_length):
    """Tokenizes text using CLIP tokenizer."""
    inputs = tokenizer(text, padding='max_length', truncation=True, max_length=max_length, return_tensors="np")
    return inputs['input_ids'][0].astype(np.int32)

print("Helper functions defined/verified.")
print("--- !!! ---")
print("WARNING: Preprocessed data will be saved to the local Colab disk.")
print("This data will be LOST if the runtime restarts.")
print("Save model checkpoints frequently to Google Drive during training.")
print("--- !!! ---")

In [None]:
#@title Load CLIP Text Encoder Model (Corrected SyntaxError)

import jax
# Ensure flax and transformers are imported
import flax.linen as nn
from transformers import CLIPTokenizer, FlaxCLIPTextModelWithProjection
import numpy as np # Ensure numpy is imported

# Model name (using a common one)
CLIP_MODEL_NAME = "openai/clip-vit-base-patch32"

# --- Initialize variable BEFORE try/except blocks ---
text_embedding_dim = None # Initialize to None

print(f"Loading CLIP Tokenizer: {CLIP_MODEL_NAME}")
try:
    # Tokenizer should already be loaded from preprocessing setup cell
    if 'tokenizer' not in globals():
         print("Error: Tokenizer not found. Please run the preprocessing setup cell first.")
         # Attempt to load it again just in case
         tokenizer = CLIPTokenizer.from_pretrained(CLIP_MODEL_NAME)
         print("Tokenizer loaded now.")
    else:
        print("Tokenizer already exists.")

    print(f"Loading Flax CLIP Text Model (with Projection): {CLIP_MODEL_NAME}")
    # Load the Flax version of the model for JAX compatibility
    clip_text_model = FlaxCLIPTextModelWithProjection.from_pretrained(CLIP_MODEL_NAME)
    print("Flax CLIP Text Model loaded successfully.")

    # --- Determine text_embedding_dim ---
    try:
        # Need MAX_LENGTH (use constant from pipeline setup)
        if 'MAX_LENGTH' not in globals():
             raise NameError("MAX_LENGTH not defined. Ensure pipeline setup cell was run.")
        dummy_input_ids = np.array([[tokenizer.eos_token_id] * MAX_LENGTH])
        output = clip_text_model(input_ids=dummy_input_ids)
        # --- Assign directly (no 'global' needed) ---
        text_embedding_dim = output.text_embeds.shape[-1]
        print(f"CLIP text embedding dimension determined: {text_embedding_dim}")
    except Exception as e_inner:
        print(f"Could not determine embedding dimension automatically: {e_inner}")
        # --- Assign directly (no 'global' needed) ---
        text_embedding_dim = 512 # Fallback
        print(f"Using fallback text embedding dimension: {text_embedding_dim}")

except Exception as e_outer:
    print(f"Error during CLIP model loading or dimension determination: {e_outer}")
    # --- Assign directly only if still None (no 'global' needed) ---
    if text_embedding_dim is None:
         text_embedding_dim = 512 # Placeholder
         print(f"Using placeholder text embedding dimension due to outer error: {text_embedding_dim}")

# --- Final check outside all try/except ---
if text_embedding_dim is None:
    print("ERROR: text_embedding_dim could not be determined or set.")
    # Optionally raise an error or set a default again
    # text_embedding_dim = 512
    # print(f"Setting default text_embedding_dim = {text_embedding_dim}")
else:
    print(f"Final text_embedding_dim value: {text_embedding_dim}")

In [7]:
#@title Load TGIF TSV File and Create IDs

import csv
import pandas as pd # Keep pandas import if needed elsewhere, but not for primary loading here

annotations_to_process = []
# TSV_FILE_PATH should be defined in the previous cell ([17])
print(f"Attempting to load TSV from: {TSV_FILE_PATH}")

try:
    with open(TSV_FILE_PATH, 'r', encoding='utf-8') as f:
        # Use csv.reader, specifying tab delimiter.
        # Assumes NO header row in your TSV based on the previous error.
        # If there IS a header row you want to skip, uncomment the next line:
        # next(f) # Skips the first line (header)
        reader = csv.reader(f, delimiter='\t')

        for i, row in enumerate(reader):
            if len(row) == 2:
                url, description = row
                # Add more robust checks
                if isinstance(url, str) and url.lower().startswith('http') and url.lower().endswith('.gif') and isinstance(description, str) and description.strip():
                    # Create a unique ID - using index is simple and effective
                    unique_id = f"item_{i:07d}"
                    annotations_to_process.append({
                        'id': unique_id,        # Add the generated ID
                        'url': url,             # Add the URL
                        'text': description.strip() # Add the cleaned text
                    })
                else:
                    print(f"Skipping row {i+1} due to invalid format/content: {row}") # Log invalid rows
            else:
                 print(f"Skipping row {i+1} due to incorrect number of columns ({len(row)}): {row}") # Log rows with wrong column count

    print(f"Loaded {len(annotations_to_process)} valid annotations from {TSV_FILE_PATH}")
    if annotations_to_process:
        print("Sample of loaded data structure:", annotations_to_process[0])
    else:
        print("Warning: No valid annotations were loaded. Check TSV path and format.")

except FileNotFoundError:
    print(f"ERROR: TSV file not found at {TSV_FILE_PATH}. Please upload it or correct the path.")
except Exception as e:
    print(f"Error reading TSV file: {e}")

Attempting to load TSV from: /content/tgif-v1.0.tsv
Loaded 125782 valid annotations from /content/tgif-v1.0.tsv
Sample of loaded data structure: {'id': 'item_0000000', 'url': 'https://38.media.tumblr.com/9f6c25cc350f12aa74a7dc386a5c4985/tumblr_mevmyaKtDf1rgvhr8o1_500.gif', 'text': 'a man is glaring, and someone with sunglasses appears.'}


In [None]:
#@title Run Preprocessing Loop (Download, Process, Save Locally)

import pandas as pd # Make sure pandas is imported

processed_list = [] # To store info for the new manifest
processed_ids = set() # Keep track of IDs already processed (for resuming)
failed_downloads = []
failed_processing = []

# --- Optional: Load existing manifest from local path to resume ---
resume_run = True # Set to False to re-process everything
if resume_run and os.path.exists(PROCESSED_MANIFEST_FILE):
    print(f"Found existing manifest: {PROCESSED_MANIFEST_FILE}. Resuming...")
    try:
        # Use standard pandas to read local CSV
        existing_df = pd.read_csv(PROCESSED_MANIFEST_FILE)
        processed_ids = set(existing_df['id'])
        processed_list = existing_df.to_dict('records') # Load existing records
        print(f"Loaded {len(processed_ids)} previously processed IDs.")
    except Exception as e:
        print(f"Warning: Could not load or parse existing manifest. Starting fresh. Error: {e}")
        processed_list = [] # Reset if manifest is corrupted
        processed_ids = set()


# --- Iterate and process ---
start_time = time.time()
process_limit = 500 # Set to None to process all, or e.g., 100 for testing
save_interval = 500 # Save manifest backup periodically

headers = {'User-Agent': 'Mozilla/5.0'}

# Ensure annotations_to_process is loaded from the previous cell
if 'annotations_to_process' not in globals() or not annotations_to_process:
     print("ERROR: annotations_to_process not found or empty. Please load the TSV file first.")
else:
    print(f"Starting preprocessing for {len(annotations_to_process)} items...")
    for i, item in enumerate(annotations_to_process):
        item_id = item['id']
        url = item['url']
        description = item['text']

        if process_limit is not None and i >= process_limit:
          print(f"Reached processing limit of {process_limit}. Stopping.")
          break

        if item_id in processed_ids:
            continue # Skip if already processed in a previous run

        # --- 1. Download ---
        temp_gif_path = os.path.join(TEMP_DOWNLOAD_DIR, f"{item_id}.gif")
        download_success = False
        try:
            response = requests.get(url, headers=headers, timeout=20, stream=True)
            response.raise_for_status()
            with open(temp_gif_path, 'wb') as f:
                for chunk in response.iter_content(chunk_size=8192):
                    f.write(chunk)
            download_success = True
        except requests.exceptions.RequestException as e:
            failed_downloads.append({'id': item_id, 'url': url, 'error': str(e)})
        except Exception as e: # Catch other potential errors
            failed_downloads.append({'id': item_id, 'url': url, 'error': f"Unexpected: {str(e)}"})

        if not download_success:
            if os.path.exists(temp_gif_path): os.remove(temp_gif_path)
            continue

        # --- 2. Process GIF Frames ---
        frames = sample_frames_from_path(temp_gif_path, N_FRAMES, IMAGE_SIZE)

        # --- 3. Process Text ---
        tokens = preprocess_text_for_saving(description, tokenizer, MAX_LENGTH)

        # --- 4. Save to Local Colab Disk ---
        if frames is not None:
            # Define local paths for .npy files
            frames_save_path = os.path.join(PREPROCESSED_DATA_DIR, f"{item_id}_frames.npy")
            tokens_save_path = os.path.join(PREPROCESSED_DATA_DIR, f"{item_id}_tokens.npy")

            try:
                # Use standard np.save for local paths
                np.save(frames_save_path, frames)
                np.save(tokens_save_path, tokens)

                # Record success
                processed_list.append({
                    'id': item_id,
                    'frames_path': frames_save_path, # Store local path
                    'tokens_path': tokens_save_path, # Store local path
                    'original_text': description
                })
                processed_ids.add(item_id)

            except Exception as e:
                print(f"ERROR saving files locally for {item_id}: {e}")
                failed_processing.append({'id': item_id, 'url': url, 'error': f"Saving failed: {str(e)}"})
                # Clean up potentially corrupted local files
                if os.path.exists(frames_save_path): os.remove(frames_save_path)
                if os.path.exists(tokens_save_path): os.remove(tokens_save_path)
                if item_id in processed_ids: processed_ids.remove(item_id)

        else:
            failed_processing.append({'id': item_id, 'url': url, 'error': "Frame sampling returned None"})

        # --- 5. Clean up downloaded GIF ---
        if os.path.exists(temp_gif_path):
            os.remove(temp_gif_path)

        # --- Progress Update & Periodic Manifest Save ---
        processed_count = len(processed_ids)
        if processed_count > 0 and processed_count % save_interval == 0:
            elapsed_time = time.time() - start_time
            # Calculate estimated total time if possible
            total_items = len(annotations_to_process) if process_limit is None else process_limit
            if processed_count > 0 and total_items > 0:
                eta_seconds = (elapsed_time / processed_count) * (total_items - processed_count)
                eta_minutes = eta_seconds / 60
                print(f"Processed {processed_count}/{total_items} items... Time: {elapsed_time:.2f}s (ETA: {eta_minutes:.1f} min)")
            else:
                 print(f"Processed {processed_count}/{total_items} items... Time: {elapsed_time:.2f}s")

            # Save backup manifest locally
            backup_manifest_path = os.path.join(PREPROCESSED_DATA_DIR, 'processed_manifest_backup.csv')
            temp_df = pd.DataFrame(processed_list)
            try:
              # Use standard pandas to write local CSV
              temp_df.to_csv(backup_manifest_path, index=False)
            except Exception as e:
              print(f"Warning: Failed to save manifest backup: {e}")


# --- Final Save of Manifest ---
end_time = time.time()
print(f"\nPreprocessing finished. Total time: {end_time - start_time:.2f} seconds.")
print(f"Successfully processed: {len(processed_ids)}")
print(f"Failed downloads: {len(failed_downloads)}")
print(f"Failed processing/saving: {len(failed_processing)}")

if processed_list:
    final_df = pd.DataFrame(processed_list)
    try:
        # Use standard pandas to write final local CSV
        final_df.to_csv(PROCESSED_MANIFEST_FILE, index=False)
        print(f"Final manifest saved locally to: {PROCESSED_MANIFEST_FILE}")
        print(final_df.head())
    except Exception as e:
        print(f"ERROR: Failed to save final manifest: {e}")
else:
    print("No items were successfully processed and saved.")


In [None]:
#@title Create tf.data Pipeline (Using Preprocessed Data from Local Disk)

import tensorflow as tf
import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp
import tensorflow_datasets as tfds
import os # Need os for checking path
import time # For basic timing

# --- Disable TF GPU usage to prevent conflicts with JAX on TPU ---
# Important if you have a GPU available in Colab but are using TPU
try:
    tf.config.set_visible_devices([], 'GPU')
    print("TensorFlow GPU visibility disabled.")
except Exception as e:
    print(f"Could not disable TensorFlow GPU visibility: {e}")

# --- Constants (should match preprocessing) ---
# Make sure these constants are defined and match the preprocessing step
# If not, redefine them here:
IMAGE_SIZE = 64
N_FRAMES = 16
MAX_LENGTH = 77

# --- Local Path to Manifest ---
# Ensure this path points to the manifest created in the preprocessing step
PREPROCESSED_DATA_DIR = '/content/tgif_preprocessed_local_64px_16f' # Base directory where .npy files are
PROCESSED_MANIFEST_FILE = os.path.join(PREPROCESSED_DATA_DIR, 'processed_manifest.csv')

# --- Batching and Shuffling Parameters ---
# Adjust BATCH_SIZE_PER_DEVICE based on TPU memory during training
# Start small (e.g., 4 or 8) and increase if possible
BATCH_SIZE_PER_DEVICE = 8
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_DEVICE * jax.device_count()
BUFFER_SIZE = 1000 # How many items to preload for shuffling

print(f"Targeting Global Batch Size: {GLOBAL_BATCH_SIZE} ({BATCH_SIZE_PER_DEVICE} per device)")
print(f"Loading data manifest from local path: {PROCESSED_MANIFEST_FILE}")

# --- Load the processed manifest from local disk ---
manifest_list = []
try:
    if not os.path.exists(PROCESSED_MANIFEST_FILE):
         raise FileNotFoundError(f"Manifest file not found at {PROCESSED_MANIFEST_FILE}")
    # Use standard pandas to read local CSV
    manifest_df = pd.read_csv(PROCESSED_MANIFEST_FILE)
    # Ensure required columns exist
    if not all(col in manifest_df.columns for col in ['frames_path', 'tokens_path']):
        raise ValueError("Manifest CSV missing required columns: 'frames_path' and 'tokens_path'")
    manifest_list = manifest_df.to_dict('records')
    print(f"Loaded manifest with {len(manifest_list)} processed items.")
    if not manifest_list:
        print("Warning: Manifest file loaded but is empty.")

except FileNotFoundError as e:
    print(f"ERROR: {e}")
    print("Please ensure the preprocessing step writing locally completed successfully and the path is correct.")
except Exception as e:
    print(f"Error loading or parsing processed manifest: {e}")
    manifest_list = [] # Ensure it's empty on error

# --- Data Generator using Local Preprocessed Files ---
def preprocessed_data_generator():
    """Yields preprocessed (tokens, frames) pairs by reading local .npy files."""
    if not manifest_list:
         print("Generator stopping: Manifest list is empty.")
         raise StopIteration # Stop if no data to load

    # Shuffle indices at the start of each epoch iteration
    indices = np.arange(len(manifest_list))
    np.random.shuffle(indices)
    print(f"Data Generator: Starting epoch with {len(indices)} shuffled indices.")

    files_loaded = 0
    files_failed = 0
    start_time = time.time()

    for i in indices:
        item = manifest_list[i]
        # Paths are local Colab paths - already checked if manifest loaded correctly
        frames_path = item['frames_path']
        tokens_path = item['tokens_path']

        # Minimal check: ensure paths are strings (already done by manifest loading)
        # More robust check: Verify file existence here? Can slow down if many files.
        # Let's rely on the try-except block during loading.

        try:
            # Use standard np.load for local .npy files
            # No need for tf.io.gfile here
            tokens = np.load(tokens_path)
            frames = np.load(frames_path)

            # Basic shape and type check (crucial)
            if tokens.shape == (MAX_LENGTH,) and frames.shape == (N_FRAMES, IMAGE_SIZE, IMAGE_SIZE, 3):
                 # Ensure correct dtypes before yielding
                 yield {'input_ids': tokens.astype(np.int32), 'pixel_values': frames.astype(np.float32)}
                 files_loaded += 1
            else:
                 # Optional: Log shape mismatches more verbosely if needed
                 # print(f"Warning: Shape/Type mismatch for item {item.get('id', 'N/A')}. "
                 #       f"Tokens: {tokens.shape} ({tokens.dtype}), Frames: {frames.shape} ({frames.dtype}). Skipping.")
                 files_failed += 1


        except FileNotFoundError:
             # print(f"Warning: File not found for item {item.get('id', 'N/A')} (Paths: {tokens_path}, {frames_path}). Skipping.")
             files_failed += 1
        except Exception as e:
            # Log other errors loading specific files but continue
            print(f"Warning: Error loading local data for item {item.get('id', 'N/A')} (Paths: {tokens_path}, {frames_path}). Error: {e}. Skipping.")
            files_failed += 1

        # Optional: Print progress within the generator for debugging long epochs
        # if (files_loaded + files_failed) % 1000 == 0 and (files_loaded + files_failed) > 0:
        #     elapsed = time.time() - start_time
        #     print(f"Generator progress: {(files_loaded + files_failed)} items processed ({files_loaded} loaded, {files_failed} failed) in {elapsed:.2f}s")


    print(f"Data Generator: Epoch finished. Loaded {files_loaded}, Failed {files_failed}.")


# Define output signature for tf.data (matches yielded dictionary)
output_signature = {
    'input_ids': tf.TensorSpec(shape=(MAX_LENGTH,), dtype=tf.int32),
    'pixel_values': tf.TensorSpec(shape=(N_FRAMES, IMAGE_SIZE, IMAGE_SIZE, 3), dtype=tf.float32)
}

# --- Create the tf.data.Dataset ---
# Only create dataset if manifest was loaded successfully and is not empty
if manifest_list:
    print("Creating tf.data.Dataset from generator...")
    dataset = tf.data.Dataset.from_generator(
        preprocessed_data_generator,
        output_signature=output_signature
    )

    print("Applying dataset transformations (shuffle, repeat, batch, prefetch)...")
    # --- Apply transformations ---
    # Shuffle the data items. BUFFER_SIZE determines how many items are loaded ahead for shuffling.
    dataset = dataset.shuffle(BUFFER_SIZE)

    # Repeat indefinitely for training epochs. Ensures the iterator doesn't stop after one pass.
    dataset = dataset.repeat()

    # Batch the data. drop_remainder=True is crucial for TPUs requiring fixed batch sizes per device.
    dataset = dataset.batch(GLOBAL_BATCH_SIZE, drop_remainder=True)

    # Prefetch data to overlap data loading/preprocessing with model execution.
    # AUTOTUNE lets tf.data decide the optimal number of batches to prefetch.
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    print("Dataset transformations applied.")

    # --- Function to shard batch for JAX pmap ---
    def shard_batch(batch):
        """Reshapes and distributes batch across JAX devices."""
        # Ensure the batch data are JAX arrays (tfds.as_numpy usually returns numpy)
        batch = jax.tree_map(lambda x: jnp.array(x), batch)
        # Calculate the number of devices
        num_devices = jax.local_device_count()
        # Use jax.tree_map to apply the reshape to both 'input_ids' and 'pixel_values'
        return jax.tree_map(
            lambda x: x.reshape((num_devices, -1) + x.shape[1:]), batch
        )

    # --- Convert TF dataset to NumPy iterator and test ---
    print("\nConverting dataset to NumPy iterator for JAX...")
    numpy_iterator = tfds.as_numpy(dataset)

    print("Testing the data pipeline by fetching one batch...")
    start_fetch_time = time.time()
    try:
        # Get one batch from the iterator
        example_batch = next(iter(numpy_iterator))
        fetch_time = time.time() - start_fetch_time
        print(f"Successfully fetched one batch in {fetch_time:.2f} seconds.")

        # Shard the batch for pmap
        sharded_batch = shard_batch(example_batch)
        print("Batch sharded successfully.")

        print("\n--- Batch Shapes ---")
        print("Original batch shapes (Global Batch):")
        jax.tree_map(lambda x: print(f"  {x.shape}, {x.dtype}"), example_batch)

        print("\nSharded batch shapes (Per Device):")
        jax.tree_map(lambda x: print(f"  {x.shape}, {x.dtype}"), sharded_batch) # Shape should be [num_devices, BATCH_SIZE_PER_DEVICE, ...]

        print("\nData pipeline setup and test successful!")

    except StopIteration:
         print("\nERROR: The data generator did not yield any data even after creating the dataset.")
         print("This likely means the manifest was empty OR all .npy files failed to load during the first attempt.")
         print("Check the generator logs and the integrity of the preprocessed files.")
    except Exception as e:
        print(f"\nERROR during pipeline testing (fetching/sharding batch): {e}")
        import traceback
        traceback.print_exc() # Print detailed traceback
        print("Check manifest file path, .npy file integrity/paths, batch sizes, and sharding logic.")

else:
    print("\nSkipping dataset creation and testing because the manifest failed to load or was empty.")
    print("Please ensure the preprocessing step ran correctly and created a non-empty manifest file.")

In [19]:
#@title Define U-Net Components and Diffusion Model (Flax)

import jax.numpy as jnp
import flax.linen as nn
import math
from typing import Sequence, Optional, Tuple, Union, Any

# --- Configuration ---
# These can be adjusted, but start reasonably small
unet_dim = 64 # Base channel dimension for the U-Net
unet_dim_mults = (1, 2, 4) # Multipliers for channels at different resolutions (e.g., 64, 128, 256)
unet_num_resnet_blocks = 2 # Number of ResNet blocks per resolution level
unet_attn_resolutions = (1,) # Which resolution levels (as factors of original, e.g., 1 = 64x64, 2 = 32x32) to use attention. Adjust based on IMAGE_SIZE
# Ensure text_embedding_dim is available from the previous cell
# text_embedding_dim = 512 # Should be set by CLIP loading cell


# --- Helper Functions & Modules ---

def normalization(num_groups: int = 32):
  """Group Normalization."""
  # Consider LayerNorm as an alternative if GroupNorm gives issues
  # LayerNorm might be more stable for TPUs sometimes
  # return nn.LayerNorm()
  return nn.GroupNorm(num_groups=num_groups, epsilon=1e-5)

class SinusoidalPosEmb(nn.Module):
    """Sinusoidal Positional Embedding for Timesteps."""
    dim: int

    @nn.compact
    def __call__(self, time):
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = jnp.exp(jnp.arange(half_dim) * -embeddings)
        # Expand dims for broadcasting: time=[B], embeddings=[half_dim] -> embeddings=[B, half_dim]
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = jnp.concatenate([jnp.sin(embeddings), jnp.cos(embeddings)], axis=-1)
        # Ensure output dim matches self.dim, handle odd dim if necessary
        if self.dim % 2 == 1:
             embeddings = jnp.pad(embeddings, [(0,0), (0,1)])
        return embeddings

class ResnetBlock3D(nn.Module):
    """3D ResNet Block with Time and Text Conditioning."""
    out_channels: int
    time_emb_dim: Optional[int] = None
    text_emb_dim: Optional[int] = None
    groups: int = 32 # For GroupNorm
    dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, x, t_emb=None, text_emb=None, *, train: bool):
        # x shape: [B, F, H, W, C]
        # t_emb shape: [B, time_emb_dim]
        # text_emb shape: [B, text_emb_dim]

        input_channels = x.shape[-1]
        hidden_channels = self.out_channels

        # Normalize and apply first convolution
        h = normalization(self.groups)(x)
        h = nn.silu(h)
        h = nn.Conv(features=hidden_channels, kernel_size=(3, 3, 3), padding='SAME')(h)

        # --- Inject Time Embedding ---
        if t_emb is not None and self.time_emb_dim is not None:
            # Project time embedding to match channel dimension
            time_proj = nn.silu(nn.Dense(features=hidden_channels)(t_emb))
            # Add to feature map: needs reshape [B, emb] -> [B, 1, 1, 1, emb] for broadcasting
            h = h + time_proj[:, None, None, None, :]

        # --- Inject Text Embedding (Optional, simple addition here) ---
        # More advanced conditioning uses cross-attention
        if text_emb is not None and self.text_emb_dim is not None:
             text_proj = nn.silu(nn.Dense(features=hidden_channels)(text_emb))
             h = h + text_proj[:, None, None, None, :] # Broadcast similarly

        # Normalize, activate, dropout, and second convolution
        h = normalization(self.groups)(h)
        h = nn.silu(h)
        h = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(h)
        h = nn.Conv(features=self.out_channels, kernel_size=(3, 3, 3), padding='SAME')(h)

        # --- Residual Connection ---
        if input_channels != self.out_channels:
            # Use 1x1x1 convolution to match dimensions if needed
            x_res = nn.Conv(features=self.out_channels, kernel_size=(1, 1, 1))(x)
        else:
            x_res = x

        return h + x_res


class AttentionBlock(nn.Module):
    """Multi-Head Attention Block."""
    num_heads: int = 8
    head_dim: int = 64 # Dimension per head
    groups: int = 32 # For GroupNorm
    use_cross_attention: bool = False

    @nn.compact
    def __call__(self, x, context=None):
        # x shape: [B, F, H, W, C]
        # context shape: [B, context_len, context_dim] (e.g., from text encoder)
        batch_size, num_frames, height, width, channels = x.shape
        residual = x

        inner_dim = self.num_heads * self.head_dim

        # Normalize
        h = normalization(self.groups)(x)

        # --- Reshape for Attention ---
        # Combine spatial/temporal dimensions: [B, F*H*W, C]
        h_reshaped = h.reshape(batch_size, num_frames * height * width, channels)

        # --- Self-Attention ---
        qkv = nn.Dense(features=inner_dim * 3, name='to_qkv')(h_reshaped)
        # Split into q, k, v: [B, F*H*W, inner_dim] each
        q, k, v = jnp.array_split(qkv, 3, axis=-1)

        # Apply attention (using Flax built-in)
        attn_output = nn.MultiHeadDotProductAttention(
            num_heads=self.num_heads,
            qkv_features=inner_dim, # Total dimension across heads
            out_features=channels, # Project back to original channel dim
            name='self_attn'
        )(inputs_q=q, inputs_kv=k) # Q=q, K=k, V=v (implicitly handled by MHA)

        # --- Optional: Cross-Attention ---
        if self.use_cross_attention and context is not None:
            # Project query from self-attention output, keys/values from context
            q_cross = nn.Dense(features=inner_dim, name='to_q_cross')(attn_output)
            kv_cross = nn.Dense(features=inner_dim * 2, name='to_kv_cross')(context)
            k_cross, v_cross = jnp.array_split(kv_cross, 2, axis=-1)

            attn_output = nn.MultiHeadDotProductAttention(
                num_heads=self.num_heads,
                qkv_features=inner_dim, # Q dim
                out_features=channels,
                name='cross_attn'
            )(inputs_q=q_cross, inputs_kv=k_cross) # Use k_cross, v_cross (implicitly)


        # Reshape back to original spatial/temporal dimensions: [B, F, H, W, C]
        attn_output_reshaped = attn_output.reshape(batch_size, num_frames, height, width, channels)

        # Add residual connection
        return attn_output_reshaped + residual


class Downsample(nn.Module):
    """Downsampling using strided 3D convolution."""
    features: int

    @nn.compact
    def __call__(self, x):
        # Use stride (1, 2, 2) to downsample H and W, keep F the same
        # Or use (2, 2, 2) to downsample F, H, W
        # Let's start with spatial downsampling only:
        return nn.Conv(
            features=self.features,
            kernel_size=(3, 3, 3),
            strides=(1, 2, 2), # Downsample H, W
            padding='SAME' # Or calculate padding manually
        )(x)

class Upsample(nn.Module):
    """Upsampling using 3D transposed convolution."""
    features: int

    @nn.compact
    def __call__(self, x):
        batch, frames, height, width, channels = x.shape
        # Use ConvTranspose to double H and W
        return nn.ConvTranspose(
            features=self.features,
            kernel_size=(3, 3, 3),
            strides=(1, 2, 2), # Upsample H, W
            padding='SAME' # padding='SAME' works with strides > 1 in ConvTranspose
        )(x)


# --- Main U-Net Model ---

class UNetConditional3D(nn.Module):
    """Conditional 3D U-Net for Video Diffusion."""
    dim: int = 64
    dim_mults: Sequence[int] = (1, 2, 4)
    num_resnet_blocks: int = 2
    attn_resolutions_factor: Sequence[int] = (2, 4) # Resolution factors (e.g., H/2, H/4) where attention is applied
    out_dim: Optional[int] = None
    time_emb_dim: int = 256 # Internal dimension for time embedding MLP
    text_emb_dim: int = text_embedding_dim # From CLIP
    dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, x, time, text_embedding, *, train: bool):
        # x shape: [B, F, H, W, C] (e.g., B, 16, 64, 64, 3)
        # time shape: [B] (timesteps)
        # text_embedding shape: [B, E] (e.g., B, 512)

        assert x.ndim == 5, f"Input x must be 5D (B, F, H, W, C), got {x.ndim}"
        batch_size, num_frames, height, width, channels = x.shape
        out_dim = self.out_dim if self.out_dim is not None else channels

        # --- Time Embedding ---
        t_emb = SinusoidalPosEmb(dim=self.dim)(time) # [B, dim]
        # MLP for time embedding
        t_emb = nn.Dense(features=self.time_emb_dim)(t_emb)
        t_emb = nn.silu(t_emb)
        t_emb = nn.Dense(features=self.time_emb_dim)(t_emb) # [B, time_emb_dim]

        # --- Initial Convolution ---
        h = nn.Conv(features=self.dim, kernel_size=(3, 3, 3), padding='SAME')(x)
        hs = [h] # Store hidden states for skip connections

        # --- Downsampling Path ---
        current_res_factor = 1
        block_dims = [self.dim]
        for i, mult in enumerate(self.dim_mults):
            is_last = (i == len(self.dim_mults) - 1)
            features = self.dim * mult
            block_dims.append(features)

            for _ in range(self.num_resnet_blocks):
                h = ResnetBlock3D(
                    out_channels=features,
                    time_emb_dim=self.time_emb_dim,
                    text_emb_dim=self.text_emb_dim, # Pass text dim for conditioning
                    dropout_rate=self.dropout_rate
                )(h, t_emb, text_embedding, train=train) # Pass embeddings

                # --- Attention ---
                if current_res_factor in self.attn_resolutions_factor:
                    h = AttentionBlock(
                        num_heads=8,
                        head_dim=features // 8, # Example head dim calculation
                        use_cross_attention=False # Set True to use text_embedding in attention
                    )(h, context=text_embedding) # Pass text_embedding if use_cross_attention=True

                hs.append(h) # Store for skip connection

            # Downsample at the end of the level (unless it's the last level)
            if not is_last:
                h = Downsample(features=features)(h)
                hs.append(h)
                current_res_factor *= 2 # Assuming spatial downsampling by 2

        # --- Bottleneck ---
        mid_features = block_dims[-1]
        h = ResnetBlock3D(
            out_channels=mid_features,
            time_emb_dim=self.time_emb_dim,
            text_emb_dim=self.text_emb_dim,
            dropout_rate=self.dropout_rate
        )(h, t_emb, text_embedding, train=train)

        h = AttentionBlock(
            num_heads=8, head_dim=mid_features // 8, use_cross_attention=False
        )(h, context=text_embedding)

        h = ResnetBlock3D(
            out_channels=mid_features,
            time_emb_dim=self.time_emb_dim,
            text_emb_dim=self.text_emb_dim,
            dropout_rate=self.dropout_rate
        )(h, t_emb, text_embedding, train=train)

        # --- Upsampling Path ---
        for i, mult in reversed(list(enumerate(self.dim_mults))):
            is_first = (i == 0)
            features = self.dim * mult
            prev_features = block_dims[i] # Dim from the corresponding downsampling level

            for _ in range(self.num_resnet_blocks + 1): # +1 because we pop skip connections
                # Pop skip connection and concatenate
                skip_h = hs.pop()
                h = jnp.concatenate([h, skip_h], axis=-1)

                h = ResnetBlock3D(
                    out_channels=features, # Target features for this level
                    time_emb_dim=self.time_emb_dim,
                    text_emb_dim=self.text_emb_dim,
                    dropout_rate=self.dropout_rate
                )(h, t_emb, text_embedding, train=train)

                # --- Attention ---
                if current_res_factor in self.attn_resolutions_factor:
                     h = AttentionBlock(
                         num_heads=8, head_dim=features // 8, use_cross_attention=False
                     )(h, context=text_embedding)

            # Upsample at the end of the level (unless it's the first level)
            if not is_first:
                h = Upsample(features=features)(h)
                current_res_factor //= 2 # Assuming spatial upsampling by 2

        # Ensure we've used all skip connections
        assert not hs, "Skip connections remaining after upsampling path."

        # --- Final Layers ---
        h = normalization(32)(h) # Final normalization
        h = nn.silu(h)
        # Final convolution maps back to the number of output channels (e.g., 3 for RGB noise)
        output_noise = nn.Conv(features=out_dim, kernel_size=(1, 1, 1))(h)

        return output_noise


# --- Quick Test (Initialization) ---
print("\n--- Testing Model Initialization ---")
key = jax.random.PRNGKey(0)
dummy_batch_size = 2 # Small batch for testing init
dummy_frames = 16
dummy_height = IMAGE_SIZE # Use constant from pipeline
dummy_width = IMAGE_SIZE
dummy_channels = 3
dummy_text_len = 77 # CLIP max length
dummy_text_emb_dim = text_embedding_dim # From CLIP loading

# Dummy Inputs
dummy_x = jnp.ones((dummy_batch_size, dummy_frames, dummy_height, dummy_width, dummy_channels))
dummy_time = jnp.ones((dummy_batch_size,))
dummy_text_emb = jnp.ones((dummy_batch_size, dummy_text_emb_dim)) # Match CLIP output dim

# Instantiate the model
unet_model = UNetConditional3D(
    dim=unet_dim,
    dim_mults=unet_dim_mults,
    num_resnet_blocks=unet_num_resnet_blocks,
    attn_resolutions_factor=[dummy_height // pow(2, i) for i in range(len(unet_dim_mults))], # Example: Apply attention based on downsampled H
    out_dim=dummy_channels,
    time_emb_dim=unet_dim * 4, # Often larger than base dim
    text_emb_dim=dummy_text_emb_dim,
    dropout_rate=0.1
)


# Initialize parameters
print("Initializing U-Net parameters...")
try:
    # Use train=False for initialization typically, dropout won't be active
    params = unet_model.init(key, dummy_x, dummy_time, dummy_text_emb, train=False)['params']
    print("U-Net initialization successful.")

    # Test forward pass
    print("Testing U-Net forward pass...")
    output = unet_model.apply({'params': params}, dummy_x, dummy_time, dummy_text_emb, train=False)
    print(f"Output shape: {output.shape}")

    # Check if output shape matches input shape
    if output.shape == dummy_x.shape:
        print("Output shape matches input shape. Looks good!")
    else:
        print(f"WARNING: Output shape {output.shape} does not match input shape {dummy_x.shape}")

    # Optional: Print number of parameters
    num_params = sum(x.size for x in jax.tree_util.tree_leaves(params))
    print(f"Total number of parameters in U-Net: {num_params:,}")

except Exception as e:
    print(f"\nERROR during U-Net initialization or forward pass: {e}")
    import traceback
    traceback.print_exc()
    print("\n--- Debugging Tips ---")
    print("*   Check input/output dimensions in Conv3D, ConvTranspose, Attention, ResNet blocks.")
    print("*   Ensure skip connections are handled correctly (concatenation dimension, number of pops).")
    print("*   Verify time/text embedding injection logic and shapes.")
    print("*   Make sure attention resolution factors are compatible with image size.")
    print("*   Try simplifying the network (fewer blocks/layers/attention) to isolate the issue.")


--- Testing Model Initialization ---
Initializing U-Net parameters...
U-Net initialization successful.
Testing U-Net forward pass...
Output shape: (2, 16, 64, 64, 3)
Output shape matches input shape. Looks good!
Total number of parameters in U-Net: 41,601,219


In [36]:
#@title Define Diffusion Schedule, Sampling Functions, and Loss

import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from flax.core import FrozenDict # Keep import, useful type

# --- Diffusion Hyperparameters ---
NUM_DIFFUSION_TIMESTEPS = 1000 # Standard number of timesteps
BETA_START = 0.0001
BETA_END = 0.02

# --- Noise Schedule Calculation ---

def linear_beta_schedule(timesteps):
    """Generates linear beta schedule."""
    scale = 1000 / timesteps # Adjusts beta range based on actual timesteps
    beta_start = scale * BETA_START
    beta_end = scale * BETA_END
    return jnp.linspace(beta_start, beta_end, timesteps, dtype=jnp.float32)

def get_diffusion_params(schedule_fn, timesteps):
    """Calculates alphas, alphas_cumprod, etc. from betas."""
    betas = schedule_fn(timesteps)

    alphas = 1. - betas
    alphas_cumprod = jnp.cumprod(alphas, axis=0)
    alphas_cumprod_prev = jnp.pad(alphas_cumprod[:-1], (1, 0), constant_values=1.0)

    sqrt_alphas_cumprod = jnp.sqrt(alphas_cumprod)
    sqrt_one_minus_alphas_cumprod = jnp.sqrt(1. - alphas_cumprod)

    posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
    posterior_variance = jnp.maximum(posterior_variance, 1e-20)
    posterior_log_variance_clipped = jnp.log(posterior_variance)

    posterior_mean_coef1 = betas * jnp.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
    posterior_mean_coef2 = (1. - alphas_cumprod_prev) * jnp.sqrt(alphas) / (1. - alphas_cumprod)

    # Return as a standard dictionary (or FrozenDict, either works now)
    params_dict = {
        "betas": betas,
        "alphas": alphas,
        "alphas_cumprod": alphas_cumprod,
        "alphas_cumprod_prev": alphas_cumprod_prev,
        "sqrt_alphas_cumprod": sqrt_alphas_cumprod,
        "sqrt_one_minus_alphas_cumprod": sqrt_one_minus_alphas_cumprod,
        "posterior_variance": posterior_variance,
        "posterior_log_variance_clipped": posterior_log_variance_clipped,
        "posterior_mean_coef1": posterior_mean_coef1,
        "posterior_mean_coef2": posterior_mean_coef2,
    }
    # Convert to FrozenDict (optional but good practice for immutable schedules)
    return FrozenDict(params_dict)


# --- Instantiate Diffusion Parameters ---
diffusion_params = get_diffusion_params(linear_beta_schedule, NUM_DIFFUSION_TIMESTEPS)
print(f"Calculated diffusion parameters (as {type(diffusion_params).__name__}) for {NUM_DIFFUSION_TIMESTEPS} timesteps.")
print(f"Beta range: {diffusion_params['betas'][0]:.4f} to {diffusion_params['betas'][-1]:.4f}")

# --- Helper function to extract schedule values for specific timesteps ---
def extract(arr, timesteps, broadcast_shape):
    """Extract values from arr at given timesteps and reshape."""
    batch_size = timesteps.shape[0]
    out = arr[timesteps]
    return out.reshape(batch_size, *((1,) * (len(broadcast_shape) - 1)))

# --- Forward Process (q_sample) ---
# @jax.jit # Can optionally jit this function
def q_sample(x_start, t, noise, schedule):
    """Diffuse the data (t == 0 means no noise)."""
    sqrt_alphas_cumprod_t = extract(schedule['sqrt_alphas_cumprod'], t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(schedule['sqrt_one_minus_alphas_cumprod'], t, x_start.shape)
    noisy_frames = sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
    return noisy_frames

# --- Loss Function ---
def noise_prediction_loss(predicted_noise, target_noise):
    """Calculates the Mean Squared Error between predicted and target noise."""
    # Can add options for L1 loss etc. later if needed
    return jnp.mean((predicted_noise - target_noise) ** 2)

# --- Optional: Test q_sample ---
print("\n--- Testing q_sample ---")
try:
    if 'dummy_x' not in globals(): # Check if dummy data exists from U-Net init test
         # Recreate minimal dummy data if needed
         print("Recreating dummy data for q_sample test...")
         key = jax.random.PRNGKey(1)
         # Use constants if available, otherwise define fallback values
         _N_FRAMES = N_FRAMES if 'N_FRAMES' in globals() else 16
         _IMAGE_SIZE = IMAGE_SIZE if 'IMAGE_SIZE' in globals() else 64
         dummy_batch_size=2; dummy_channels=3
         dummy_x = jax.random.normal(key, (dummy_batch_size, _N_FRAMES, _IMAGE_SIZE, _IMAGE_SIZE, dummy_channels))

    key, noise_key = jax.random.split(jax.random.PRNGKey(42))
    dummy_noise = jax.random.normal(noise_key, dummy_x.shape)
    t0 = jnp.array([0] * dummy_batch_size)
    noisy_x_t0 = q_sample(dummy_x, t0, dummy_noise, diffusion_params)
    diff_t0 = jnp.mean((dummy_x - noisy_x_t0)**2)
    print(f"MSE between x_start and noisy_x at t=0: {diff_t0:.6f}")

    t_mid = jnp.array([NUM_DIFFUSION_TIMESTEPS // 2] * dummy_batch_size)
    noisy_x_tmid = q_sample(dummy_x, t_mid, dummy_noise, diffusion_params)
    diff_tmid = jnp.mean((dummy_x - noisy_x_tmid)**2)
    print(f"MSE between x_start and noisy_x at t={t_mid[0]}: {diff_tmid:.4f}")

    t_max = jnp.array([NUM_DIFFUSION_TIMESTEPS - 1] * dummy_batch_size)
    noisy_x_tmax = q_sample(dummy_x, t_max, dummy_noise, diffusion_params)
    diff_tmax = jnp.mean((dummy_x - noisy_x_tmax)**2)
    var_tmax = jnp.var(noisy_x_tmax)
    print(f"MSE between x_start and noisy_x at t={t_max[0]}: {diff_tmax:.4f}")
    print(f"Variance of noisy_x at t={t_max[0]}: {var_tmax:.4f}")

    print("q_sample test completed successfully.")

except NameError as ne:
     print(f"Error during q_sample test: Missing variable '{ne}'. Ensure previous cells ran.")
except Exception as e:
    print(f"Error during q_sample test: {e}")
    import traceback
    traceback.print_exc()

Calculated diffusion parameters (as FrozenDict) for 1000 timesteps.
Beta range: 0.0001 to 0.0200

--- Testing q_sample ---
MSE between x_start and noisy_x at t=0: 0.000100
MSE between x_start and noisy_x at t=500: 1.4395
MSE between x_start and noisy_x at t=999: 1.9837
Variance of noisy_x at t=999: 0.9993
q_sample test completed successfully.


In [37]:
#@title Define Optimizer and Training Step Function (Corrected)

import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from functools import partial
from flax.core import FrozenDict

# --- Optimizer Configuration ---
LEARNING_RATE = 1e-4 # Starting learning rate

optimizer = optax.adamw(learning_rate=LEARNING_RATE)

# --- Ensure necessary components are available ---
if 'unet_model' not in globals(): raise NameError("unet_model not defined.")
if 'clip_text_model' not in globals(): raise NameError("clip_text_model not defined.")
if 'numpy_iterator' not in globals(): raise NameError("numpy_iterator (iterable) not defined.")
if 'shard_batch' not in globals(): raise NameError("shard_batch function not defined.")
if 'diffusion_params' not in globals(): raise NameError("diffusion_params not defined.")
if 'q_sample' not in globals(): raise NameError("q_sample function not defined.")
if 'noise_prediction_loss' not in globals(): raise NameError("noise_prediction_loss function not defined.")
if 'params' not in globals(): raise NameError("U-Net parameters 'params' not initialized.")
if 'NUM_DIFFUSION_TIMESTEPS' not in globals(): raise NameError("NUM_DIFFUSION_TIMESTEPS not defined.")


# --- Define the Training Step ---

# Only model definitions are static. Schedule will be replicated.
@partial(jax.pmap, axis_name='batch',
         static_broadcasted_argnums=(4, 5)) # unet_def, clip_encoder
def train_step(unet_params, clip_params, opt_state, schedule, unet_def, clip_encoder, batch, key):
    """Performs a single training step on a batch distributed across devices."""

    step_key, t_key, noise_key = jax.random.split(key, 3)

    # 1. Get Text Embeddings
    # NOTE: Ensure clip_params has the correct structure expected by clip_encoder.apply
    # For HuggingFace Flax models, usually it's just model.params
    text_embeddings = clip_encoder.apply(
        {'params': clip_params}, batch['input_ids']
    ).text_embeds

    # 2. Prepare Diffusion Inputs
    clean_frames = batch['pixel_values']
    local_batch_size = clean_frames.shape[0]
    t = jax.random.randint(t_key, shape=(local_batch_size,), minval=0, maxval=NUM_DIFFUSION_TIMESTEPS)
    noise = jax.random.normal(noise_key, clean_frames.shape)
    noisy_frames = q_sample(clean_frames, t, noise, schedule) # schedule is device-local replicated version

    # 3. Define Loss Function for Grad Calculation
    def compute_loss(current_unet_params):
        predicted_noise = unet_def.apply(
            {'params': current_unet_params},
            noisy_frames, t, text_embeddings, train=True
        )
        loss = noise_prediction_loss(predicted_noise, noise)
        return loss

    # 4. Calculate Loss and Gradients
    loss_val, grads = jax.value_and_grad(compute_loss)(unet_params)

    # Average gradients and loss across devices
    grads = jax.lax.pmean(grads, axis_name='batch')
    loss_val = jax.lax.pmean(loss_val, axis_name='batch')

    # 5. Update Optimizer State and U-Net Parameters
    updates, new_opt_state = optimizer.update(grads, opt_state, unet_params)
    new_unet_params = optax.apply_updates(unet_params, updates)

    metrics = {'loss': loss_val}
    return new_unet_params, new_opt_state, metrics, step_key

print("Optimizer and train_step function defined.")
print("train_step configured for pmap execution (schedule is replicated, models static).")

# --- Prepare State for Training Loop Execution ---
opt_state = None
replicated_params = None
replicated_opt_state = None
replicated_clip_params = None
replicated_schedule = None
step_keys = None
num_devices = jax.local_device_count()

try:
    # 1. Initialize Optimizer State
    opt_state = optimizer.init(params) # 'params' should be from U-Net init cell
    print("Optimizer state initialized.")

    # 2. Replicate State Across Devices
    replicated_params = jax.device_put_replicated(params, jax.local_devices())
    replicated_opt_state = jax.device_put_replicated(opt_state, jax.local_devices())
    # Ensure clip_text_model.params is the correct way to access CLIP params
    replicated_clip_params = jax.device_put_replicated(clip_text_model.params, jax.local_devices())
    # Replicate the diffusion schedule (FrozenDict or dict)
    replicated_schedule = jax.device_put_replicated(diffusion_params, jax.local_devices())

    print("U-Net params, CLIP params, optimizer state, and schedule replicated across devices.")

    # 3. Prepare PRNG Keys
    rng = jax.random.PRNGKey(0)
    step_keys = jax.random.split(rng, num_devices)
    print(f"PRNG keys split for {num_devices} devices.")

except NameError as ne:
    print(f"ERROR: Missing variable '{ne}', cannot prepare for training loop.")
    print("Ensure U-Net params ('params'), diffusion_params, clip_text_model, and optimizer are defined.")
except Exception as e:
    print(f"ERROR preparing state for training loop: {e}")
    import traceback
    traceback.print_exc()

Optimizer and train_step function defined.
train_step configured for pmap execution (schedule is replicated, models static).
Optimizer state initialized.
U-Net params, CLIP params, optimizer state, and schedule replicated across devices.
PRNG keys split for 8 devices.


In [38]:
#@title Training Loop (Corrected)

import time
from flax.training import checkpoints
import os
from google.colab import drive
import tensorflow_datasets as tfds # Ensure tfds is imported

# --- Training Configuration ---
NUM_TRAIN_STEPS = 500 # Set desired number of steps
LOG_EVERY_STEPS = 100
CHECKPOINT_EVERY_STEPS = 1000 # Frequency to save checkpoints

# --- Checkpoint Configuration ---
CHECKPOINT_DIR = '/content/drive/MyDrive/tgif_checkpoints_local_data'
try:
    # Ensure Drive is mounted before creating directory
    if not os.path.ismount('/content/drive'):
      drive.mount('/content/drive')
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    print(f"Checkpoints will be saved to: {CHECKPOINT_DIR}")
except Exception as e:
    print(f"Error setting up checkpoint directory on Google Drive: {e}")
    print("Checkpoints will not be saved.")
    CHECKPOINT_DIR = None # Disable checkpointing if Drive fails

# --- Verify necessary components exist ---
if 'replicated_params' not in globals() or replicated_params is None: raise NameError("replicated_params not ready.")
if 'replicated_clip_params' not in globals() or replicated_clip_params is None: raise NameError("replicated_clip_params not ready.")
if 'replicated_opt_state' not in globals() or replicated_opt_state is None: raise NameError("replicated_opt_state not ready.")
if 'replicated_schedule' not in globals() or replicated_schedule is None: raise NameError("replicated_schedule not ready.")
if 'train_step' not in globals(): raise NameError("train_step function (pmap'ed) not defined.")
if 'numpy_iterator' not in globals(): raise NameError("numpy_iterator (iterable) not defined.")
if 'shard_batch' not in globals(): raise NameError("shard_batch function not defined.")
if 'step_keys' not in globals() or step_keys is None: raise NameError("step_keys (per-device PRNG keys) not defined.")
if 'diffusion_params' not in globals(): raise NameError("diffusion_params not defined.")
if 'unet_model' not in globals(): raise NameError("unet_model definition not defined.")
if 'clip_text_model' not in globals(): raise NameError("clip_text_model instance not defined.")


print("Starting training loop...")
start_loop_time = time.time()

# Initialize current state from replicated state
current_params = replicated_params
current_opt_state = replicated_opt_state
current_keys = step_keys

# Create the actual iterator from the dataset iterable
try:
    data_iterator = iter(numpy_iterator)
    print("Data iterator created successfully.")
except Exception as e:
    print(f"ERROR creating data iterator: {e}")
    # Stop if data iterator fails
    raise SystemExit("Stopping: Failed to create data iterator.") from e


# --- Main Training Loop ---
train_metrics = [] # To store metrics over time
for step in range(1, NUM_TRAIN_STEPS + 1):
    step_start_time = time.time()

    # 1. Get next batch
    try:
        batch = next(data_iterator)
    except StopIteration:
        print("Warning: Data iterator exhausted. Re-creating (ensure tf.data.Dataset uses .repeat()).")
        data_iterator = iter(numpy_iterator)
        batch = next(data_iterator)
    except Exception as e:
        print(f"ERROR fetching batch at step {step}: {e}")
        break

    # 2. Shard the batch
    try:
      sharded_batch = shard_batch(batch)
    except Exception as e:
      print(f"ERROR sharding batch at step {step}: {e}")
      break

    # 3. Perform one training step
    try:
        # Pass the replicated schedule as a regular argument
        current_params, current_opt_state, metrics, current_keys = train_step(
            current_params,
            replicated_clip_params,
            current_opt_state,
            replicated_schedule, # Use replicated schedule
            unet_model,          # Static model def
            clip_text_model,     # Static model instance
            sharded_batch,
            current_keys
        )
    except Exception as e:
        print(f"\nERROR during train_step execution at step {step}: {e}")
        if "RESOURCE_EXHAUSTED" in str(e):
             print("\n--- OOM Error Detected! ---")
             print("The model/batch size is likely too large for TPU memory.")
             print("Consider implementing Factorized Attention, Gradient Checkpointing, or reducing input/batch size.")
             print("Stopping training.")
        import traceback
        traceback.print_exc()
        break # Stop training on error

    step_time = time.time() - step_start_time
    train_metrics.append(metrics) # Store metrics

    # 4. Log Metrics
    if step == 1 or step % LOG_EVERY_STEPS == 0:
        try:
            # Get loss from the first device (it's already averaged)
            loss_value = jax.device_get(metrics['loss'][0])
            print(f"Step: {step}/{NUM_TRAIN_STEPS} | Loss: {loss_value:.6f} | Step Time: {step_time:.3f}s")
        except Exception as e:
            print(f"Error logging metrics at step {step}: {e}")

    # 5. Save Checkpoint
    if CHECKPOINT_DIR is not None and step % CHECKPOINT_EVERY_STEPS == 0:
        print(f"Saving checkpoint at step {step}...")
        checkpoint_start_time = time.time()
        try:
            # Retrieve state from device 0
            unet_params_host = jax.device_get(current_params)[0]
            opt_state_host = jax.device_get(current_opt_state)[0]
            ckpt_state = {
                'unet_params': unet_params_host,
                'opt_state': opt_state_host,
                'step': step
            }
            checkpoints.save_checkpoint(
                ckpt_dir=CHECKPOINT_DIR,
                target=ckpt_state,
                step=step,
                overwrite=True,
                keep=3
            )
            checkpoint_time = time.time() - checkpoint_start_time
            print(f"Checkpoint saved successfully in {checkpoint_time:.2f}s.")
        except Exception as e:
            print(f"ERROR saving checkpoint at step {step}: {e}")
            import traceback
            traceback.print_exc()


# --- End of Training Loop ---
end_loop_time = time.time()
total_training_time = end_loop_time - start_loop_time
print(f"\n--- Training Finished ---")
print(f"Total Steps: {NUM_TRAIN_STEPS}")
print(f"Total Training Time: {total_training_time / 60:.2f} minutes")

# Save final model state if checkpointing enabled
if CHECKPOINT_DIR is not None:
    print("Saving final model state...")
    try:
         final_unet_params_host = jax.device_get(current_params)[0]
         final_opt_state_host = jax.device_get(current_opt_state)[0]
         final_state = {
             'unet_params': final_unet_params_host,
             'opt_state': final_opt_state_host,
             'step': NUM_TRAIN_STEPS
         }
         checkpoints.save_checkpoint(CHECKPOINT_DIR, final_state, step=NUM_TRAIN_STEPS, prefix="final_checkpoint_", overwrite=True)
         print("Final model state saved.")
    except Exception as e:
         print(f"ERROR saving final model state: {e}")

Checkpoints will be saved to: /content/drive/MyDrive/tgif_checkpoints_local_data
Starting training loop...
Data iterator created successfully.
Data Generator: Starting epoch with 500 shuffled indices.
Data Generator: Epoch finished. Loaded 500, Failed 0.
Data Generator: Starting epoch with 500 shuffled indices.
Data Generator: Epoch finished. Loaded 500, Failed 0.
Data Generator: Starting epoch with 500 shuffled indices.

ERROR during train_step execution at step 1: Non-hashable static arguments are not supported. An error occurred while trying to hash an object of type <class '__main__.UNetConditional3D'>, UNetConditional3D(
    # attributes
    dim = 64
    dim_mults = (1, 2, 4)
    num_resnet_blocks = 2
    attn_resolutions_factor = [64, 32, 16]
    out_dim = 3
    time_emb_dim = 256
    text_emb_dim = 512
    dropout_rate = 0.1
). The error was:
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_cod

Traceback (most recent call last):
  File "<ipython-input-38-e811d35872b0>", line 85, in <cell line: 0>
    current_params, current_opt_state, metrics, current_keys = train_step(
                                                               ^^^^^^^^^^^
ValueError: Non-hashable static arguments are not supported. An error occurred while trying to hash an object of type <class '__main__.UNetConditional3D'>, UNetConditional3D(
    # attributes
    dim = 64
    dim_mults = (1, 2, 4)
    num_resnet_blocks = 2
    attn_resolutions_factor = [64, 32, 16]
    out_dim = 3
    time_emb_dim = 256
    text_emb_dim = 512
    dropout_rate = 0.1
). The error was:
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/usr/local/lib/python3.11/dist-packages/colab_kernel_launcher.py", line 37, in <module>
  File "/usr/local/lib/python3.11/dist-packages/traitlets/config/application.py", line 992, in launch_insta

ERROR saving final model state: 0
