# Load diffuser

In [None]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from diffusers import UNet2DModel, DDPMScheduler

scheduler = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256", use_safetensors = True)
pretrained_model = UNet2DModel.from_pretrained("google/ddpm-celebahq-256")

# By disabling the clipping, the mean of the distribution used by the scheduler will match exactly the predicted_mean
scheduler.config["clip_sample"] = False
scheduler = DDPMScheduler.from_config(scheduler.config)

pretrained_model = pretrained_model.to(device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


scheduler_config.json:   0%|          | 0.00/256 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/790 [00:00<?, ?B/s]

An error occurred while trying to fetch google/ddpm-celebahq-256: google/ddpm-celebahq-256 does not appear to have a file named diffusion_pytorch_model.safetensors.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.


diffusion_pytorch_model.bin:   0%|          | 0.00/455M [00:00<?, ?B/s]

# Helper functions

In [None]:
import PIL.Image
import numpy as np

def display_sample(sample, label = ""):
  image_processed = sample.cpu().permute(0, 2, 3, 1)
  image_processed = (image_processed + 1.0) * 127.5
  image_processed = image_processed.numpy().astype(np.uint8)
  image = PIL.Image.fromarray(image_processed[0])
  display(label)
  display(image)

# Diffusion trajectory classes

In [None]:
import torch
from torch import Tensor
from torch.distributions import Normal
from typing import List, Dict, Iterator, Tuple, Optional, Callable
from diffusers import DDPMScheduler
from dataclasses import dataclass, asdict


@dataclass
class DiffusionStep:
    """Represents a single step in the diffusion trajectory.

    Attributes:
        timestep: The timestep index in the diffusion process
        current_sample: The noisy latent at this timestep (Xt)
        prev_sample: The previous sample in the trajectory (Xt-1)
        pred_noise: The predicted noise at this timestep (from original model)
        log_prob: The log probability of this step (from original model)
    """
    timestep: int
    current_sample: Tensor
    prev_sample: Tensor
    pred_noise: Tensor
    log_prob: Tensor
    mean: Tensor
    variance: Tensor

    def compute_log_prob(self, model, scheduler) -> Tensor:
        """Compute log probability using a different model with the same trajectory data.

        Args:
            model: The updated diffusion model
            scheduler: The noise scheduler

        Returns:
            Tensor: Updated log probability for this step
        """
        # Get device
        device = next(model.parameters()).device

        # Get model prediction for the noise
        current_sample = self.current_sample.detach()
        prev_sample = self.prev_sample.detach()

        residual = model(current_sample, self.timestep).sample

        # Get the distribution parameters for p(Xt-1 | Xt)
        t = self.timestep
        alpha_t = scheduler.alphas[t].to(device)
        alpha_t_bar = scheduler.alphas_cumprod[t].to(device)
        beta_t = scheduler.betas[t].to(device)

        # Previous timestep's alpha_cumprod (ensuring it exists)
        prev_t = max(0, t-1) if isinstance(t, int) else torch.maximum(torch.zeros_like(t), t-1)
        alpha_t_prev_bar = scheduler.alphas_cumprod[prev_t].to(device)

        # Calculate predicted mean and variance
        predicted_mean = (1.0 / torch.sqrt(alpha_t)) * (
            current_sample - (beta_t / torch.sqrt(1.0 - alpha_t_bar)) * residual
        )
        variance = beta_t * (1 - alpha_t_prev_bar) / (1 - alpha_t_bar)

        # Create normal distribution and compute log probability
        dist = Normal(predicted_mean, torch.sqrt(variance))
        log_prob = dist.log_prob(prev_sample)

        # Sum over all dimensions except batch
        log_prob = log_prob.mean(dim=list(range(1, len(log_prob.shape))))

        return log_prob, predicted_mean, variance


class Trajectory:
    """Represents a full diffusion trajectory from noise to image.

    This class allows iteration over diffusion steps and supports re-computation
    of log probabilities with updated models without re-sampling the trajectory.
    """

    def __init__(
        self,
        model,
        scheduler: DDPMScheduler,
        device: torch.device,
        num_inference_steps: Optional[int] = None,
        starting_noise: Optional[Tensor] = None
    ):
        """Initialize and generate a diffusion trajectory.

        Args:
            model: The diffusion model used to generate the trajectory
            scheduler: The noise scheduler
            device: The device to use for computation
            num_inference_steps: Optional number of inference steps
            starting_noise: Optional pre-defined noise to start from
        """
        self.model = model
        self.scheduler = scheduler
        self.device = device

        # Use provided number of timesteps or default to scheduler
        self.num_inference_steps = num_inference_steps or scheduler.config.num_train_timesteps // 50

        # Generate or use provided starting noise
        if starting_noise is None:
            latent_shape = (1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
            self.starting_noise = torch.randn(latent_shape, device=device)
        else:
            self.starting_noise = starting_noise.to(device)

        # Generate the trajectory
        self.steps = self._generate_trajectory()

    def _generate_trajectory(self) -> List[DiffusionStep]:
        """Generate the complete diffusion trajectory.

        Returns:
            List[DiffusionStep]: A list of diffusion steps from noise to image
        """
        steps = []

        # Initialize the scheduler and start with pure noise
        self.scheduler.set_timesteps(self.num_inference_steps)
        latent = self.starting_noise.clone()

        # Generate trajectory by iterating through the diffusion process
        for i, t in enumerate(self.scheduler.timesteps):
            # Get model prediction
            with torch.no_grad():
                pred_noise = self.model(latent, t).sample

            # Step the scheduler to get the next latent
            scheduler_output = self.scheduler.step(pred_noise, t, latent)
            prev_sample = scheduler_output.prev_sample

            # Create step
            step = DiffusionStep(
                timestep=t,
                current_sample=latent.clone(),  # Current latent (Xt)
                prev_sample=prev_sample.clone(),  # Next latent in the trajectory (Xt-1)
                pred_noise=pred_noise,
                log_prob=None,
                mean=None,
                variance=None
            )
            steps.append(step)

            # Compute log probability
            step.log_prob, step.mean, step.variance = step.compute_log_prob(self.model, self.scheduler)
            step.log_prob = step.log_prob.detach()
            step.mean = step.mean.detach()
            step.variance = step.variance.detach()

            # Move to the next step
            latent = prev_sample

        return steps

    def __iter__(self) -> Iterator[DiffusionStep]:
        """Allow iteration over trajectory steps."""
        return iter(self.steps)

    def __len__(self) -> int:
        """Return the number of steps in the trajectory."""
        return len(self.steps)

    def __getitem__(self, idx) -> DiffusionStep:
        """Get a specific step by index."""
        return self.steps[idx]

# Reward model

In [None]:
from transformers import pipeline

def gender_reward(img_tensor):

  pipe = pipeline("image-classification", model="rizvandwiki/gender-classification")

  image_processed = img_tensor.cpu().permute(0, 2, 3, 1)
  image_processed = (image_processed + 1.0) * 127.5
  image_processed = image_processed.numpy().astype(np.uint8)
  image = PIL.Image.fromarray(image_processed[0])

  classification = pipe(image)
  for class_pred in classification:
    if class_pred["label"] == "male":
      if class_pred["score"] >= 0.5:
        return class_pred["score"]*2
      else:
        return class_pred["score"]

# Analyze trajectories

In [None]:
trajectories_val = []
total_reward = 0
for _ in range(2):
  trajectories_val.append(Trajectory(model, scheduler, device, num_inference_steps=50))
  reward = gender_reward(trajectories_val[-1][-1].prev_sample)
  total_reward += reward
  display_sample(trajectories_val[-1][-1].prev_sample)
print(f"Total reward: {total_reward}")

In [None]:
asdict(trajectories_val[0][0])

In [None]:
from IPython.display import clear_output
import time

for step in trajectories_val[0]:
  if step.timestep%100 == 0:
    display_sample(step.prev_sample, f"Timestep {step.timestep}")
    time.sleep(2)
    clear_output(wait=True)

In [None]:
import matplotlib.pyplot as plt
# Prepare data for visualization
all_timesteps = []
all_log_probs = []

# Extract timesteps and log probabilities from trajectories
for i, trajectory in enumerate(trajectories_val[:3]):
  timesteps = []
  log_probs = []
  for step in trajectory:
    timesteps.append(step.timestep.item() if torch.is_tensor(step.timestep) else step.timestep)
    log_probs.append(step.log_prob.item())

  all_timesteps.append(timesteps)
  all_log_probs.append(log_probs)

# Plot individual trajectories
for i in range(len(all_timesteps)):
  plt.plot(all_timesteps[i], all_log_probs[i], label=f'Trajectory {i+1}', alpha=0.7)

plt.title('Log Probabilities Across Timesteps (Individual Trajectories)')
plt.xlabel('Timestep')
plt.ylabel('Log Probability')
plt.legend()
plt.grid(True, alpha=0.3)

# Fine-tune with REINFORCE

In [None]:
def get_gradient_norm(model):
    total_norm = 0.0
    for param in model.parameters():
        if param.grad is not None:
            param_norm = param.grad.detach().data.norm(2)  # L2 norm
            total_norm += param_norm.item() ** 2
    return total_norm ** 0.5  # Square root to get final norm


In [None]:
def generate_dataset(model, scheduler, device, prompts, batch_size, group_size):
    """
    The final dataset will have (prompts x group_size) trajectories.
    The shape will be (n_batches, batch_size (same as n_groups_per_batch), group_size).
    The last batch can have fewer groups if the number of prompts is not divisible by batch_size.
    """
    steps_dataset = []

    # Iterate over prompts in chunks of batch_size
    for i in range(0, len(prompts), batch_size):
        batch = []
        for prompt in prompts[i:i + batch_size]:
            group = [Trajectory(model, scheduler, device, num_inference_steps=50) for _ in range(group_size)]
            batch.append(group)
        steps_dataset.append(batch)

    return steps_dataset


In [None]:
import copy
from dataclasses import dataclass, asdict
import pandas as pd

def train_with_reinforce(model, scheduler, dataset, train_epochs=5, lr=1e-6):
    updated_model = model
    updated_model.train()
    optimizer = torch.optim.AdamW(updated_model.parameters(), lr=lr)

    # Logging data
    log_data = []

    # Training loop
    for epoch in range(train_epochs):
        total_loss = 0
        total_reward = 0
        for batch_idx, batch in enumerate(dataset):
            n_trajectories = len(batch)*len(batch[0])
            for group_idx, group in enumerate(batch):
              trajectories = group
              # Get advantages of each trajectory in the group
              rewards = [gender_reward(trajectory[-1].prev_sample) for trajectory in trajectories]
              print("Rewards: ", rewards)
              total_reward += sum(rewards)
              if max(rewards) < 0.5: # No men
                print("Skipping group")
                n_trajectories -= len(group)
                continue

              group_mean = torch.tensor(rewards).mean()
              group_std = (torch.tensor(rewards).std()+1e-5)
              advantages = [(reward-group_mean)/group_std for reward in rewards]
              print("Advantages: ", advantages)

              # Process each trajectory
              for i, trajectory in enumerate(trajectories):
                  sample_prob = 0.0
                  for step_idx, step in enumerate(trajectory):
                    log_prob_new, mean, var = step.compute_log_prob(updated_model, trajectory.scheduler)

                    # Importance Sampling Ratio (exp(log(p_new/p_old)) = exp(log_p_new - log_p_old))
                    importance_ratio = torch.exp(log_prob_new - step.log_prob)

                    # Advantage
                    advantage = advantages[i]

                    # PPO clipping
                    #clipped_ratio = torch.clamp(importance_ratio, 1 - 1e-4, 1 + 1e-4)
                    clipped_ratio = importance_ratio
                    loss_clip = torch.min(importance_ratio * advantage, clipped_ratio * advantage)
                    #print(f"Importance ratio = {importance_ratio} | Clipped ratio = {clipped_ratio}")

                    # KL regularization
                    kl = torch.distributions.kl_divergence(
                        torch.distributions.Normal(step.mean, torch.sqrt(step.variance)),
                        torch.distributions.Normal(mean, torch.sqrt(var))
                    )
                    kl = kl.mean(dim=list(range(1, len(kl.shape))))

                    # Compute the loss
                    loss_base = -1*loss_clip + 0.0*kl
                    loss_total = loss_base.view(loss_base.size(0), -1).sum(dim=1).mean()
                    # Average across trajectories
                    loss_total = loss_total/n_trajectories
                    loss_total.backward()
                    sample_prob += log_prob_new.item()
                    # if get_gradient_norm(updated_model) == 0:
                    #   print("Gradient norm: ", get_gradient_norm(updated_model))

                    # Log metrics
                    log_data.append({
                        "epoch": epoch,
                        "batch": batch_idx,
                        "group": group_idx,
                        "sample": i,
                        "step": step_idx,
                        "log_prob_new": log_prob_new.item(),
                        "log_prob_old": step.log_prob.item(),
                        "importance_ratio": importance_ratio.item(),
                        "clipped_ratio": clipped_ratio.item(),
                        "kl_loss": kl.item(),
                        "reward": rewards[i],
                        "advantage": advantage.item(),
                        "loss_clip": loss_clip.item(),
                        "loss_base": loss_base.item(),
                        "loss_total": loss_total.item(),
                        "gradient": get_gradient_norm(updated_model)
                    })

                  display_sample(trajectory[-1].prev_sample, label = f"Advantage = {advantage} | Total Log prob = {sample_prob}")
                  total_loss += sample_prob

            # Update the model
            #torch.nn.utils.clip_grad_norm_(updated_model.parameters(), max_norm=1.0)
            optimizer.step()
            optimizer.zero_grad()
        print(f"Epoch reward: {total_reward}")

    # Convert log data to DataFrame
    log_df = pd.DataFrame(log_data)

    return updated_model, log_df


In [None]:
full_epochs = 1 # Each full epoch include sampling + training # 2/2
train_epochs = 20 # Specifies how many times the model iterates over the same sampled trajectories

In [None]:
import copy
model = copy.deepcopy(pretrained_model).to(device)
dfs =[]
for i in range(full_epochs):
  # Generate dataset of trajectories (n_batches, n_groups_per_batch (batch_size), group_size)
  print(f"{20*'-*'} Full epoch {i} {20*'-*'}")
  dataset = generate_dataset(model=model,
                             scheduler=scheduler,
                             device=device,
                             prompts = 4*[None],
                             batch_size=1,
                             group_size=5)
  model, df = train_with_reinforce(model, scheduler, dataset, train_epochs, lr=1e-7)
  dfs.append(df)

In [None]:
for i in range(len(dfs)):
  dfs[i]["sampling_epoch"] = i
concated_df = pd.concat(dfs)
concated_df.to_csv("reinforce_log_1_sampling_20_epochs_no_clip.csv")

In [None]:
from google.colab import files
files.download('reinforce_log_1_sampling_20_epochs_no_clip.csv')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
len(concated_df)

2500

In [None]:
total_reward = 0
for _ in range(30):
  trajectorie = Trajectory(model, scheduler, device, num_inference_steps=50)
  reward = gender_reward(trajectorie[-1].prev_sample)
  total_reward += reward
  display_sample(trajectorie[-1].prev_sample, label=f"Reward = {reward}")
print(f"Total reward: {total_reward}")

In [None]:
total_reward = 0
for _ in range(30):
  trajectorie_ref = Trajectory(pretrained_model, scheduler, device, num_inference_steps=50)
  reward = gender_reward(trajectorie_ref[-1].prev_sample)
  total_reward += reward
  display_sample(trajectorie_ref[-1].prev_sample)
print(f"Total reward: {total_reward}")

In [None]:
concated_df.groupby("sampling_epoch")["epoch"].value_counts()

Unnamed: 0_level_0,Unnamed: 1_level_0,count
sampling_epoch,epoch,Unnamed: 2_level_1
0,0,500
0,1,500
0,2,500
0,3,500
1,0,750
1,1,750
1,2,750
1,3,750


In [None]:
concated_df.groupby("sampling_epoch")

Unnamed: 0,epoch,batch,group,sample,step,log_prob_new,log_prob_old,importance_ratio,clipped_ratio,kl_loss,reward,advantage,loss_clip,loss_base,loss_total,gradient,sampling_epoch
0,0,1,0,0,0,-7.916809,-7.916809,1.000000,1.000000,0.000000e+00,1.943117,tensor(1.0817),1.081703,-1.081703,-0.216341,0.061077,0
1,0,1,0,0,1,-8.003198,-8.003198,1.000000,1.000000,0.000000e+00,1.943117,tensor(1.0817),1.081703,-1.081703,-0.216341,0.120144,0
2,0,1,0,0,2,-7.968703,-7.968703,1.000000,1.000000,0.000000e+00,1.943117,tensor(1.0817),1.081703,-1.081703,-0.216341,0.177481,0
3,0,1,0,0,3,-7.958009,-7.958009,1.000000,1.000000,0.000000e+00,1.943117,tensor(1.0817),1.081703,-1.081703,-0.216341,0.234267,0
4,0,1,0,0,4,-7.973542,-7.973542,1.000000,1.000000,0.000000e+00,1.943117,tensor(1.0817),1.081703,-1.081703,-0.216341,0.290235,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2995,3,2,0,4,45,-6.276108,-6.275907,0.999799,0.999900,4.913418e-07,0.015227,tensor(-0.4466),-0.446536,0.446536,0.089307,0.437074,1
2996,3,2,0,4,46,-5.748247,-5.747997,0.999750,0.999900,6.206404e-07,0.015227,tensor(-0.4466),-0.446536,0.446536,0.089307,0.437074,1
2997,3,2,0,4,47,-4.793023,-4.792648,0.999625,0.999900,8.707821e-07,0.015227,tensor(-0.4466),-0.446536,0.446536,0.089307,0.437074,1
2998,3,2,0,4,48,-2.636160,-2.636009,0.999850,0.999900,1.432957e-06,0.015227,tensor(-0.4466),-0.446536,0.446536,0.089307,0.437074,1


Investigate impact of each interval of timesteps
If importance ragio is clipped, it is infficient

In [None]:
concated_df["batch"].value_counts()

Unnamed: 0_level_0,count
batch,Unnamed: 1_level_1
1,2000
2,2000
0,1000


In [None]:
import pandas as pd
import numpy as np  # Needed for np.where()

# Assign 'reward_sign' based on whether 'advantage' is positive or not
concated_df['reward_sign'] = np.where(concated_df["advantage"] > 0, "positive", "negative")

# Group by epoch and reward sign, then calculate the mean of 'loss_total'
loss_by_epoch_and_reward = concated_df.groupby(['sampling_epoch', 'epoch', 'reward_sign'])['loss_total'].mean().reset_index()

# Display the results
print(loss_by_epoch_and_reward)


    sampling_epoch  epoch reward_sign  loss_total
0                0      0    negative    0.175073
1                0      0    positive   -0.175073
2                0      1    negative    0.175064
3                0      1    positive   -0.175067
4                0      2    negative    0.175072
5                0      2    positive   -0.175082
6                0      3    negative    0.175071
7                0      3    positive   -0.175082
8                1      0    negative    0.114325
9                1      0    positive   -0.228650
10               1      1    negative    0.114325
11               1      1    positive   -0.228660
12               1      2    negative    0.114322
13               1      2    positive   -0.228659
14               1      3    negative    0.114318
15               1      3    positive   -0.228653


In [None]:
concated_df.groupby(['sampling_epoch', 'epoch', 'batch', 'sample']).sum()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,group,step,log_prob_new,log_prob_old,importance_ratio,clipped_ratio,kl_loss,reward,advantage,loss_clip,loss_base,loss_total,gradient,reward_sign
sampling_epoch,epoch,batch,sample,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
0,0,1,0,0,1225,-365.012650,-365.012650,50.000000,50.000000,0.000000,97.155839,tensor(54.0852),54.085141,-54.085141,-10.817029,67.890319,positivepositivepositivepositivepositivepositi...
0,0,1,1,0,1225,-364.538139,-364.538139,50.000000,50.000000,0.000000,0.734839,tensor(-36.4757),-36.475667,36.475667,7.295134,149.370738,negativenegativenegativenegativenegativenegati...
0,0,1,2,0,1225,-364.959413,-364.959413,50.000000,50.000000,0.000000,0.687021,tensor(-36.5206),-36.520579,36.520579,7.304116,70.983645,negativenegativenegativenegativenegativenegati...
0,0,1,3,0,1225,-364.654691,-364.654691,50.000000,50.000000,0.000000,0.664115,tensor(-36.5421),-36.542094,36.542094,7.308419,122.866284,negativenegativenegativenegativenegativenegati...
0,0,1,4,0,1225,-365.519147,-365.519147,50.000000,50.000000,0.000000,98.612428,tensor(55.4532),55.453205,-55.453205,-11.090641,132.946472,positivepositivepositivepositivepositivepositi...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1,3,2,0,0,1225,-365.157062,-365.148899,49.991840,49.997261,0.000085,0.748135,tensor(-22.3441),-22.342906,22.342906,4.468581,24.110515,negativenegativenegativenegativenegativenegati...
1,3,2,1,0,1225,-365.374617,-365.364704,49.990092,49.996812,0.000096,0.703902,tensor(-22.3948),-22.393336,22.393336,4.478667,58.501948,negativenegativenegativenegativenegativenegati...
1,3,2,2,0,1225,-365.215573,-365.206342,49.990774,49.996733,0.000076,0.722233,tensor(-22.3738),-22.372318,22.372318,4.474464,89.906045,negativenegativenegativenegativenegativenegati...
1,3,2,3,0,1225,-364.808795,-364.816023,50.007233,50.001114,0.000072,98.406029,tensor(89.4417),89.443680,-89.443680,-17.888736,46.663392,positivepositivepositivepositivepositivepositi...


In [None]:
pipe = pipeline("image-classification", model="rizvandwiki/gender-classification")

Device set to use cuda:0


In [None]:
from transformers import pipeline

def gender_reward_test(img_tensor):

  pipe = pipeline("image-classification", model="rizvandwiki/gender-classification")

  image_processed = img_tensor.cpu().permute(0, 2, 3, 1)
  image_processed = (image_processed + 1.0) * 127.5
  image_processed = image_processed.numpy().astype(np.uint8)
  image = PIL.Image.fromarray(image_processed[0])

  classification = pipe(image)
  print(classification)
  for class_pred in classification:
    if class_pred["label"] == "male":
      if class_pred["score"] >= 0.7:
        return class_pred["score"]*2
      else:
        return class_pred["score"]

In [None]:
for batch_idx, batch in enumerate(dataset):
  for group_idx, group in enumerate(batch):
    for i, trajectory in enumerate(group):
      print(gender_reward_test(trajectory[-1].prev_sample))
      display_sample(trajectory[-1].prev_sample)


In [None]:
sampling_1 = dfs[0]
sampling_2 = dfs[1]

In [None]:
sampling_1.groupby("epoch").mean()

Unnamed: 0_level_0,batch,group,sample,step,importance_ratio,kl_loss,reward,advantage,loss_base,loss_total
epoch,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
0,1.5,0.0,2.0,24.5,0.999948,9.848513e-08,0.409476,0.0,1e-05,2e-06
1,1.5,0.0,2.0,24.5,1.000018,2.35602e-07,0.409476,0.0,-1.5e-05,-3e-06


In [None]:
sampling_2.groupby("epoch").mean()

Unnamed: 0_level_0,batch,group,sample,step,importance_ratio,kl_loss,reward,advantage,loss_base,loss_total
epoch,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
0,1.5,0.0,2.0,24.5,0.999995,6.103175e-08,0.985205,-0.0,5e-06,9.533912e-07
1,1.5,0.0,2.0,24.5,1.000074,3.553911e-07,0.985205,-0.0,-1e-05,-1.974672e-06


To do:
- Check loss of reference trajectories over the iterations. I expect the loss of male samples to decrease and the loss of female samples to increase. Since the trajectories will be fixed and neither importance sampling nor regularization is being used, maybe the loss can increase at some point as the updated model diverges from the original
- Batches
- Advantages ok
- PPO ok
- KL
- Gradient norm
- What about sampling the timesteps as Pinterest paper?