In [1]:
!pip install diffusers[torch]>=0.29.0

In [2]:
!pip install peft>=0.6.0

In [3]:
# Clone the ImageReward repository (containing data for testing)
!git clone https://github.com/THUDM/ImageReward.git
!cd ImageReward

# Install the integrated package `image-reward`
!pip install image-reward
!pip install git+https://github.com/openai/CLIP.git

Cloning into 'ImageReward'...
remote: Enumerating objects: 224, done.[K
remote: Counting objects: 100% (91/91), done.[K
remote: Compressing objects: 100% (49/49), done.[K
remote: Total 224 (delta 52), reused 49 (delta 41), pack-reused 133 (from 1)[K
Receiving objects: 100% (224/224), 4.30 MiB | 34.64 MiB/s, done.
Resolving deltas: 100% (89/89), done.
Collecting image-reward
  Downloading image_reward-1.5-py3-none-any.whl.metadata (12 kB)
Collecting timm==0.6.13 (from image-reward)
  Downloading timm-0.6.13-py3-none-any.whl.metadata (38 kB)
Collecting fairscale==0.4.13 (from image-reward)
  Downloading fairscale-0.4.13.tar.gz (266 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m266.3/266.3 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Installing backend dependencies ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Downloa

## Importing modules

In [4]:
import logging
import torch
import contextlib
import random
from functools import partial
import math
import ImageReward as RM
import os
import wandb
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union, Callable
import tempfile

from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
import csv
import collections
import numpy as np
import functools
import time
import tqdm
tqdm = partial(tqdm.tqdm, dynamic_ncols=True)

from PIL import Image
import torch.nn as nn
from torchvision.transforms import Compose, Resize, CenterCrop, Normalize
try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC


import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import cast_training_params
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.torch_utils import is_compiled_module
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, GaussianFourierProjection
from diffusers.models.unets.unet_2d_blocks import get_down_block, DownBlock2D, CrossAttnDownBlock2D

from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
    StableDiffusionPipeline,
    rescale_noise_cfg,
)
try:
    from diffusers.utils import randn_tensor
except ImportError:
    from diffusers.utils.torch_utils import randn_tensor

from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler

In [5]:
@dataclass()
class Config:
    save_freq = 10
    num_checkpoint_limit = 3
    mixed_precision = "bf16"
    allow_tf32 = True
    use_lora = True
    pretrained = {
        "model": "runwayml/stable-diffusion-v1-5",
        "revision": "main",
    }
    sample = {
        "num_steps": 50,
        "eta": 1.0,
        "guidance_scale": 5.0,
        "batch_size": 2,
        "num_batches_per_epoch": 2,
    }
    train = {
        "use_8bit_adam": False,
        "learning_rate": 3.0e-4,
        "adam_beta1": .9,
        "adam_beta2": .999,
        "adam_weight_decay": 1.0e-4,
        "adam_epsilon": 1.0e-8,
        "max_grad_norm": 1.0,
        "num_inner_epochs": 1,
        "cfg": True,
        "adv_clip_max": 5,
        "clip_range": 1.0e-4,
        "timestep_fraction": 1.0,
        "lora_rank": 4,
        "batch_size": 2,
        "gradient_accumulation_steps": 2,
        "reward_exp": 1.0e+2,
        "flow_learning_rate": 3.0e-4,
        "anneal": "linear",
        "unetreg": 1.0e+0,
        "klpf": -1,
    }
    seed = 0
    num_epochs = 100
    wandb = False
    prompt_fn = "drawbench"
    reward_fn = "imagereward"
    prompt_fn_kwargs = { }

In [6]:
def set_seed(seed):
    import random
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

    import numpy as np
    np.random.seed(seed)

    torch.manual_seed(seed)
    torch.random.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.cuda.empty_cache()

    logging.info(f'Using seed: {seed}')

def image_postprocess(x):
    return torch.clamp((x + 1) / 2, 0, 1)

In [7]:
BASE_PATH = "/kaggle/input/alignment-gflownet-assets/assets"

@functools.lru_cache()
def read_csv(path):
    with open(path, "r") as f:
        reader = csv.DictReader(f)
        reader = [row for row in reader]

    info = collections.defaultdict(list)
    for row in reader:
        info[row["Category"]].append(row["Prompts"])
    filtered_info = {}
    for k, v in info.items():
        if k in ["Misspellings", "Rare Words"]:
            continue
        filtered_info[k] = v[2:]
    drawbench_prompt_ls = sum(filtered_info.values(), [])
    return drawbench_prompt_ls

def drawbench():
    drawbench_prompt_ls = read_csv(os.path.join(BASE_PATH, "DrawBench Prompts.csv" ))
    return random.choice(drawbench_prompt_ls), {}

In [8]:
def imagereward(dtype=torch.float32, device="cuda"):
    reward_model = RM.load("ImageReward-v1.0")
    reward_model.to(dtype).to(device)

    rm_preprocess = Compose([
        Resize(224, interpolation=BICUBIC),
        CenterCrop(224),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])

    def _fn(images, prompts, metadata):
        dic = reward_model.blip.tokenizer(prompts,
                padding="max_length", truncation=True,
                return_tensors="pt", max_length=reward_model.blip.tokenizer.model_max_length)
        device = images.device
        input_ids, attention_mask = dic.input_ids.to(device), dic.attention_mask.to(device)
        reward = reward_model.score_gard(input_ids, attention_mask, rm_preprocess(images))
        return reward.reshape(images.shape[0]).float(), {}

    return _fn

def unwrap_model(model):
    model = model._orig_mod if is_compiled_module(model) else model
    return model

In [9]:
class ConditionalFlow(torch.nn.Module):
    def __init__(self,
        # sample_size: Optional[int] = None,
        in_channels: int = 4,
        # center_input_sample: bool = False,
        flip_sin_to_cos: bool = True,
        freq_shift: int = 0,
        down_block_types: Tuple[str] = ("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"),
        only_cross_attention: Union[bool, Tuple[bool]] = False,
        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
        layers_per_block: Union[int, Tuple[int]] = 2,
        downsample_padding: int = 1,
        act_fn: str = "silu",
        norm_num_groups: Optional[int] = 32,
        norm_eps: float = 1e-5,
        cross_attention_dim: Union[int, Tuple[int]] = 1280,
        encoder_hid_dim: Optional[int] = None,
        encoder_hid_dim_type: Optional[str] = None,
        attention_head_dim: Union[int, Tuple[int]] = 8,
        timestep_post_act: Optional[str] = None,
        time_cond_proj_dim: Optional[int] = None,
        conv_in_kernel: int = 3,
        class_embeddings_concat: bool = False):

        super().__init__()
        timestep_input_dim = block_out_channels[0]
        self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos = flip_sin_to_cos, downscale_freq_shift = freq_shift)
        time_embed_dim = block_out_channels[0] * 4
        self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn,
                                               post_act_fn=timestep_post_act, cond_proj_dim=time_cond_proj_dim)
        conv_in_padding = (conv_in_kernel - 1) // 2
        self.conv_in = nn.Conv2d(
            in_channels, block_out_channels[0], kernel_size = conv_in_kernel, padding = conv_in_padding
        )
        self.encoder_hid_proj = None
        self.down_blocks = nn.ModuleList([])
        if isinstance(attention_head_dim, int):
            attention_head_dim = (attention_head_dim,) * len(down_block_types)

        if isinstance(cross_attention_dim, int):
            cross_attention_dim = (cross_attention_dim,) * len(down_block_types)

        if isinstance(layers_per_block, int):
            layers_per_block = [layers_per_block] * len(down_block_types)

        if class_embeddings_concat:
            blocks_time_embed_dim = time_embed_dim * 2
        else:
            blocks_time_embed_dim = time_embed_dim

        output_channel = block_out_channels[0]
        for i, down_block_type in enumerate(down_block_types):
            input_channel = output_channel
            output_channel = block_out_channels[i]

            down_block = get_down_block(
                    down_block_type,
                    num_layers=layers_per_block[i],
                    in_channels=input_channel,
                    out_channels=output_channel,
                    temb_channels=blocks_time_embed_dim,
                    # add_downsample=not is_final_block,
                    add_downsample=True,
                    resnet_eps=norm_eps,
                    resnet_act_fn=act_fn,
                    resnet_groups=norm_num_groups,
                    cross_attention_dim=cross_attention_dim[i],
                    # attn_num_head_channels=attention_head_dim[i], # old diffusers version
                    num_attention_heads=attention_head_dim[i],
                    attention_head_dim=attention_head_dim[i], # can be annotated
                    downsample_padding=downsample_padding,
            )
            self.down_blocks.append(down_block)

        self.pool = nn.AvgPool2d(4, stride = 4)
        self.fc = nn.Linear(block_out_channels[-1], 1)


    def forward(self, sample, timesteps, encoder_hidden_states, attention_mask: Optional[torch.Tensor] = None,
               cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.Tensor] = None,):

        dtype = next(self.down_blocks.parameters()).dtype
        t_emb = self.time_proj(timesteps)
        t_emb = t_emb.to(dtype=dtype)
        emb = self.time_embedding(t_emb)

        sample = self.conv_in(sample)
        for downsample_block in self.down_blocks:
            if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
                sample, res_samples = downsample_block(
                    hidden_states = sample,
                    temb = emb,
                    encoder_hidden_states=encoder_hidden_states,
                    attention_mask=attention_mask,
                    cross_attention_kwargs=cross_attention_kwargs,
                    encoder_attention_mask=encoder_attention_mask,
                )
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

        sample = self.pool(sample)
        sample = sample.view(sample.size(0), -1)
        sample = self.fc(sample).squeeze()
        return sample

In [10]:
def _left_broadcast(t, shape):
    assert t.ndim <= len(shape)
    return t.reshape(t.shape + (1,) * (len(shape) - t.ndim)).broadcast_to(shape)

def _get_variance(self, timestep, prev_timestep):
    alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device)
    alpha_prod_t_prev = torch.where(
        prev_timestep.cpu() >= 0,
        self.alphas_cumprod.gather(0, prev_timestep.cpu()),
        self.final_alpha_cumprod,
    ).to(timestep.device)
    beta_prod_t = 1 - alpha_prod_t
    beta_prod_t_prev = 1 - alpha_prod_t_prev
    variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
    return variance

def ddim_step_with_logprob(
    self: DDIMScheduler,
    model_output: torch.FloatTensor,
    timestep: int,
    sample: torch.FloatTensor,
    eta: float = 1.0,
    use_clipped_model_output: bool = False,
    generator = None,
    prev_sample: Optional[torch.FloatTensor] = None,
    calculate_pb: bool = False,
    logp_mean = True,
    prev_timestep: int = None,
) -> Union[DDIMSchedulerOutput, Tuple]:
    
    assert isinstance(self, DDIMScheduler)
    if self.num_inference_steps is None:
        raise ValueEror(
            "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
        )

    # 1. get previous step value (=t-1)
    if prev_timestep is None:
        prev_timestep = (
            timestep - self.config.num_train_timesteps // self.num_inference_steps
        )
    prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1)

    # 2. compute alphas, betas
    alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu())
    alpha_prod_t_prev = torch.where(
        prev_timestep.cpu() >= 0,
        self.alphas_cumprod.gather(0, prev_timestep.cpu()),
        self.final_alpha_cumprod,
    )
    alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device)
    alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to(sample.device)
    beta_prod_t = 1 - alpha_prod_t
    
    # 3. compute predicted original sample from predicted noise also called
    if self.config.prediction_type == "epsilon":
        pred_original_sample = (
            sample - beta_prod_t ** (0.5) * model_output
        ) / alpha_prod_t ** 0.5
        pred_epsilon = model_output
    elif self.config.prediction_type == "sample":
        pred_original_sample = model_output
        pred_epsilon = (
            sample - alpha_prod_t ** (0.5) * pred_original_sample
        ) / beta_prod_t ** 0.5
    elif self.config.prediction_type == "v_prediction":
        pred_original_sample = (alpha_prod_t ** 0.5) * sample - (
            beta_prod_t ** 0.5
        ) * model_output
        pred_epsilon = (alpha_prod_t**0.5) * model_output + (
            beta_prod_t**0.5
        ) * sample
    else:
        raise ValueError(
            f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
            " `v_prediction`"
        )
    # 4. Clip or threshold "predicted x_0"
    if self.config.thresholding:
        pred_original_sample = self._threshold_sample(pred_original_sample)
    elif self.config.clip_sample:
        pred_original_sample = pred_original_sample.clamp(
            -self.config.clip_sample_range, self.config.clip_sample_range
        )

    # 5. compute variance: "sigma_t(η)" -> see formula (16)
    variance = _get_variance(self, timestep, prev_timestep)
    std_dev_t = eta * variance ** (0.5)
    std_dev_t = _left_broadcast(std_dev_t, sample.shape).to(sample.device)

    if use_clipped_model_output:
        pred_epsilon = (
            sample - alpha_prod_t ** (0.5) * pred_original_sample
        ) / (beta_prod_t) ** 0.5

    # 6. compute "direction pointing to x_t" of formula (12)
    prev_sample_direction = (1 - alpha_prod_t_prev - std_dev_t ** 2) ** (0.5) * pred_epsilon 

    # 7. compute x_t without "random noise" 
    prev_sample_mean = (alpha_prod_t_prev ** (0.5) * pred_original_sample + prev_sample_direction)

    if prev_sample is not None and generator is not None:
        raise ValueError(
            "Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
            " `prev_sample` stays `None`."
        )
    if prev_sample is None:
        variance_noise = randn_tensor(
            model_output.shape,
            generator = generator,
            device = model_output.device,
            dtype = model_output.dtype,
        )
        prev_sample = prev_sample_mean + std_dev_t * variance_noise

    log_prob = (
        -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t ** 2))
        - torch.log(std_dev_t)
        - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
    )

    if logp_mean:
        log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
    else:
        log_prob = log_prob.sum(dim=tuple(range(1, log_prob.ndim)))

    if calculate_pb:
        assert prev_sample is not None
        alpha_ddim = alpha_prod_t / alpha_prod_t_prev  # (bs, 4, 64, 64)
        pb_mean = alpha_ddim.sqrt() * prev_sample
        pb_std = (1 - alpha_ddim).sqrt()
        log_pb = (
                -((sample.detach() - pb_mean.detach()) ** 2) / (2 * (pb_std ** 2))
                - torch.log(pb_std)
                - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
        )
        if logp_mean:
            log_pb = log_pb.mean(dim=tuple(range(1, sample.ndim)))
        else:
            log_pb = log_pb.sum(dim=tuple(range(1, sample.ndim)))
        return prev_sample.type(sample.dtype), log_prob, log_pb

    else:
        return prev_sample.type(sample.dtype), log_prob

@torch.no_grad
def pred_orig_latent(self: DDIMScheduler, model_output, sample: torch.FloatTensor, timestep: int):
    alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu())
    alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device)
    alpha_prod_t = alpha_prod_t.to(sample.dtype)
    beta_prod_t = 1 - alpha_prod_t

    if self.config.prediction_type == "epsilon":
        pred_original_sample = (
            sample - beta_prod_t ** (0.5) * model_output
        ) / alpha_prod_t ** (0.5)
    elif self.config.prediction_type == "sample":
        pred_original_sample = model_output
    elif self.config.prediction_type == "v_prediction":
        pred_original_sample = (alpha_prod_t**0.5) * sample - (
            beta_prod_t**0.5
        ) * model_output
    else:
        raise ValueError(
            f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
            " `v_prediction`"
        )
    return pred_original_sample

In [11]:
@torch.no_grad()
def pipeline_with_logprob(
    self: StableDiffusionPipeline,
    prompt: Union[str, List[str]] = None,
    height: Optional[int] = None,
    width: Optional[int] = None,
    num_inference_steps: int = 50,
    guidance_scale: float = 5,
    negative_prompt: Optional[Union[str, List[str]]] = None,
    num_images_per_prompt: Optional[int] = 1,
    eta: float = 0.0,
    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
    latents: Optional[torch.FloatTensor] = None,
    prompt_embeds: Optional[torch.FloatTensor] = None,
    negative_prompt_embeds: Optional[torch.FloatTensor] = None,
    output_type: Optional[str] = "pil",
    return_dict: bool = True,
    callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
    callback_steps: int = 1,
    cross_attention_kwargs: Optional[Dict[str, Any]] = None,
    guidance_rescale: float = 0.0,

    batch_size = None, dtype=None,
    device = None,
    calculate_pb = False, logp_mean = True,
    return_unetoutput = False,
):
    
    # 0. Default height and width to unet
    if height is None:
        height = height or self.unet.config.sample_size * self.vae_scale_factor
    if width is None:
        width = width or self.unet.config.sample_size * self.vae_scale_factor

    # 1. Check inputs. Raise error if not correct
    if hasattr(self, "check_inputs"):  # DDPMPipeline does not have this method
        self.check_inputs(
            prompt,
            height,
            width,
            callback_steps,
            negative_prompt,
            prompt_embeds,
            negative_prompt_embeds,
        )

    # 2. Define call parameters
    if batch_size is None:
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

    if device is None:
        device = self._execution_device
    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
    # corresponds to doing no classifier free guidance.
    do_classifier_free_guidance = guidance_scale > 1.0

    # 3. Encode input prompt
    if prompt_embeds is not None:
        text_encoder_lora_scale = (
            cross_attention_kwargs.get("scale", None)
            if cross_attention_kwargs is not None
            else None
        )
        prompt_embeds = self._encode_prompt(
            prompt,
            device,
            num_images_per_prompt,
            do_classifier_free_guidance,
            negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            lora_scale=text_encoder_lora_scale,
        )

    # 4. Prepare timesteps
    if num_inference_steps is None:
        timesteps = self.scheduler.timesteps
        num_inference_steps = len(timesteps)
    else:
        self.scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps = self.scheduler.timesteps

    # 5. Prepare latent variables
    num_channels_latents = self.unet.config.in_channels
    if prompt_embeds is not None:
        latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            height,
            width,
            prompt_embeds.dtype,
            device,
            generator,
            latents,
        )

        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # {'eta': 1.0, 'generator': None}

    else:
        shape = (batch_size, num_channels_latents, height, width)
        latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
        extra_step_kwargs = {'eta': eta, 'generator': generator}

    # 7. Denoising loop
    # self.scheduler.order is 1, not sure what it does
    num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order

    all_latents = [latents]
    all_log_probs = []
    all_log_pbs = []
    unet_outputs = []
    with self.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            # expand the latents if we are doing classifier free guidance
            latent_model_input = (
                torch.cat([latents] * 2) if do_classifier_free_guidance else latents
            )
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

            # predict the noise residual
            if prompt_embeds is not None:
                noise_pred = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    cross_attention_kwargs=cross_attention_kwargs,
                    return_dict=False,
                )[0]
            else:
                noise_pred = self.unet(
                    latent_model_input, t, return_dict=False
                )[0]

            # perform guidance
            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (
                    noise_pred_text - noise_pred_uncond
                )
            if return_unetoutput:
                unet_outputs.append(noise_pred.detach())

            # by default not used (as guidance_rescale = 0.0)
            if do_classifier_free_guidance and guidance_rescale > 0.0:
                # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
                noise_pred = rescale_noise_cfg(
                    noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
                )

            # compute the previous noisy sample x_t -> x_t-1
            prev_timestep = timesteps[i + 1] if i < num_inference_steps-1 else None
            if calculate_pb:
                latents, log_prob, log_pb = ddim_step_with_logprob(
                    self.scheduler, noise_pred, t, latents,
                    calculate_pb=calculate_pb, logp_mean=logp_mean,
                    prev_timestep=prev_timestep, #
                    **extra_step_kwargs
                )
                all_log_pbs.append(log_pb)
            else:
                latents, log_prob = ddim_step_with_logprob(
                    self.scheduler, noise_pred, t, latents,
                    prev_timestep=prev_timestep, #
                    **extra_step_kwargs
                )

            all_latents.append(latents)
            all_log_probs.append(log_prob)

            # call the callback, if provided
            if i == len(timesteps) - 1 or (
                (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
            ):
                progress_bar.update()
                if callback is not None and i % callback_steps == 0:
                    callback(i, t, latents)

    if not output_type == "latent":
        image = self.vae.decode(
            latents / self.vae.config.scaling_factor, return_dict=False
        )[0]
        image, has_nsfw_concept = self.run_safety_checker(
            image, device, prompt_embeds.dtype
        )
    else:
        image = latents
        has_nsfw_concept = None

    if has_nsfw_concept is None:
        do_denormalize = [True] * image.shape[0]
    else:
        do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

    # At least for the cifar10 DDPM, the generated image is in [-1, 1],
    # so we need this postprocessing to make it [0, 1]
    if prompt_embeds is not None:
        image = self.image_processor.postprocess(
            image, output_type=output_type, do_denormalize=do_denormalize
        )
        # Offload last model to CPU
        if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
            self.final_offload_hook.offload()
    else:
        # image = (image / 2 + 0.5).clamp(0, 1)
        image = image_postprocess(image)

    assert not (calculate_pb and return_unetoutput), "Cannot return both log_pb and unet_outputs"
    if calculate_pb:
        return image, has_nsfw_concept, all_latents, all_log_probs, all_log_pbs
    if return_unetoutput:
        return image, has_nsfw_concept, all_latents, all_log_probs, unet_outputs

    return image, has_nsfw_concept, all_latents, all_log_probs

In [12]:
def main():
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger = logging.getLogger(__name__)
    cfg = Config()
    cfg.gpu_type = torch.cuda.get_device_name() if torch.cuda.is_available() else "CPU"
    logger.info(f"GPU type: {cfg.gpu_type}")
    output_dir = os.path.join("./output")
    os.makedirs(output_dir, exist_ok=True)
    logger.info(f"\n{cfg}")
    set_seed(cfg.seed)
    weight_dtype = torch.float32
    if cfg.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif cfg.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    pipeline = StableDiffusionPipeline.from_pretrained(
        cfg.pretrained["model"], revision=cfg.pretrained["revision"], torch_dtype=weight_dtype,
    )
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    scheduler_config = {}
    scheduler_config.update(pipeline.scheduler.config)
    pipeline.scheduler = DDIMScheduler.from_config(scheduler_config)
    pipeline.vae.requires_grad_(False)
    pipeline.text_encoder.requires_grad_(False)
    pipeline.vae.to(device, dtype=weight_dtype)
    pipeline.text_encoder.to(device, dtype=weight_dtype)

    pipeline.safety_checker = None
    pipeline.set_progress_bar_config(
        position=1,
        disable=False,
        leave=False,
        desc="Timestep",
        dynamic_ncols=True,
    )

    unet = pipeline.unet
    unet.requires_grad_(False)
    for param in unet.parameters():
        param.requires_grad_(False)
    unet.to(device, dtype=weight_dtype)

    unet_lora_config = LoraConfig(
        r=cfg.train["lora_rank"], lora_alpha=cfg.train["lora_rank"],
        init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"],
    )
    unet.add_adapter(unet_lora_config)
    if cfg.mixed_precision in ["fp16", "bf16"]:
        cast_training_params(unet, dtype=torch.float32)
    lora_layers = filter(lambda p: p.requires_grad, unet.parameters())
    scaler = None
    if cfg.mixed_precision in ["fp16", "bf16"]:
        scaler = torch.cuda.amp.GradScaler()

    if cfg.train["use_8bit_adam"]:
        try:
            import bitsandbytes as bnb
        except ImportError:
            raise ImportError(
                "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
            )
        optimizer_cls = bnb.optim.AdamW8bit
    else:
        optimizer_cls = torch.optim.AdamW

    reward_fn = imagereward(weight_dtype, device)
    # Generate negative prompt embeddings.
    neg_prompt_embed = pipeline.text_encoder(
        pipeline.tokenizer(
            [""],
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=pipeline.tokenizer.model_max_length,
        ).input_ids.to(device)
    )[0]

    sample_neg_prompt_embeds = neg_prompt_embed.repeat(cfg.sample["batch_size"], 1, 1)
    train_neg_prompt_embeds = neg_prompt_embed.repeat(cfg.train["batch_size"], 1, 1)

    def func_autocast():
        return torch.cuda.amp.autocast(dtype=weight_dtype)
    if cfg.use_lora:
        autocast = contextlib.nullcontext
    else:
        autocast = func_autocast

    unet.to(device)

    def decode(latents):
        image = pipeline.vae.decode(
            latents / pipeline.vae.config.scaling_factor, return_dict=False
        )[0]
        do_denormalize = [True] * image.shape[0]
        image = pipeline.image_processor.postprocess(image,
                    output_type = "pt", do_denormalize=do_denormalize
                )
        return image

    flow_model = ConditionalFlow(in_channels = 4, block_out_channels=(64, 128, 256, 256),
                                layers_per_block=1, cross_attention_dim=pipeline.text_encoder.config.hidden_size)
    flow_model = flow_model.to(device, dtype=torch.float32)
    autocast_flow = func_autocast
    params = [
        {"params": lora_layers, "lr": cfg.train["learning_rate"]},
        {"params": flow_model.parameters(), "lr": cfg.train["learning_rate"]}
    ]
    optimizer = optimizer_cls(
        params,
        betas = (cfg.train["adam_beta1"], cfg.train["adam_beta2"]),
        weight_decay = cfg.train["adam_weight_decay"],
        eps = cfg.train["adam_epsilon"],
    )
    result = collections.defaultdict(dict)
    result["config"] = cfg
    start_time = time.time()
    #######################################################
    samples_per_epoch = (
        cfg.sample["batch_size"] * cfg.sample["num_batches_per_epoch"]
    )
    total_train_batch_size = (
        cfg.train["batch_size"] * cfg.train["gradient_accumulation_steps"]
    )

    logger.info("***** Running training *****")
    logger.info(f"  Num Epochs = {cfg.num_epochs}")
    logger.info(f"  Sample batch size per device = {cfg.sample['batch_size']}")
    logger.info(f"  Train batch size per device = {cfg.train['batch_size']}")
    logger.info(
        f"  Gradient Accumulation steps = {cfg.train['gradient_accumulation_steps']}"
    )
    logger.info("")
    logger.info(f"  Total number of samples per epoch = test_bs * num_batch_per_epoch = {samples_per_epoch}")
    logger.info(
        f"  Total train batch size = train_bs * grad_accumul = {total_train_batch_size}"
    )
    logger.info(
        f"  Number of gradient updates per inner epoch = samples_per_epoch // total_train_batch_size = {samples_per_epoch // total_train_batch_size}"
    )
    logger.info(f"  Number of inner epochs = {cfg.train['num_inner_epochs']}")

    assert cfg.sample['batch_size'] >= cfg.train['batch_size']
    assert cfg.sample['batch_size'] % cfg.train['batch_size'] == 0 # not necessary
    assert samples_per_epoch % total_train_batch_size == 0

    first_epoch = 0
    global_step = 0
    
    for epoch in range(first_epoch, cfg.num_epochs):
        if cfg.train["anneal"] in ["linear"]:
            ratio = min(1, epoch / (0.5 * cfg.num_epochs))
        else:
            ratio = 1.0
        reward_exp_ep = cfg.train["reward_exp"] * ratio
        def reward_transform(value):
            return value * reward_exp_ep

        num_diffusion_steps = cfg.sample["num_steps"]
        pipeline.scheduler.set_timesteps(num_diffusion_steps, device = device)
        scheduler_dt = pipeline.scheduler.timesteps[0] - pipeline.scheduler.timesteps[1]
        num_train_timesteps = int(num_diffusion_steps * cfg.train['timestep_fraction'])
        accumulation_steps =  cfg.train["gradient_accumulation_steps"] * num_train_timesteps

        torch.cuda.empty_cache()
        unet.zero_grad()
        unet.eval()
        flow_model.zero_grad()

        if True:
            with torch.inference_mode():
                samples = []
                prompts = []
                for i in tqdm(
                    range(cfg.sample["num_batches_per_epoch"]),
                    desc=f"Epoch {epoch}: sampling",
                    disable=False,
                    position=0,
                ):
                    # generate prompts
                    prompts, prompt_metadata = zip(
                        *[
                            drawbench()
                            for _ in range(cfg.sample["batch_size"])
                        ]
                    )
                    # encode prompts
                    prompt_ids = pipeline.tokenizer(
                        prompts,
                        return_tensors="pt",
                        padding="max_length",
                        truncation=True,
                        max_length=pipeline.tokenizer.model_max_length,
                    ).input_ids.to(device)
                    prompt_embeds = pipeline.text_encoder(prompt_ids)[0]

                    # sample
                    with autocast():
                        ret_tuple = pipeline_with_logprob(
                            pipeline,
                            prompt_embeds = prompt_embeds,
                            negative_prompt_embeds = sample_neg_prompt_embeds,
                            num_inference_steps = num_diffusion_steps,
                            guidance_scale = cfg.sample["guidance_scale"],
                            eta = cfg.sample["eta"],
                            output_type = "pt",
                            return_unetoutput = cfg.train['unetreg'] > 0.,
                        )

                    if cfg.train["unetreg"] > 0:
                        images, _, latents, log_probs, unet_outputs = ret_tuple
                        unet_outputs = torch.stack(unet_outputs, dim = 1)
                    else:
                        images, _, latents, log_probs = ret_tuple

                    latents = torch.stack(latents, dim = 1)
                    log_probs = torch.stack(log_probs, dim = 1)
                    timesteps = pipeline.scheduler.timesteps.repeat(
                        cfg.sample["batch_size"], 1
                    )

                    rewards = reward_fn(images, prompts, prompt_metadata)

                    samples.append(
                        {
                            "prompts": prompts,
                            "prompt_metadata": prompt_metadata,
                            "prompt_ids": prompt_ids,
                            "prompt_embeds": prompt_embeds,
                            "timesteps": timesteps,
                            "latents": latents[:, :-1],
                            "next_latents": latents[:, 1:],
                            "log_probs": log_probs,
                            "rewards": rewards,
                        }
                    )
                    if cfg.train["unetreg"] > 0:
                        samples[-1]["unet_outputs"] = unet_outputs


                # wait for all rewards to be computed
                for sample in tqdm(
                    samples,
                    desc = "Waiting for rewards",
                    disable = False,
                    position = 0,
                ):
                    rewards, reward_metadata = sample["rewards"]
                    sample["rewards"] = torch.as_tensor(rewards, device=device)

            # collate samples into dict where each entry has shape
            new_samples = {}
            for k in samples[0].keys():
                if k in ["prompts", "prompt_metadata"]:
                    new_samples[k] = [item for s in samples for item in s[k]]
                else:
                    new_samples[k] = torch.cat([s[k] for s in samples])
            samples = new_samples

            # this is a hack to force wandb to log the images as JPEGs instead of PNGs
            with tempfile.TemporaryDirectory() as tmpdir:
                for i, image in enumerate(images):
                    pil = Image.fromarray(
                        (image.cpu().float().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
                    )
                    pil = pil.resize((256, 256))
                    pil.save(os.path.join(tmpdir, f"{i}.jpg"))
                if cfg.wandb:
                    wandb.log(
                        {
                            "images": [
                                wandb.Image(
                                    os.path.join(tmpdir, f"{i}.jpg"),
                                    caption=f"{prompt} | {reward:.2f}",
                                )
                                for i, (prompt, reward) in enumerate(
                                    zip(prompts, rewards)
                                )
                            ],
                        },
                        step=global_step,
                    )
                    
            
            rewards = samples["rewards"].to(dtype=samples["rewards"].dtype, device=device).cpu().float().numpy()
            result["reward_mean"][global_step] = rewards.mean()
            result["reward_std"][global_step] = rewards.std()
            logger.info(f"global_step: {global_step}  rewards: {rewards.mean().item():.3f}")
            if cfg.wandb:
                wandb.log(
                    {
                        "reward_mean": rewards.mean(),
                        "reward_std": rewards.std(),
                    },
                    step = global_step,
                )
            del samples["prompt_ids"]
            total_batch_size, num_timesteps = samples["timesteps"].shape
            assert (total_batch_size == cfg.sample["batch_size"] * cfg.sample["num_batches_per_epoch"])
            assert num_timesteps == num_diffusion_steps
            

            ############################## TRAINING #####################
            for inner_epoch in range(cfg.train["num_inner_epochs"]):
                 # shuffle samples along batch dimension
                perm = torch.randperm(total_batch_size, device = device)
                for k, v in samples.items():
                    if k in ["prompts", "prompt_metadata"]:
                        samples[k] = [v[i] for i in perm]
                    elif k in ["unet_outputs"]:
                        samples[k] = v[perm]
                    else:
                        samples[k] = v[perm]

                perms = torch.stack(
                    [
                        torch.randperm(num_timesteps, device = device)
                        for _ in range(total_batch_size)
                    ]
                )
                key_ls = ["timesteps", "latents", "next_latents", "log_probs"]
                for key in key_ls:
                     samples[key] = samples[key][torch.arange(total_batch_size, device=device)[:, None], perms]
                if cfg.train["unetreg"] > 0:
                    samples["unet_outputs"] = samples["unet_outputs"][torch.arange(total_batch_size, device=device)[:, None], perms]

                ### rebatch for training
                samples_batched = {}
                for k, v in samples.items():
                    if k in ["prompts", "prompt_metadata"]:
                        samples_batched[k] = [v[i: i + cfg.train["batch_size"]] for i in range(0, len(v), cfg.train["batch_size"])]
                    elif k in ["unet_outputs"]:
                        samples_batched[k] = v.reshape(-1, cfg.train["batch_size"], *v.shape[1:])
                    else:
                        samples_batched[k] = v.reshape(-1, cfg.train["batch_size"], *v.shape[1:])

                samples_batched = [
                    dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())
                ]
                unet.train()
                flow_model.train()
                info = defaultdict(list)
                for i, sample in tqdm(
                    list(enumerate(samples_batched)),
                    desc=f"Epoch {epoch}.{inner_epoch}: training",
                    position=0,
                    disable=False
                ):
                    torch.cuda.empty_cache()
                    if cfg.train["cfg"]:
                        embeds = torch.cat([train_neg_prompt_embeds, sample["prompt_embeds"]])
                    else:
                        embeds = sample["prompt_embeds"]

                    for j in tqdm(range(num_train_timesteps), desc="Timestep", position=1, leave=False, disable=False):
                        with autocast():
                            if cfg.train["cfg"]:
                                noise_pred = unet(
                                    torch.cat([sample["latents"][:, j]] * 2),
                                    torch.cat([sample["timesteps"][:, j]] * 2),
                                    embeds,
                                ).sample
                                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                                noise_pred = (
                                    noise_pred_uncond + cfg.sample["guidance_scale"] * (noise_pred_text - noise_pred_uncond)
                                )
                                if cfg.train["unetreg"] > 0:
                                    unetreg = (noise_pred - sample["unet_outputs"][:, j]) ** 2
                                    unetreg = torch.mean(unetreg, dim = (1, 2, 3))
                            else:
                                noise_pred = unet(
                                    sample["latents"][:, j],
                                    sample["timesteps"][:, j],
                                    embeds,
                                ).sample
                                if cfg.train["unetreg"] > 0:
                                    unetreg = (noise_pred - sample["unet_outputs"][:, j]) ** 2

                            _, log_pf, log_pb = ddim_step_with_logprob(
                                pipeline.scheduler, noise_pred,
                                sample["timesteps"][:, j],
                                sample["latents"][:, j],
                                eta = cfg.sample["eta"],
                                prev_sample=sample["next_latents"][:, j], calculate_pb=True,
                            )

                        with autocast_flow():
                            flow = flow_model(sample["latents"][:, j], sample["timesteps"][:, j], sample["prompt_embeds"])
                            timestep_next = torch.clamp(sample["timesteps"][:, j] - scheduler_dt, min=0)
                            flow_next = flow_model(sample["next_latents"][:, j], timestep_next, sample["prompt_embeds"])

                        with autocast(), torch.no_grad():
                            unet_output = unet(sample["latents"][:, j], sample["timesteps"][:, j], sample["prompt_embeds"]).sample
                            latent = pred_orig_latent(pipeline.scheduler, unet_output, sample["latents"][:, j], sample["timesteps"][:, j])
                        with torch.inference_mode():
                            logr_temp = reward_fn(decode(latent), sample["prompts"], sample["prompt_metadata"])[0]
                        logr = reward_transform(logr_temp)
                        flow = flow + logr
                        
                        with autocast(), torch.no_grad():
                            unet_output = unet(sample["next_latents"][:, j], timestep_next, sample["prompt_embeds"]).sample
                            latent_next = pred_orig_latent(pipeline.scheduler, unet_output, sample["next_latents"][:, j], timestep_next)
                        with torch.inference_mode():
                            logr_next_temp = reward_fn(decode(latent_next), sample["prompts"], sample["prompt_metadata"])[0]
                        logr_next = reward_transform(logr_next_temp)
                        flow_next = flow_next + logr_next

                        info["log_pf"].append(torch.mean(log_pf).detach())
                        info["flow"].append(torch.mean(flow).detach())
                        info["log_pb"].append(torch.mean(log_pb).detach())

                        if cfg.train["klpf"] > 0:
                            losses_flow = (flow + log_pf.detach() - log_pb.detach() - flow_next) ** 2

                            flow_next_klpf = flow_next.detach()
                            log_pb_klpf, log_pf_klpf = log_pb.detach(), log_pf.detach()
                            reward_db = (flow_next_klpf + log_pb_klpf - log_pf_klpf - flow).detach()

                            # different gpu has different states, so cannot share a baseline
                            assert len(reward_db) > 1
                            rloo_baseline = (reward_db.sum() - reward_db) / (len(reward_db) - 1)
                            reward_ = (reward_db - rloo_baseline) ** 2
                            rloo_var = (reward_.sum() - reward_) / (len(reward_db) - 1)
                            advantages = (reward_db - rloo_baseline) / (rloo_var.sqrt() + 1e-8)
                            advantages = torch.clamp(advantages, -cfg['train']['adv_clip_max'], cfg['train']['adv_clip_max'])

                            ratio = torch.exp(log_pf - sample["log_probs"][:, j])
                            unclipped_losses = -advantages * ratio
                            clipped_losses = -advantages * torch.clamp(
                                ratio,
                                1.0 - cfg['train']['clip_range'],
                                1.0 + cfg['train']['clip_range'],
                            )
                            losses_klpf = torch.maximum(unclipped_losses, clipped_losses)
                            info["ratio"].append(torch.mean(ratio).detach())

                            losses = losses_flow + cfg['train']['klpf'] * losses_klpf
                            info["loss"].append(losses_flow.mean().detach())
                            info["loss_klpf"].append(losses_klpf.mean().detach())
                            torch.cuda.empty_cache() # clear comp graph for log_pf_next
                            
                        else:
                            losses_gfn = (flow + log_pf - log_pb - flow_next) ** 2
                            info["loss"].append(losses_gfn.mean().detach())
                            losses = losses_gfn

                        if cfg.train["unetreg"] > 0:
                            losses = losses + cfg.train["unetreg"] * unetreg
                            info["unetreg"].append(unetreg.mean().detach())
                        loss = torch.mean(losses)

                        if logr_tmp is not None:
                            info["logr"].append(torch.mean(logr_tmp).detach())

                        loss = loss / accumulation_steps
                        if scaler:
                            scaler.scale(loss).backward()
                        else:
                            loss.backward()

                        # prevent OOM
                        image_next = image = prev_sample_klpf = unet_output = latent = latent_next = latent_next_next = None
                        noise_pred_next_uncond = noise_pred_next_text = noise_pred_uncond = noise_pred_text = noise_pred = noise_pred_next = None
                        flow = flow_next = flow_next_next = logr = logr_next = logr_next_next = logr_next_tmp = logr_tmp = reward_db = advantages = None
                        _ = log_pf = log_pb = log_pf_next = log_pb_next = log_pf_klpf = log_pb_klpf = None
                        unetreg = unetreg_initial = losses = losses_flow = losses_klpf = losses_gfn = None

                    if ((j == num_train_timesteps - 1) and
                        (i + 1) % cfg.train['gradient_accumulation_steps'] == 0):
                        if scaler:
                            scaler.unscale_(optimizer)
                            torch.nn.utils.clip_grad_norm_(unet.parameters(), cfg.train['max_grad_norm'])
                            torch.nn.utils.clip_grad_norm_(flow_model.parameters(), cfg.train['max_grad_norm'])
                            scaler.step(optimizer)
                            scaler.update()
                        else:
                            torch.nn.utils.clip_grad_norm_(unet.parameters(), cfg.train['max_grad_norm'])
                            torch.nn.utils.clip_grad_norm_(flow_model.parameters(), cfg['train']['max_grad_norm'])
                        optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                    info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
                    for k, v in info.items():
                        result[k][global_step] = v.item()

                    info.update({"epoch": epoch})
                    result["epoch"][global_step] = epoch
                    result["time"][global_step] = time.time() - start_time

                    if cfg['wandb']:
                        wandb.log(info, step=global_step)
                    logger.info(f"global_step={global_step}  " +
                              " ".join([f"{k}={v:.3f}" for k, v in info.items()]))
                    info = defaultdict(list) # reset info dict

    pickle.dump(result, gzip.open(os.path.join(output_dir, f"result.json"), 'wb'))
    if epoch % cfg['save_freq'] == 0 or epoch == cfg['num_epochs'] - 1:
        save_path = os.path.join(output_dir, f"checkpoint_epoch{epoch}")
        unwrapped_unet = unwrap_model(unet)
        unet_lora_state_dict = convert_state_dict_to_diffusers(
            get_peft_model_state_dict(unwrapped_unet)
        )
        StableDiffusionPipeline.save_lora_weights(
            save_directory=save_path,
            unet_lora_layers=unet_lora_state_dict,
            is_main_process=True,
            safe_serialization=True,
        )
        logger.info(f"Saved state to {save_path}")


    if cfg.wandb:
        wandb.finish()
    

In [13]:
main()

model_index.json:   0%|          | 0.00/541 [00:00<?, ?B/s]

Fetching 15 files:   0%|          | 0/15 [00:00<?, ?it/s]

scheduler/scheduler_config.json:   0%|          | 0.00/308 [00:00<?, ?B/s]

safety_checker/config.json:   0%|          | 0.00/4.72k [00:00<?, ?B/s]

text_encoder/config.json:   0%|          | 0.00/617 [00:00<?, ?B/s]

tokenizer/special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

(…)ature_extractor/preprocessor_config.json:   0%|          | 0.00/342 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/492M [00:00<?, ?B/s]

tokenizer/merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

tokenizer/vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

tokenizer/tokenizer_config.json:   0%|          | 0.00/806 [00:00<?, ?B/s]

unet/config.json:   0%|          | 0.00/743 [00:00<?, ?B/s]

vae/config.json:   0%|          | 0.00/547 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

ImageReward.pt:   0%|          | 0.00/1.79G [00:00<?, ?B/s]

load checkpoint from /root/.cache/ImageReward/ImageReward.pt


med_config.json:   0%|          | 0.00/485 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

checkpoint loaded


Epoch 0: sampling:   0%|          | 0/2 [00:00<?, ?it/s]

Timestep:   0%|          | 0/50 [00:00<?, ?it/s]

Epoch 0: sampling:  50%|█████     | 1/2 [01:38<01:38, 98.78s/it]

Timestep:   0%|          | 0/50 [00:00<?, ?it/s]

Epoch 0: sampling: 100%|██████████| 2/2 [03:21<00:00, 100.59s/it]
Waiting for rewards: 100%|██████████| 2/2 [00:00<00:00, 10727.12it/s]
Epoch 0.0: training:   0%|          | 0/2 [00:00<?, ?it/s]
Timestep:   0%|          | 0/50 [00:00<?, ?it/s][A
Epoch 0.0: training:   0%|          | 0/2 [00:02<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 512.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 426.12 MiB is free. Process 2512 has 14.32 GiB memory in use. Of the allocated memory 13.91 GiB is allocated by PyTorch, and 296.33 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)