In [None]:
# !pip install diffusers
!pip install git+https://github.com/huggingface/diffusers.git
!pip install botorch==0.10.0
!pip install transformers scipy ftfy accelerate torchmetrics gpytorch umap-learn pytorch-lightning openai-clip open-clip-torch peft
!pip install "jax[cuda12_pip]==0.4.27" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [None]:
!git clone https://github.com/tgxs002/HPSv2.git

In [None]:
%cd HPSv2
!pip install -e .

In [None]:
# Built-in Python modules
import copy
import functools
import io
import json
import math
import os
import random
import warnings
from importlib import resources
from os.path import join
from warnings import filterwarnings

# Image and visualization libraries
import cv2 as cv
import matplotlib.pyplot as plt
from PIL import Image, ImageFile

# Torch and related libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor, autocast, randint
from torch.nn.utils import parametrize
from torch.nn.utils.parametrizations import orthogonal
from torch.utils.data import DataLoader, Dataset
import torchmetrics.functional.multimodal
import torchmetrics.multimodal.clip_score
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchvision.models import ResNet18_Weights, ResNet50_Weights
from torchvision.transforms.functional import pil_to_tensor

# Machine learning and deep learning libraries
import clip
import hpsv2
import numpy as np
import pandas as pd
import pytorch_lightning as pl
from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
from peft import LoraConfig
import tqdm
from tqdm import tqdm
from transformers import (
    CLIPModel, 
    CLIPProcessor, 
    CLIPTextModel, 
    CLIPTokenizer
)

# Diffusion libraries
from diffusers import (
    AutoencoderKL, 
    EulerDiscreteScheduler, 
    LMSDiscreteScheduler, 
    PNDMScheduler, 
    UNet2DConditionModel
)

# Networking and file handling
import requests

In [None]:
#Directories
AESTHETIC_MODEL_PATH = "/content/sac+logos+ava1-l14-linearMSE.pth" #change this to the actual path for aesthetic scorer loading
PROMPT_DATA_PATH = "/content/prompts_diffusiondb_DATA" #change this to the actual path for prompt data loading
INDIVIDUAL_TRAINING_PATH = '/content/Individual_Training' #change this to the actual path for saving individual training results
BATCH_TRAINING_PATH = '/content/Batch_Training' #change this to the actual path for saving batch training results

In [None]:
clip_score_fn = partial(clip_score, model_name_or_path="openai/clip-vit-large-patch14")

def calculate_clip_score(images, prompts):
    # images_int = (images * 255).astype("uint8")
    clip_score = clip_score_fn(torch.from_numpy(images).permute(0, 3, 1, 2), prompts).detach()
    return round(float(clip_score), 4)

def calculate_clip_score_diff(images, prompts): #images: torch Tensor
    # images_int = (images * 255).astype("uint8")
    clip_score = clip_score_fn(images, prompts)
    return clip_score

In [None]:
# Based on https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/fe88a163f4661b4ddabba0751ff645e2e620746e/simple_inference.py
# and https://github.com/mihirp1998/AlignProp/blob/trl_main/alignprop_trainer.py

def Freeze_Model(models):
    for model in models:
        for param in model.parameters():
            param.requires_grad = False

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(768, 1024),
            nn.Dropout(0.2),
            nn.Linear(1024, 128),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.Dropout(0.1),
            nn.Linear(64, 16),
            nn.Linear(16, 1),
        )

    def forward(self, embed):
        return self.layers(embed)

class AestheticScorer(nn.Module):
    def __init__(self, dtype):
        super().__init__()
        self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
        self.mlp = MLP()
        state_dict = torch.load(
            "/content/drive/MyDrive/sac+logos+ava1-l14-linearMSE.pth"
        )
        self.mlp.load_state_dict(state_dict)
        self.dtype = dtype
        self.eval()

    @torch.no_grad()
    def __call__(self, images):
        device = next(self.parameters()).device
        inputs = self.processor(images=images, return_tensors="pt")
        inputs = {k: v.to(self.dtype).to(device) for k, v in inputs.items()}
        embed = self.clip.get_image_features(**inputs)
        # normalize embedding
        embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
        return self.mlp(embed).squeeze(1)


class AestheticScorerDiff(torch.nn.Module):
    def __init__(self, dtype):
        super().__init__()
        self.target_size = 224
        self.normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                                          std=[0.26862954, 0.26130258, 0.27577711])
        self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
        self.mlp = MLP()
        state_dict = torch.load(
            AESTHETIC_MODEL_PATH
        )
        self.mlp.load_state_dict(state_dict)
        self.dtype = dtype
        self.eval()

    def __call__(self, images):
        device = next(self.parameters()).device
        im_pix = torchvision.transforms.Resize(self.target_size)(images)
        im_pix = self.normalize(im_pix).to(self.dtype)
        embed = self.clip.get_image_features(pixel_values=im_pix)
        embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
        return self.mlp(embed).squeeze(1)

class CLIP_Score_Diff(torch.nn.Module):
    def __init__(self, dtype, device):
        super().__init__()
        self.target_size = 224
        self.device = device
        self.normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                                          std=[0.26862954, 0.26130258, 0.27577711])
        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
        self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=dtype)
        # self.clip_model, _ = clip.load("ViT-B/32", device=device)
        self.dtype = dtype
        self.eval()

    def __call__(self, images, prompts):
        im_pix = torchvision.transforms.Resize(self.target_size)(images)
        im_pix = self.normalize(im_pix).to(self.dtype)
        text = self.tokenizer(prompts, padding=False, return_tensors="pt")

        image_features = self.clip_model.get_image_features(pixel_values=im_pix)
        text_features = self.clip_model.get_text_features(text.input_ids.to(self.device))

        # Normalize features and compute cosine similarity
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        similarity = (image_features @ text_features.T)
        return similarity

def hps_loss_fn(inference_dtype=None, device=None):
    model_name = "ViT-H-14"
    model, preprocess_train, preprocess_val = create_model_and_transforms(
        model_name,
        'laion2B-s32B-b79K',
        precision=inference_dtype,
        device=device,
        jit=False,
        force_quick_gelu=False,
        force_custom_text=False,
        force_patch_dropout=False,
        force_image_size=None,
        pretrained_image=False,
        image_mean=None,
        image_std=None,
        light_augmentation=True,
        aug_cfg={},
        output_dict=True,
        with_score_predictor=False,
        with_region_predictor=False
    )

    tokenizer = get_tokenizer(model_name)

    link = "https://huggingface.co/spaces/xswu/HPSv2/resolve/main/HPS_v2_compressed.pt"

    # Create the directory if it doesn't exist
    os.makedirs(os.path.expanduser('~/.cache/hpsv2'), exist_ok=True)
    checkpoint_path = f"{os.path.expanduser('~')}/.cache/hpsv2/HPS_v2_compressed.pt"

    # Download the file if it doesn't exist
    if not os.path.exists(checkpoint_path):
        response = requests.get(link, stream=True)
        total_size = int(response.headers.get('content-length', 0))

        with open(checkpoint_path, 'wb') as file, tqdm(
            desc="Downloading HPS_v2_compressed.pt",
            total=total_size,
            unit='iB',
            unit_scale=True,
            unit_divisor=1024,
        ) as progress_bar:
            for data in response.iter_content(chunk_size=1024):
                size = file.write(data)
                progress_bar.update(size)
    # force download of model via score
    hpsv2.score([], "")

    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['state_dict'])
    tokenizer = get_tokenizer(model_name)
    model = model.to(device, dtype=inference_dtype)
    Freeze_Model([model])
    model.eval()

    target_size =  224
    normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                                std=[0.26862954, 0.26130258, 0.27577711])

    def loss_fn(im_pix, prompts):
        x_var = torchvision.transforms.Resize(target_size)(im_pix)
        x_var = normalize(x_var).to(im_pix.dtype)
        caption = tokenizer(prompts)
        caption = caption.to(device)
        outputs = model(x_var, caption)
        image_features, text_features = outputs["image_features"], outputs["text_features"]
        logits = image_features @ text_features.T
        scores = torch.diagonal(logits)
        loss = 1.0 - scores
        return  loss, scores

    return loss_fn

In [None]:
#Rotation
def design_embedding_2d(embed: Tensor, angles: Tensor, text_emb_dim: int, start_dim: int):
    #embed: n*text_emb_dim
    #angles: 1*num_angles
    N = angles.shape[1]
    ind_even = torch.arange(0,text_emb_dim,2)
    ind_odd = torch.arange(1,text_emb_dim,2)
    for i in range(N):
        theta = angles[0][i]
        x = embed[:,(start_dim+ind_even+i)%text_emb_dim]
        y = embed[:,(start_dim+ind_odd+i)%text_emb_dim]
        temp1 = x*torch.cos(theta)-y*torch.sin(theta)
        temp2 = x*torch.sin(theta)+y*torch.cos(theta)
        embed[:,(start_dim+ind_even+i)%text_emb_dim] = temp1
        embed[:,(start_dim+ind_odd+i)%text_emb_dim] = temp2
    return embed

class clamp_var(nn.Module):
    def __init__(self, proj_dim, num_tokens, lb, ub):
        super(clamp_var, self).__init__()
        self.proj_dim = proj_dim
        self.num_tokens = num_tokens
        self.lb = lb
        self.ub = ub

    def forward(self, w):
        x = torch.zeros_like(w)
        x[:,(self.num_tokens*self.proj_dim):] = w[:,(self.num_tokens*self.proj_dim):].clamp(-torch.pi,torch.pi)
        x[:,:(self.num_tokens*self.proj_dim)] = w[:,:(self.num_tokens*self.proj_dim)].clamp(self.lb,self.ub)
        return x

#IPGO tokens
class ProjX(nn.Module):
    def __init__(self, dtype, device,
                 text_emb_dim, proj_dim, num_tokens,
                 lb, ub):
        super(ProjX, self).__init__()
        self.data_type = dtype
        self.text_emb_dim = text_emb_dim
        self.proj_dim = proj_dim
        self.num_tokens = num_tokens
        self.L = []
        self.L_total = orthogonal(nn.Linear(proj_dim*num_tokens, text_emb_dim, dtype=self.data_type, device=device))
        self.x = nn.Parameter(torch.cat(((torch.rand(num_tokens,num_tokens*proj_dim,dtype=self.data_type,device=device)*2-1)*ub,
                                         (torch.rand(num_tokens,2,dtype=self.data_type,device=device)*2-1)*torch.pi), dim=1))
        parametrize.register_parametrization(self, "x", clamp_var(proj_dim, num_tokens, lb, ub))

    def forward(self):
        #existing_tokens: batch*seq*emb
        angles = self.x[:,(self.num_tokens*self.proj_dim):]
        proj_xs = []
        for i in range(self.num_tokens):
            proj_x = self.x[i:(i+1), :(self.num_tokens*self.proj_dim)]
            with torch.no_grad():
                self.L_total.weight.data = F.normalize(self.L_total.weight.data, p=2, dim=1)
            proj_x = self.L_total(proj_x)
            proj_x = design_embedding_2d(proj_x,angles[i:(i+1),(self.num_tokens*self.proj_dim):],self.text_emb_dim,i).unsqueeze(1)
            proj_xs.append(proj_x)
        proj_xs = torch.cat(proj_xs, dim=1)
        return torch.mean(proj_xs, dim=0, keepdim=True)

In [None]:
data_type = torch.float32
inference_data_type = torch.float16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model_id = "sd-legacy/stable-diffusion-v1-5" #"CompVis/stable-diffusion-v1-4"

# 1. Load the autoencoder model which will be used to decode the latents into image space.
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=inference_data_type)
# vae.eval()

# 2. Load the tokenizer and text encoder to tokenize and encode the text.
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=inference_data_type)
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=inference_data_type)
# text_encoder.eval()

# 3. The UNet model for generating the latents.
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=inference_data_type) #
# unet.eval()

scheduler = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
# scheduler = EulerDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler", torch_dtype=data_type)

Freeze_Model([vae,text_encoder,unet])

In [None]:
with open(PROMPT_DATA_PATH) as file:
    prompt_bases_tot = [line.rstrip() for line in file]
num_prompts = len(prompt_bases_tot)

In [None]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

# Reward Definitions:
# aesthetic_model = AestheticScorerDiff(dtype=data_type)
# Freeze_Model([aesthetic_model])
# loss_fn = aesthetic_model.to(device)
loss_fn_hps = hps_loss_fn(inference_dtype=inference_data_type, device=device)
# loss_fn_clip = CLIP_Score_Diff(dtype=inference_data_type, device=device)

In [None]:
def image_pipeline_default(prompt_base, vae, text_encoder, unet, tokenizer, scheduler):
    with torch.no_grad():
        #Some default params
        height = 512                        # default height of Stable Diffusion
        width = 512                         # default width of Stable Diffusion
        num_inference_steps = 50            # Number of denoising steps
        guidance_scale = 7.5                # Scale for classifier-free guidance
        generator = torch.manual_seed(0)   # Seed generator to create the inital latent noise
        batch_size = 1


        #Load Models
        vae = vae.to(device)
        text_encoder = text_encoder.to(device)
        unet = unet.to(device)

        text_input = tokenizer(prompt_base, padding=False, max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
        with torch.no_grad():
            text_embeddings = text_encoder(text_input.input_ids.to(device))[0]

        max_length = text_input.input_ids.shape[-1]
        uncond_input = tokenizer(
            [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
        )
        with torch.no_grad():
            uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]

        text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

        #Latent Space
        latents = torch.randn(
            (batch_size, unet.config.in_channels, height // 8, width // 8),
            generator=generator,
            dtype=torch.float16
            )
        latents = latents.to(device)

        scheduler.set_timesteps(num_inference_steps)
        latents = latents * scheduler.init_noise_sigma

        timesteps = scheduler.timesteps
        for i, t in tqdm(enumerate(timesteps), total=len(timesteps), disable=True):
            # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
            latent_model_input = torch.cat([latents] * 2)

            latent_model_input = scheduler.scale_model_input(latent_model_input, t)

            # predict the noise residual
            if i < 48:
                # with torch.no_grad():
                noise_pred = (unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample).detach()
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            else:
                noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

            # perform guidance
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            latents = scheduler.step(noise_pred, t, latents).prev_sample

        # scale and decode the image latents with vae
        latents = 1 / 0.18215 * latents

        image = vae.decode(latents).sample
        image = (image / 2 + 0.5).clamp(0, 1)
    return image

def image_pipeline_IPGO(prompt_base, vae, text_encoder, unet, tokenizer, scheduler, proj_xs, proj_xs_suffix):
    with torch.no_grad():
        #Some default params
        height = 512                        # default height of Stable Diffusion
        width = 512                         # default width of Stable Diffusion
        num_inference_steps = 50            # Number of denoising steps
        guidance_scale = 7.5                # Scale for classifier-free guidance
        generator = torch.manual_seed(0)   # Seed generator to create the inital latent noise
        batch_size = 1


        #Load Models
        vae = vae.to(device)
        text_encoder = text_encoder.to(device)
        unet = unet.to(device)

        text_input = tokenizer(prompt_base, padding=False, max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
        with torch.no_grad():
            text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
            s_o = torch.mean(text_embeddings,dim=1)

        L = text_input.input_ids.shape[-1]
        text_embeddings = torch.cat((proj_xs.repeat(batch_size,1,1),
                                        text_embeddings,
                                        proj_xs_suffix.repeat(batch_size,1,1)),dim=1)
        sum_penalty = torch.sum(torch.abs(torch.mean(text_embeddings,dim=1)-s_o))/text_embeddings.shape[-1]

        max_length = text_input.input_ids.shape[-1]+proj_xs_suffix.shape[1]+proj_xs.shape[1]
        assert max_length <= 77, "Token number should be <= 77"
        uncond_input = tokenizer(
            [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
        )
        with torch.no_grad():
            uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]

        text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

        #Latent Space
        latents = torch.randn(
            (batch_size, unet.config.in_channels, height // 8, width // 8),
            generator=generator,
            dtype=torch.float16
            )
        latents = latents.to(device)

        scheduler.set_timesteps(num_inference_steps)
        latents = latents * scheduler.init_noise_sigma

        timesteps = scheduler.timesteps
        for i, t in tqdm(enumerate(timesteps), total=len(timesteps), disable=True):
            # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
            latent_model_input = torch.cat([latents] * 2)

            latent_model_input = scheduler.scale_model_input(latent_model_input, t)

            # predict the noise residual
            if i < 48:
                # with torch.no_grad():
                noise_pred = (unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample).detach()
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            else:
                noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

            # perform guidance
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            latents = scheduler.step(noise_pred, t, latents).prev_sample

        # scale and decode the image latents with vae
        latents = 1 / 0.18215 * latents

        image = vae.decode(latents).sample
        image = (image / 2 + 0.5).clamp(0, 1)
    return image

In [None]:
#Individual Training
if not os.path.exists(INDIVIDUAL_TRAINING_PATH):
    os.makedirs(INDIVIDUAL_TRAINING_PATH)
seed = 0

for k, prompt in enumerate(prompt_bases_tot):
    print('Prompt #'+str(k))
    flag = True
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    bound = 1

    #prefix
    PX = ProjX(data_type, device,
            text_emb_dim=768,
            proj_dim=30,
            num_tokens=10,
            lb=-bound, ub=bound)
    #suffix
    PX_suffix = ProjX(data_type, device,
            text_emb_dim=768,
            proj_dim=30,
            num_tokens=10,
            lb=-bound, ub=bound)

    lr = 1e-3
    weight_decay = 0
    optimizer = torch.optim.Adam(list(PX.parameters()) + list(PX_suffix.parameters()), #list(lora_layers) + list(PX.parameters()) +
                                lr=lr,
                                weight_decay=weight_decay)
    scheduler_LR = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)
    scaler = torch.cuda.amp.GradScaler()

    Epochs = 50
    avg_record_steps = 1
    image_best = None
    l_avg = []
    l_avgs = []
    l_stds = []
    image_best = [None]#*len(prompt_bases)
    images_store = []
    y_best = [-100]#*len(prompt_bases)

    print("Generating a default image I0...")
    I0 = image_pipeline_default([prompt], vae, text_encoder, unet, tokenizer, scheduler)

    random.seed(seed)

    for epoch in range(Epochs):
        optimizer.zero_grad()

        with torch.autocast(device_type='cuda', dtype=torch.float16):

            #Some default params
            height = 512                        # default height of Stable Diffusion
            width = 512                         # default width of Stable Diffusion
            num_inference_steps = 50            # Number of denoising steps
            guidance_scale = 7.5                # Scale for classifier-free guidance
            generator = torch.manual_seed(seed)   # Seed generator to create the inital latent noise


            #Load Models
            vae = vae.to(device)
            text_encoder = text_encoder.to(device)
            unet = unet.to(device)

            #Text Embeddings
            l_image = 0
            prompt_bases = [[prompt]]
            batch_size = len(prompt_bases[0])
            image_init = []

            for count, prompt_base in enumerate(prompt_bases):
                text_input = tokenizer(prompt_base, padding=False, max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
                with torch.no_grad():
                    text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
                    s_o = torch.mean(text_embeddings,dim=1)

                proj_xs = PX()
                proj_xs_suffix = PX_suffix()
                L = text_input.input_ids.shape[-1]
                
                text_embeddings = torch.cat((proj_xs.repeat(batch_size,1,1),
                                            text_embeddings,
                                            proj_xs_suffix.repeat(batch_size,1,1)),dim=1)   
                sum_penalty = torch.sum(torch.abs(torch.mean(text_embeddings,dim=1)-s_o))/text_embeddings.shape[-1]

                max_length = text_input.input_ids.shape[-1]+proj_xs_suffix.shape[1]+proj_xs.shape[1]
                if max_length > 77:
                    flag = False
                    continue
                uncond_input = tokenizer(
                    [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
                )
                with torch.no_grad():
                    uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]

                text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

                #Latent Space
                latents = torch.randn(
                    (batch_size, unet.config.in_channels, height // 8, width // 8),
                    generator=generator,
                    dtype=inference_data_type #data_type
                    )
                latents = latents.to(device)

                scheduler.set_timesteps(num_inference_steps)
                latents = latents * scheduler.init_noise_sigma

                timesteps = scheduler.timesteps
                for i, t in tqdm(enumerate(timesteps), total=len(timesteps), disable=True):
                    # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
                    latent_model_input = torch.cat([latents] * 2)

                    latent_model_input = scheduler.scale_model_input(latent_model_input, t)

                    # predict the noise residual
                    if i < 48:
                        # with torch.no_grad():
                        noise_pred = (unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample).detach()
                        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    else:
                        noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
                        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

                    # perform guidance
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                    # compute the previous noisy sample x_t -> x_t-1
                    latents = scheduler.step(noise_pred, t, latents).prev_sample

                # scale and decode the image latents with vae
                latents = 1 / 0.18215 * latents

                image = vae.decode(latents).sample
                image = (image / 2 + 0.5).clamp(0, 1)
                if epoch == 0:
                    image_init.append(image.detach().to('cpu'))

                l_aes = 0 #loss_fn(image)
                # l_hps = 0
                _, l_hps = loss_fn_hps(image, prompt_base)
                l_clip = 0 #loss_fn_clip(image, prompt_base).squeeze()
                l_curr = l_aes + l_hps + l_clip
                l_curr = torch.mean(l_curr)
                l_image += l_curr
                if l_curr.item() > y_best[count]:
                    y_best[count] = l_curr.item()
                    print('Epoch: ', epoch, "Loss: ", l_image.item(), sum_penalty, prompt_base[0])
                    image_best[count] = image.detach().to('cpu')
                    images_store.append(image.detach().to('cpu'))
                    torch.save({
                        'epoch': epoch,
                        'prompt': prompt_base,
                        'initial_image': I0.detach().to('cpu'),
                        'initial_score': loss_fn_hps(I0, prompt_base)[1].squeeze(),
                        'prefix': proj_xs.detach().to('cpu'),
                        'suffix': proj_xs_suffix.detach().to('cpu'),
                        'loss_aes': l_aes,
                        'loss_hps': l_hps,
                        'loss_clip': l_clip,
                        'Image_Sequences': images_store,
                        'l_avgs': l_avgs,
                        'l_stds': l_stds,
                    }, INDIVIDUAL_TRAINING_PATH+'/prompt'+str(k)+'_checkpoint.pt')
            if flag == False:
                continue
            loss = -l_image/len(prompt_bases) + 0.001*sum_penalty
            if l_curr.item() > y_best[count]:
                y_best[count] = l_curr.item()

                print('Epoch: ', epoch, "Loss: ", l_image.item(), sum_penalty, prompt_base[0])


            l_avg.append(l_image.item())
            if epoch%avg_record_steps == avg_record_steps-1:
                l_avg = np.array(l_avg)
                # print(np.mean(l_avg), np.std(l_avg))
                l_avgs.append(np.mean(l_avg))
                l_stds.append(np.std(l_avg))
                l_avg = []

        scaler.scale(loss).backward()
        # Unscales the gradients of optimizer's assigned params in-place
        scaler.unscale_(optimizer)

        # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
        torch.nn.utils.clip_grad_norm_(list(PX.parameters()) + list(PX_suffix.parameters()), 1)
        scaler.step(optimizer)
        scaler.update()

        scheduler_LR.step()

In [None]:
#Batch Training
BATCH_TRAINING_PATH = '/content/Batch_Training'
if not os.path.exists(BATCH_TRAINING_PATH):
    os.makedirs(BATCH_TRAINING_PATH)
seed = 0
Epochs = 20
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
accumulation_steps = 4

bound = 1

PX = ProjX(data_type, device,
        text_emb_dim=768,
        proj_dim=30,
        num_tokens=10,
        lb=-bound, ub=bound)
PX_suffix = ProjX(data_type, device,
        text_emb_dim=768,
        proj_dim=30,
        num_tokens=10,
        lb=-bound, ub=bound)

lr = 1e-3
weight_decay = 0
optimizer = torch.optim.Adam(list(PX.parameters()) + list(PX_suffix.parameters()), #list(lora_layers) + list(PX.parameters()) +
                                    lr=lr,
                                    weight_decay=weight_decay)
scheduler_LR = torch.optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.95)
scaler = torch.cuda.amp.GradScaler()
l_per_generation = []

for epoch in range(Epochs):
    for k, prompt in enumerate(prompt_bases_tot):
        print('Prompt #'+str(k))
        flag = True

        random.seed(seed)

        with torch.autocast(device_type='cuda', dtype=torch.float16):

            #Some default params
            height = 512                        # default height of Stable Diffusion
            width = 512                         # default width of Stable Diffusion
            num_inference_steps = 50            # Number of denoising steps
            guidance_scale = 7.5                # Scale for classifier-free guidance
            generator = torch.manual_seed(seed)   # Seed generator to create the inital latent noise


            #Load Models
            vae = vae.to(device)
            text_encoder = text_encoder.to(device)
            unet = unet.to(device)

            #Text Embeddings
            l_image = 0
            prompt_bases = [[prompt]]
            batch_size = len(prompt_bases[0])
            image_init = []

            for count, prompt_base in enumerate(prompt_bases):
                text_input = tokenizer(prompt_base, padding=False, max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
                with torch.no_grad():
                    text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
                    s_o = torch.mean(text_embeddings,dim=1)

                proj_xs = PX()
                proj_xs_suffix = PX_suffix()
                L = text_input.input_ids.shape[-1]
                text_embeddings = torch.cat((proj_xs.repeat(batch_size,1,1),
                                            text_embeddings,
                                            proj_xs_suffix.repeat(batch_size,1,1)),dim=1)
            
                sum_penalty = torch.sum(torch.abs(torch.mean(text_embeddings,dim=1)-s_o))/text_embeddings.shape[-1]

                max_length = text_input.input_ids.shape[-1]+proj_xs_suffix.shape[1]+proj_xs.shape[1]
                if max_length > 77:
                    flag = False
                    continue
                uncond_input = tokenizer(
                    [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
                )
                with torch.no_grad():
                    uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]

                text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

                #Latent Space
                latents = torch.randn(
                    (batch_size, unet.config.in_channels, height // 8, width // 8),
                    generator=generator,
                    dtype=inference_data_type #data_type
                    )
                latents = latents.to(device)

                scheduler.set_timesteps(num_inference_steps)
                latents = latents * scheduler.init_noise_sigma

                timesteps = scheduler.timesteps
                for i, t in tqdm(enumerate(timesteps), total=len(timesteps), disable=True):
                    # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
                    latent_model_input = torch.cat([latents] * 2)

                    latent_model_input = scheduler.scale_model_input(latent_model_input, t)

                    # predict the noise residual
                    if i < 48:
                        # with torch.no_grad():
                        noise_pred = (unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample).detach()
                        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    else:
                        noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
                        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

                    # perform guidance
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                    # compute the previous noisy sample x_t -> x_t-1
                    latents = scheduler.step(noise_pred, t, latents).prev_sample

                # scale and decode the image latents with vae
                latents = 1 / 0.18215 * latents

                image = vae.decode(latents).sample
                image = (image / 2 + 0.5).clamp(0, 1)
                if epoch == 0:
                    image_init.append(image.detach().to('cpu'))

                l_aes = 0 #loss_fn(image)
                # l_hps = 0
                _, l_hps = loss_fn_hps(image, prompt_base)
                # l_clip = loss_fn_clip(image, prompt_base).squeeze()
                l_curr = l_aes + l_hps + l_clip
                l_curr = torch.mean(l_curr)
                l_image += l_curr

            l_per_generation.append(l_clip.item())
            if flag == False:
                continue
            loss = -l_image/len(prompt_bases) + 0.001*sum_penalty
            loss = loss / accumulation_steps
            print('Epoch: ', epoch, "Loss: ", l_image.item(), sum_penalty, prompt_base[0])


        scaler.scale(loss).backward()
        if ((k + 1) % accumulation_steps == 0) or (k + 1 == num_prompts):
            torch.save({
                    'epoch': epoch,
                    'k+1': k+1,
                    'text_encoder': text_encoder.state_dict(),
                    'prefix_model_state_dict': PX.state_dict(),
                    'suffix_model_state_dict': PX_suffix.state_dict(),
                    'prefix': proj_xs.detach().to('cpu'),
                    'suffix': proj_xs_suffix.detach().to('cpu'),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss_aes': l_aes,
                    'loss_hps': l_hps,
                    'loss_clip': l_clip,
                    'l_per_generation': l_per_generation
                }, BATCH_TRAINING_PATH+'/checkpoint.pt')

            # Unscales the gradients of optimizer's assigned params in-place
            scaler.unscale_(optimizer)

            # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
            torch.nn.utils.clip_grad_norm_(list(PX.parameters()) + list(PX_suffix.parameters()), 1)
            scaler.step(optimizer)
            scaler.update()

            scheduler_LR.step()
            optimizer.zero_grad()
    print("Epoch Total Stat: ", np.mean(l_per_generation[-num_prompts:]))
    print()