In [2]:
# import torch

# # Toy example to demonstrate DDPO batching logic
# print("=" * 80)
# print("TOY EXAMPLE: DDPO Batching and Sample Reuse")
# print("=" * 80)

# # Configuration (toy values)
# sample_batch_size = 4  # samples generated per batch during sampling
# num_batches_per_epoch = 2  # how many sampling batches per epoch
# train_batch_size = 2  # minibatch size for training
# num_inner_epochs = 1  # how many times to reuse samples
# num_timesteps = 3  # number of diffusion timesteps
# gradient_accumulation_steps = 2  # accumulate gradients over this many minibatches

# # Step 1: SAMPLING PHASE
# print("\n--- SAMPLING PHASE ---")
# total_samples = sample_batch_size * num_batches_per_epoch
# print(f"Total samples collected: {total_samples}")

# # Simulate collected samples (shape: [total_samples, num_timesteps])
# samples = {
#     "latents": torch.randn(total_samples, num_timesteps, 4),  # simplified
#     "timesteps": torch.randint(0, 1000, (total_samples, num_timesteps)),
#     "log_probs": torch.randn(total_samples, num_timesteps),
#     "advantages": torch.randn(total_samples),  # one per sample
# }

# print(f"Sample shapes:")
# print(f"  latents: {samples['latents'].shape}")
# print(f"  advantages: {samples['advantages'].shape}")

# # Step 2: TRAINING PHASE
# print("\n--- TRAINING PHASE ---")

# for inner_epoch in range(num_inner_epochs):
#     print(f"\nInner Epoch {inner_epoch}:")
    
#     # Shuffle samples
#     perm = torch.randperm(total_samples)
#     samples_shuffled = {k: v[perm] for k, v in samples.items()}
    
#     # Rebatch for training: reshape into minibatches
#     samples_batched = {
#         k: v.reshape(-1, train_batch_size, *v.shape[1:])
#         for k, v in samples_shuffled.items()
#     }
    
#     num_minibatches = samples_batched["latents"].shape[0]
#     print(f"  Number of minibatches: {num_minibatches}")
#     print(f"  Each minibatch has {train_batch_size} samples")
    
#     # Convert to list of dicts
#     samples_batched_list = [
#         {k: samples_batched[k][i] for k in samples_batched.keys()}
#         for i in range(num_minibatches)
#     ]
    
#     # Iterate over minibatches
#     optimizer_steps = 0
#     for i, minibatch in enumerate(samples_batched_list):
#         print(f"\n  Minibatch {i}:")
#         print(f"    Shape: {minibatch['latents'].shape}")
#         print(f"    Samples in this minibatch: {train_batch_size}")
        
#         # Iterate over timesteps
#         for j in range(num_timesteps):
#             # Simulate forward pass
#             latent = minibatch['latents'][:, j]
#             log_prob = minibatch['log_probs'][:, j]
#             advantage = minibatch['advantages']
            
#             print(f"      Timestep {j}: processing {latent.shape[0]} samples")
            
#             # Simulate gradient accumulation
#             # In real code, accelerator.sync_gradients tells us when optimizer step happens
#             should_step = (j == num_timesteps - 1) and (
#                 (i + 1) % gradient_accumulation_steps == 0
#             )
            
#             if should_step:
#                 optimizer_steps += 1
#                 print(f"      -> OPTIMIZER STEP #{optimizer_steps}")

# print("\n" + "=" * 80)
# print("SUMMARY:")
# print("=" * 80)
# print(f"Total samples collected: {total_samples}")
# print(f"Number of inner epochs: {num_inner_epochs}")
# print(f"Minibatch size: {train_batch_size}")
# print(f"Number of minibatches per inner epoch: {num_minibatches}")
# print(f"Gradient accumulation steps: {gradient_accumulation_steps}")
# print(f"Total optimizer steps: {optimizer_steps}")
# print(f"\nSample reuse calculation:")
# print(f"  Each sample is used in: 1 minibatch × {num_timesteps} timesteps = {num_timesteps} forward passes")
# print(f"  With {num_inner_epochs} inner epochs: {num_timesteps * num_inner_epochs} total forward passes per sample")
# print(f"  Total forward passes: {total_samples * num_timesteps * num_inner_epochs}")
# print(f"  Forward passes per optimizer step: {(total_samples * num_timesteps * num_inner_epochs) / optimizer_steps:.1f}")
# print("=" * 80)

# Toy example showing the CORRECT logic

import torch

# Configuration
total_batch_size = 8  # total samples collected (e.g., 4 batch_size × 2 num_batches_per_epoch)
train_batch_size = 2  # minibatch size
num_inner_epochs = 1
num_timesteps = 3

print("=" * 80)
print("ACTUAL DDPO BATCHING LOGIC")
print("=" * 80)

# Simulate collected samples
samples = {
    "latents": torch.arange(total_batch_size).reshape(-1, 1).repeat(1, num_timesteps),
    "advantages": torch.arange(total_batch_size, dtype=torch.float),
}

print(f"\nTotal samples collected: {total_batch_size}")
print(f"Sample IDs: {samples['latents'][:, 0].tolist()}")

for inner_epoch in range(num_inner_epochs):
    print(f"\n{'=' * 80}")
    print(f"Inner Epoch {inner_epoch}")
    print(f"{'=' * 80}")
    
    # Shuffle samples
    perm = torch.randperm(total_batch_size)
    samples_shuffled = {k: v[perm] for k, v in samples.items()}
    
    print(f"\nAfter shuffling: {samples_shuffled['latents'][:, 0].tolist()}")
    
    # Rebatch: split into minibatches
    samples_batched = {
        k: v.reshape(-1, train_batch_size, *v.shape[1:])
        for k, v in samples_shuffled.items()
    }
    
    num_minibatches = samples_batched["latents"].shape[0]
    print(f"\nNumber of minibatches: {num_minibatches}")
    
    # Process each minibatch
    for i in range(num_minibatches):
        minibatch_sample_ids = samples_batched["latents"][i, :, 0].tolist()
        print(f"\nMinibatch {i}: samples {minibatch_sample_ids}")
        
        # Process each timestep in this minibatch
        for j in range(num_timesteps):
            print(f"  Timestep {j}: processing samples {minibatch_sample_ids}")

print("\n" + "=" * 80)
print("KEY INSIGHT:")
print("=" * 80)
print(f"With num_inner_epochs = 1:")
print(f"  - Each sample appears in EXACTLY ONE minibatch")
print(f"  - Minibatch 0 has samples: [shuffled indices 0-1]")
print(f"  - Minibatch 1 has samples: [shuffled indices 2-3]")
print(f"  - Minibatch 2 has samples: [shuffled indices 4-5]")
print(f"  - Minibatch 3 has samples: [shuffled indices 6-7]")
print(f"\nEach sample is used {num_timesteps} times (once per timestep)")
print(f"But each sample is in a DIFFERENT minibatch (no reuse across minibatches)")
print("\n" + "=" * 80)
print("IF num_inner_epochs = 3:")
print("=" * 80)
print(f"  - Samples would be re-shuffled and re-batched 3 times")
print(f"  - Each sample would be used {num_timesteps * 3} times total")
print(f"  - THIS is where sample reuse happens!")
print("=" * 80)

ACTUAL DDPO BATCHING LOGIC

Total samples collected: 8
Sample IDs: [0, 1, 2, 3, 4, 5, 6, 7]

Inner Epoch 0

After shuffling: [5, 3, 6, 2, 1, 0, 7, 4]

Number of minibatches: 4

Minibatch 0: samples [5, 3]
  Timestep 0: processing samples [5, 3]
  Timestep 1: processing samples [5, 3]
  Timestep 2: processing samples [5, 3]

Minibatch 1: samples [6, 2]
  Timestep 0: processing samples [6, 2]
  Timestep 1: processing samples [6, 2]
  Timestep 2: processing samples [6, 2]

Minibatch 2: samples [1, 0]
  Timestep 0: processing samples [1, 0]
  Timestep 1: processing samples [1, 0]
  Timestep 2: processing samples [1, 0]

Minibatch 3: samples [7, 4]
  Timestep 0: processing samples [7, 4]
  Timestep 1: processing samples [7, 4]
  Timestep 2: processing samples [7, 4]

KEY INSIGHT:
With num_inner_epochs = 1:
  - Each sample appears in EXACTLY ONE minibatch
  - Minibatch 0 has samples: [shuffled indices 0-1]
  - Minibatch 1 has samples: [shuffled indices 2-3]
  - Minibatch 2 has samples: [shuf

In [6]:
import torch

print("=" * 80)
print("TOY EXAMPLE: Time Dimension Shuffling and Rebatching in DDPO")
print("=" * 80)

# Configuration
total_batch_size = 4  # total samples collected
num_timesteps = 3  # number of diffusion timesteps
train_batch_size = 2  # minibatch size for training

print(f"\nSetup:")
print(f"  Total samples: {total_batch_size}")
print(f"  Number of timesteps: {num_timesteps}")
print(f"  Train batch size: {train_batch_size}")

# Create sample data with easy-to-track values
# Shape: (total_batch_size, num_timesteps)
timesteps = torch.tensor([
    [1000, 800, 100],  # Sample 0
    [1000, 800, 200],  # Sample 1
    [1000, 800, 300],  # Sample 2
    [1000, 800, 400],  # Sample 3
])

print(f"shape of timesteps: {timesteps.shape}")
# Latents - use sample_id * 100 + timestep_position to track
latents = torch.tensor([
    [10, 20, 30],  # Sample 0: latent at t=1000, t=800, t=600
    [110, 120, 130],  # Sample 1
    [210, 220, 230],  # Sample 2
    [310, 320, 330],  # Sample 3
])

log_probs = torch.tensor([
    [0.1, 0.2, 0.3],  # Sample 0
    [1.1, 1.2, 1.3],  # Sample 1
    [2.1, 2.2, 2.3],  # Sample 2
    [3.1, 3.2, 3.3],  # Sample 3
])

advantages = torch.tensor([0.5, 1.5, -0.5, 2.5])  # One per sample
prompt_embeds = torch.randn(total_batch_size, 77, 768)  # Dummy embeddings

samples = {
    "timesteps": timesteps,
    "latents": latents,
    "log_probs": log_probs,
    "advantages": advantages,
    "prompt_embeds": prompt_embeds,
}

print("\n" + "=" * 80)
print("ORIGINAL DATA")
print("=" * 80)
print("\nTimesteps (shape: [batch, time]):")
print(samples["timesteps"])
print("\nLatents (shape: [batch, time]):")
print(samples["latents"])
print("\nLog probs (shape: [batch, time]):")
print(samples["log_probs"])
print("\nAdvantages (shape: [batch]):")
print(samples["advantages"])

print("\n" + "=" * 80)
print("STEP 1: SHUFFLE ALONG BATCH DIMENSION")
print("=" * 80)

# Shuffle samples along batch dimension
perm = torch.randperm(total_batch_size)
print(f"\nBatch permutation: {perm.tolist()}")

samples = {k: v[perm] for k, v in samples.items()}

print("\nTimesteps after batch shuffle:")
print(samples["timesteps"])
print("\nLatents after batch shuffle:")
print(samples["latents"])

print("\n" + "=" * 80)
print("STEP 2: SHUFFLE ALONG TIME DIMENSION (INDEPENDENTLY FOR EACH SAMPLE)")
print("=" * 80)

# Generate independent permutations for each sample's timesteps
perms = torch.stack(
    [torch.randperm(num_timesteps) for _ in range(total_batch_size)]
)
print(f"shape of perms {perms.shape}")
print(f"\nTime permutations for each sample:")
for i in range(total_batch_size):
    print(f"  Sample {i}: {perms[i].tolist()}")

# Apply time shuffling to time-dependent keys
for key in ["timesteps", "latents", "log_probs"]:
    samples[key] = samples[key][
        torch.arange(total_batch_size)[:, None],
        perms,
    ]

print(f"shape of samples after time shuffling: {samples['timesteps'].shape}")
print("\n--- After Time Shuffling ---")
print("\nTimesteps (each sample's timesteps are shuffled independently):")
print(samples["timesteps"])
print("\nLatents (shuffled to match timesteps):")
print(samples["latents"])
print("\nLog probs (shuffled to match timesteps):")
print(samples["log_probs"])

print("\n** KEY INSIGHT: Each sample's timesteps are in different order **")
print("** This means we'll train on timesteps in random order **")

print("\n" + "=" * 80)
print("STEP 3: REBATCH FOR TRAINING")
print("=" * 80)

# Reshape to create minibatches
samples_batched = {
    k: v.reshape(-1, train_batch_size, *v.shape[1:])
    for k, v in samples.items()
}
print(f"shape of samples_batched['timesteps']: {samples_batched['timesteps'].shape}")

num_minibatches = samples_batched["timesteps"].shape[0]
print(f"\nNumber of minibatches: {num_minibatches}")
print(f"Each minibatch has {train_batch_size} samples")

print("\n--- Minibatch Structure ---")
for i in range(num_minibatches):
    print(f"\nMinibatch {i}:")
    print(f"  Timesteps shape: {samples_batched['timesteps'][i].shape}")
    print(f"  Timesteps:\n{samples_batched['timesteps'][i]}")
    print(f"  Latents:\n{samples_batched['latents'][i]}")
    print(f"  Advantages: {samples_batched['advantages'][i]}")

print("\n" + "=" * 80)
print("STEP 4: CONVERT TO LIST OF DICTS (for easier iteration)")
print("=" * 80)

# Convert dict of tensors -> list of dicts
samples_batched_list = [
    dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())
]
print(f"samples_batched_list {samples_batched_list}")
print(f"\nNumber of minibatch dicts: {len(samples_batched_list)}")

print("\n" + "=" * 80)
print("STEP 5: TRAINING LOOP SIMULATION")
print("=" * 80)

for i, minibatch in enumerate(samples_batched_list):
    print(f"\n--- Processing Minibatch {i} ---")
    print(f"Batch size: {minibatch['timesteps'].shape[0]}")
    print(f"Timesteps in this minibatch:\n{minibatch['timesteps']}")
    
    # Loop over timesteps
    for j in range(num_timesteps):
        print(f"\n  Timestep position {j}:")
        print(f"    Timestep values: {minibatch['timesteps'][:, j].tolist()}")
        print(f"    Latents: {minibatch['latents'][:, j].tolist()}")
        print(f"    Log probs: {minibatch['log_probs'][:, j].tolist()}")
        print(f"    Advantages: {minibatch['advantages'].tolist()}")
        print(f"    -> Forward pass with UNet using these values")
        print(f"    -> Compute loss and accumulate gradients")

print("\n" + "=" * 80)
print("SUMMARY: WHY THIS DESIGN?")
print("=" * 80)
print("""
1. BATCH SHUFFLE: Ensures different samples are grouped together in each epoch
   - Reduces correlation between samples in a minibatch

2. TIME SHUFFLE: Each sample's timesteps are processed in random order
   - Sample 0 might process: t=600, then t=1000, then t=800
   - Sample 1 might process: t=800, then t=600, then t=1000
   - This breaks temporal correlation in training

3. REBATCHING: Split samples into minibatches of size train_batch_size
   - Allows gradient accumulation across multiple minibatches
   - Each minibatch processes all its timesteps before moving to next minibatch

4. TRAINING: For each minibatch, loop through all timesteps
   - Process timestep 0 for all samples in minibatch
   - Process timestep 1 for all samples in minibatch
   - Process timestep 2 for all samples in minibatch
   - Accumulate gradients across timesteps within minibatch
""")
print("=" * 80)

TOY EXAMPLE: Time Dimension Shuffling and Rebatching in DDPO

Setup:
  Total samples: 4
  Number of timesteps: 3
  Train batch size: 2
shape of timesteps: torch.Size([4, 3])

ORIGINAL DATA

Timesteps (shape: [batch, time]):
tensor([[1000,  800,  100],
        [1000,  800,  200],
        [1000,  800,  300],
        [1000,  800,  400]])

Latents (shape: [batch, time]):
tensor([[ 10,  20,  30],
        [110, 120, 130],
        [210, 220, 230],
        [310, 320, 330]])

Log probs (shape: [batch, time]):
tensor([[0.1000, 0.2000, 0.3000],
        [1.1000, 1.2000, 1.3000],
        [2.1000, 2.2000, 2.3000],
        [3.1000, 3.2000, 3.3000]])

Advantages (shape: [batch]):
tensor([ 0.5000,  1.5000, -0.5000,  2.5000])

STEP 1: SHUFFLE ALONG BATCH DIMENSION

Batch permutation: [1, 2, 3, 0]

Timesteps after batch shuffle:
tensor([[1000,  800,  200],
        [1000,  800,  300],
        [1000,  800,  400],
        [1000,  800,  100]])

Latents after batch shuffle:
tensor([[110, 120, 130],
        [2