In [None]:
import argparse, os, sys, glob
from collections import defaultdict
from ast import parse
import cv2
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
import time
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext
import matplotlib.pyplot as plt

from structured_stable_diffusion.util import instantiate_from_config
from structured_stable_diffusion.models.diffusion.ddim import DDIMSampler
from structured_stable_diffusion.models.diffusion.plms import PLMSSampler


import stanza
from nltk.tree import Tree
nlp = stanza.Pipeline(lang='en', processors='tokenize,pos,constituency')
import pdb
import json
import sng_parser

seed = 42
seed_everything(seed)


def preprocess_prompts(prompts):
    if isinstance(prompts, (list, tuple)):
        return [p.lower().strip().strip(".").strip() for p in prompts]
    elif isinstance(prompts, str):
        return prompts.lower().strip().strip(".").strip()
    else:
        raise NotImplementedError


def get_all_nps(tree, full_sent=None):
    start = 0
    end = len(tree.leaves())

    def get_sub_nps(tree, left, right):
        if isinstance(tree, str) or len(tree.leaves()) == 1:
            return []
        sub_nps = []
        n_leaves = len(tree.leaves())
        n_subtree_leaves = [len(t.leaves()) for t in tree]
        offset = np.cumsum([0] + n_subtree_leaves)[:len(n_subtree_leaves)]
        assert right - left == n_leaves
        if tree.label() == 'NP' and n_leaves > 1:
            sub_nps.append([" ".join(tree.leaves()), (int(left), int(right))])
        for i, subtree in enumerate(tree):
            sub_nps += get_sub_nps(subtree, left=left+offset[i], right=left+offset[i]+n_subtree_leaves[i])
        return sub_nps
    
    all_nps = get_sub_nps(tree, left=start, right=end)
    lowest_nps = []
    for i in range(len(all_nps)):
        span = all_nps[i][1]
        lowest = True
        for j in range(len(all_nps)):
            span2 = all_nps[j][1]
            if span2[0] >= span[0] and span2[1] <= span[1]:
                lowest = False
                break
        if lowest:
            lowest_nps.append(all_nps[i])

    all_nps, spans = map(list, zip(*all_nps))
    if full_sent and full_sent not in all_nps:
        all_nps = [full_sent] + all_nps
        spans = [(start, end)] + spans

    return all_nps, spans, lowest_nps


def get_all_spans_from_scene_graph(caption):
    caption = caption.strip()
    graph = sng_parser.parse(caption)
    nps = []
    spans = []
    words = caption.split()
    for e in graph['entities']:
        start, end = e['span_bounds']
        if e['span'] == caption: continue
        if end-start == 1: continue
        nps.append(e['span'])
        spans.append(e['span_bounds'])
    for r in graph['relations']:
        start1, end1 = graph['entities'][r['subject']]['span_bounds']
        start2, end2 = graph['entities'][r['object']]['span_bounds']
        start = min(start1, start2)
        end = max(end1, end2)
        if " ".join(words[start:end]) == caption: continue
        nps.append(" ".join(words[start:end]))
        spans.append((start, end))
    
    return [caption] + nps, [(0, len(words))] + spans, None



def expand_sequence(seq, length, dim=1):
    seq = seq.transpose(0, dim)
    max_length = seq.size(0)
    n_repeat = (max_length - 2) // length
    repeat_size = (n_repeat,) + (1, ) * (len(seq.size()) -1)

    eos = seq[length+1, ...].clone()
    segment = seq[1:length+1, ...].repeat(*repeat_size)
    seq[1:len(segment)+1] = segment
    seq[len(segment)+1] = eos

    return seq.transpose(0, dim)


def align_sequence(main_seq, seq, span, eos_loc, dim=1, zero_out=False, replace_pad=False):
    seq = seq.transpose(0, dim)
    main_seq = main_seq.transpose(0, dim)
    start, end = span[0]+1, span[1]+1
    seg_length = end - start

    main_seq[start:end] = seq[1:1+seg_length]
    if zero_out:
        main_seq[1:start] = 0
        main_seq[end:eos_loc] = 0

    if replace_pad:
        pad_length = len(main_seq) - eos_loc
        main_seq[eos_loc:] = seq[1+seg_length:1+seg_length+pad_length]
    

    return main_seq.transpose(0, dim)


def get_actions(tree, SHIFT = 0, REDUCE = 1, OPEN='(', CLOSE=')'):
    #input tree in bracket form: ((A B) (C D))
    #output action sequence: S S R S S R R
    actions = []
    tree = tree.strip()
    i = 0
    num_shift = 0
    num_reduce = 0
    left = 0
    right = 0
    while i < len(tree):
        if tree[i] != ' ' and tree[i] != OPEN and tree[i] != CLOSE: #terminal
            if tree[i-1] == OPEN or tree[i-1] == ' ':
                actions.append(SHIFT)
                num_shift += 1
        elif tree[i] == CLOSE:
            actions.append(REDUCE)
            num_reduce += 1
            right += 1
        elif tree[i] == OPEN:
            left += 1
        i += 1
    pdb.set_trace()
    assert(num_shift == num_reduce + 1)
    return actions


def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())


def numpy_to_pil(images):
    """
    Convert a numpy image or a batch of images to a PIL image.
    """
    if images.ndim == 3:
        images = images[None, ...]
    images = (images * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]

    return pil_images


def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model

def load_replacement(x):
    try:
        hwc = x.shape
        y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
        y = (np.array(y)/255.0).astype(x.dtype)
        assert y.shape == x.shape
        return y
    except Exception:
        return x

# generating images from pre-trained model
def sampling(model, sampler, prompt, n_samples, scale=7.5, steps=50, conjunction=False):
    H = W = 512
    C = 4
    f = 8
    precision_scope = autocast
    with torch.no_grad():
        with precision_scope("cuda"):
            with model.ema_scope():
                all_samples = list()
                for n in trange(n_samples, desc="Sampling"):
                    for bid, p in enumerate(prompt):
                        p = preprocess_prompts(p)

                        uc = model.get_learned_conditioning([""])
                        if isinstance(p, tuple):
                            p = list(p)
                        c = model.get_learned_conditioning(p)
                        
                        doc = nlp(p[0])
                        mytree = Tree.fromstring(str(doc.sentences[0].constituency))
                        nps, spans, noun_chunk = get_all_nps(mytree, p[0])
                        # nps, spans, noun_chunk = get_all_spans_from_scene_graph(prompts[0].split("\t")[0])

                        nps_length = [len(ids)-2 for ids in model.cond_stage_model.tokenizer(nps).input_ids]
                        nps = [[np]*len(p) for np in nps]
                        
                        c = [model.get_learned_conditioning(np) for np in nps]
                        k_c = [c[0]] + [align_sequence(c[0].clone(), seq, span, nps_length[0]+1) for seq, span in zip(c[1:], spans[1:])]
                        v_c = [c[0]] + [align_sequence(c[0].clone(), seq, span, nps_length[0]+1) for seq, span in zip(c[1:], spans[1:])]
                        
                        if not conjunction:
                            c = {'k': k_c[:1], 'v': v_c}
                        else:
                            c = {'k': k_c, 'v': v_c[-1:]}

                        shape = [C, H // f, W // f]
                        samples_ddim, _ = sampler.sample(S=steps,
                                                            conditioning=c,
                                                            batch_size=1,
                                                            shape=shape,
                                                            verbose=False,
                                                            unconditional_guidance_scale=scale,
                                                            unconditional_conditioning=uc,
                                                            eta=0.0,
                                                            x_T=None,
                                                            quiet=True)

                        x_samples_ddim = model.decode_first_stage(samples_ddim)
                        x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
                        x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()

                        x_checked_image = x_samples_ddim

                        x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)

                        all_samples.append(x_checked_image_torch)
    return all_samples


2022-11-15 12:11:06 INFO: Checking for updates to resources.json in case models have been updated.  Note: this behavior can be turned off with download_method=None or download_method=DownloadMethod.REUSE_RESOURCES
Downloading https://raw.githubusercontent.com/stanfordnlp/stanza-resources/main/resources_1.4.1.json: 193kB [00:00, 79.0MB/s]                    
2022-11-15 12:11:07 INFO: Loading these models for language: en (English):
| Processor    | Package  |
---------------------------
| tokenize     | combined |
| pos          | combined |
| constituency | wsj      |

2022-11-15 12:11:07 INFO: Use device: gpu
2022-11-15 12:11:07 INFO: Loading: tokenize
2022-11-15 12:11:07 INFO: Loading: pos
2022-11-15 12:11:07 INFO: Loading: constituency
2022-11-15 12:11:08 INFO: Done loading processors!


In [None]:
import os
import sys
import torch
import torchvision.transforms as TS

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
current_dir = os.getcwd()

parent_dir = os.path.dirname(current_dir)

sys.path.insert(0, parent_dir)

grounding_dino_path = os.path.join(parent_dir, 'GroundingDINO')
tag2text_path = os.path.join(parent_dir, 'Tag2Text')
fastchat_path = os.path.join(parent_dir, 'fastchat')

sys.path.insert(0, grounding_dino_path)
sys.path.insert(0, tag2text_path)
sys.path.insert(0, fastchat_path)
print(f"Updated sys.path: {sys.path}")

import GroundingDINO.groundingdino.datasets.transforms as T
from GroundingDINO.groundingdino.models import build_model
from GroundingDINO.groundingdino.util import box_ops
from GroundingDINO.groundingdino.util.slconfig import SLConfig
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap


# from segment_anything import build_sam, SamPredictor

from Tag2Text.models import tag2text
from Tag2Text import inference

config_file = os.path.join(parent_dir, "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py")
tag2text_checkpoint = os.path.join(parent_dir, 'weight/tag2text_swin_14m.pth')
grounded_checkpoint = os.path.join(parent_dir, 'weight/groundingdino_swint_ogc.pth')
sam_checkpoint = os.path.join(parent_dir, 'weight/sam_vit_h_4b8939.pth')


assert os.path.isfile(tag2text_checkpoint), f"Checkpoint file {tag2text_checkpoint} not found."
assert os.path.isfile(grounded_checkpoint), f"Checkpoint file {grounded_checkpoint} not found."
assert os.path.isfile(sam_checkpoint), f"Checkpoint file {sam_checkpoint} not found."

normalize = TS.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
transform = TS.Compose([
    TS.Resize((384, 384)),
    TS.ToTensor(), normalize
])

delete_tag_index = [i for i in range(3012, 3429)]
tag2text_model = tag2text.tag2text_caption(pretrained=tag2text_checkpoint,
                                           image_size=384,
                                           vit='swin_b',
                                           delete_tag_index=delete_tag_index)
tag2text_model.threshold = 0.64
tag2text_model.eval()
tag2text_model = tag2text_model.to(device)

specified_tags = 'None'

from fastchat.serve.inference import ChatIO, generate_stream, load_model
from fastchat.conversation import get_default_conv_template

class SimpleChatIO(ChatIO):
    def prompt_for_input(self, role, prompt, tags) -> str:
        return f"Please assess the described scene based on the provided prompt and determine the likelihood of each tag appearing in the scene. Assign a score to each tag according to the following criteria:  If a tag is certain to appear, assign a score of 3. If a tag may appear, assign a score of 2. If a tag is unlikely to appear, assign a score of 1.\nPrompt: {prompt}; Tags: {tags}"

    def prompt_for_output(self, role: str):
        print(f"{role}: ", end="", flush=True)

    def stream_output(self, output_stream):
        pre = 0
        output = ''
        for outputs in output_stream:
            outputs = outputs.strip().split(" ")
            now = len(outputs) - 1
            if now > pre:
                output += " ".join(outputs[pre:now])
                pre = now
        output += " ".join(outputs[pre:])
        return output

chatio = SimpleChatIO()
vicuna_path = "lmsys/vicuna-7b-v1.5"
vicuna, tokenizer = load_model(vicuna_path, "cuda", 1, None, False, False, False)
conv = get_default_conv_template(vicuna_path)


In [None]:
import os
import sys
import torch
import torchvision.transforms as TS
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import clip

current_dir = os.getcwd()

parent_dir = os.path.dirname(current_dir)

data_path = os.path.join(parent_dir, 'data/coco_train.npy')

assert os.path.isfile(data_path), f"Data file {data_path} not found."

# load dataset
coco_train = np.load(data_path, allow_pickle=True).tolist()
length = len(coco_train)
ftprompts = pd.DataFrame({'prompts': coco_train[:length]})

class MSCOCODataset():
    def __init__(self):
        global ftprompts
        self.ftprompts = ftprompts.iloc[:, 0]

    def __len__(self):
        return len(self.ftprompts)

    def __getitem__(self, index):

        prompt = self.ftprompts.iloc[index]
        if not isinstance(prompt, str):
            prompt = str(prompt)

        prompt = prompt.strip()
        return prompt


batch_size = 8
dataset = MSCOCODataset()
finetune_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

# load CLIP model
clip_model, preprocess = clip.load('ViT-B/32', device=device)
# params = torch.load("../hpc.pt")['state_dict']
# clip_model.load_state_dict(params)

# load blip-2model
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
caption = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
caption.to(device)


In [None]:
def get_sam_score(prompt, tags, length):
    inp = chatio.prompt_for_input(conv.roles[0], prompt, tags)
    conv.append_message(conv.roles[0], inp)
    conv.append_message(conv.roles[1], None)
    generate_stream_func = generate_stream
    prompt = conv.get_prompt()

    gen_params = {
        "model": vicuna_path,
        "prompt": prompt,
        "temperature": 0.7,
        "max_new_tokens": 512,
        "stop": conv.stop_str,
        "stop_token_ids": conv.stop_token_ids,
        "echo": False,
    }

    output_stream = generate_stream_func(vicuna, tokenizer, gen_params, device)
    outputs = chatio.stream_output(output_stream)
    sam_score = 0

    for i in range(len(outputs)):
        if outputs[i].isdigit():
            sam_score += int(outputs[i])
    return (sam_score - 2 * length) / (2 * length)


def get_tags(image_pil):
    raw_image = image_pil.resize((384, 384))
    raw_image = transform(raw_image).unsqueeze(0).to(device)
    res = inference.inference(raw_image, tag2text_model, specified_tags)
    text_prompt = res[0].replace(' |', ',')
    length = len(text_prompt.split(','))
    return text_prompt, length


def get_image_score(image, prompt):
    with torch.no_grad():
        text = clip.tokenize([prompt]).to(device)
        text_features = clip_model.encode_text(text)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        text_features = text_features.to(torch.float16)

        tags, length = get_tags(image)
        sam_score = get_sam_score(prompt, tags, length)

        inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
        generated_ids = caption.generate(**inputs)
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()

        captext = clip.tokenize([generated_text]).to(device)
        caption_features = clip_model.encode_text(captext)
        caption_features /= caption_features.norm(dim=-1, keepdim=True)

        cos_sim = torch.cosine_similarity(text_features, caption_features, dim=1)
        reward = float(cos_sim) + float(sam_score)
        return reward


def get_cap_reward(image, prompt):
    with torch.no_grad():
        text = clip.tokenize([prompt]).to(device)
        text_features = clip_model.encode_text(text)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        text_features = text_features.to(torch.float16)

        inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
        generated_ids = caption.generate(**inputs)
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()

        captext = clip.tokenize([generated_text]).to(device)
        caption_features = clip_model.encode_text(captext)
        caption_features /= caption_features.norm(dim=-1, keepdim=True)

        cos_sim = torch.cosine_similarity(text_features, caption_features, dim=1)
        reward = float(cos_sim)
        return reward


def get_sam_reward(image, prompt):
    with torch.no_grad():
        tags, length = get_tags(image)
        sam_score = get_sam_score(prompt, tags, length)
        reward = float(sam_score)
        return reward


def get_max_score(prompt_list, image_list, index, epoch=0, ARL=True):
    if ARL:
        cap_score_list = []
        sam_score_list = []

        print(f"Length of image_list: {len(image_list)}")
        print(f"Length of prompt_list: {len(prompt_list)}")

        # Iterate through each prompt and its corresponding set of images
        for prompt_idx, (prompt, img_set) in enumerate(zip(prompt_list, image_list)):
            print(f"Processing image set {prompt_idx}, length of image set {len(img_set)}")
            # Iterate through all generated images for the current prompt
            for img_idx, image in enumerate(img_set):
                # Compute individual rewards
                cap_score = get_cap_reward(image, prompt)
                sam_score = get_sam_reward(image, prompt)

                # Store the score along with its prompt/image indices
                cap_score_list.append((cap_score, prompt_idx, img_idx))
                sam_score_list.append((sam_score, prompt_idx, img_idx))

        # Sort by descending scores; higher scores are better
        cap_rankings = sorted(range(len(cap_score_list)), key=lambda x: cap_score_list[x][0], reverse=True)
        sam_rankings = sorted(range(len(sam_score_list)), key=lambda x: sam_score_list[x][0], reverse=True)

        # Rank by the sum of individual rankings
        total_rankings = [
            (cap_rankings.index(i) + sam_rankings.index(i), cap_score_list[i][1], cap_score_list[i][2])
            for i in range(len(cap_score_list))
        ]

        # Select the item with the best (lowest) combined rank
        best_ranking = min(total_rankings, key=lambda x: x[0])
        best_total_rank = best_ranking[0]
        best_prompt_idx = best_ranking[1]
        best_img_idx = best_ranking[2]

        # Optionally retrieve the best score
        best_cap_score = cap_score_list[cap_rankings[best_prompt_idx]][0]
        best_sam_score = sam_score_list[sam_rankings[best_prompt_idx]][0]
        best_score = best_cap_score + best_sam_score

        ftprompts.loc[index, f'Epoch{epoch} Scores'] = best_score

        return [best_score, best_prompt_idx, best_img_idx]
    else:
        score_list = []
        for i in range(len(prompt_list)):
            score_list.append(get_image_score(image_list[i], prompt_list[i]))
        torch.cuda.empty_cache()
        ftprompts.loc[index, f'Epoch{epoch} Scores'] = max(score_list)
        return [max(score_list), score_list.index(max(score_list))]


In [None]:
#@title Settings for the model

#@markdown All settings have been configured to achieve optimal outputorch. Changing them is not advisable.

#@markdown Enter value for `resolution`.
resolution=512 #@param {type:"integer"}

#@markdown Enter value for `num_images_per_prompt`.
num_images_per_prompt=10 #@param {type:"integer"}

#@markdown Enter value for `epochs`.
epochs=2 #@param {type:"integer"} |

#@markdown Enter value for `seed`.
#generator = torch.Generator(device=device).manual_seed(seed)

In [7]:
ckpt = "../models/ldm/stable-diffusion-v1/model.ckpt"

config = OmegaConf.load("../configs/stable-diffusion/v1-inference.yaml")
model = load_model_from_config(config, f"{ckpt}")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)

sampler = PLMSSampler(model)

Loading model from ../models/ldm/stable-diffusion-v1/sd-v1-4.ckpt
Global Step: 470000
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 859.52 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels


Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPTextModel: ['vision_model.encoder.layers.17.layer_norm2.bias', 'vision_model.encoder.layers.3.self_attn.out_proj.bias', 'vision_model.encoder.layers.14.layer_norm1.bias', 'vision_model.encoder.layers.0.self_attn.v_proj.bias', 'vision_model.encoder.layers.9.self_attn.k_proj.bias', 'vision_model.encoder.layers.23.self_attn.out_proj.bias', 'vision_model.encoder.layers.7.self_attn.out_proj.bias', 'vision_model.encoder.layers.2.layer_norm1.bias', 'vision_model.encoder.layers.13.mlp.fc2.bias', 'vision_model.encoder.layers.16.self_attn.k_proj.bias', 'vision_model.encoder.layers.20.layer_norm2.weight', 'vision_model.encoder.layers.21.self_attn.v_proj.bias', 'vision_model.encoder.layers.10.self_attn.out_proj.bias', 'vision_model.encoder.layers.21.self_attn.v_proj.weight', 'vision_model.encoder.layers.16.layer_norm1.bias', 'vision_model.encoder.layers.20.self_attn.k_proj.weight', 'vision_mod

In [None]:
import shutil
import random
import pandas as pd
import numpy as np
import torch
import torch.utils.checkpoint
import concurrent
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from PIL import Image

#os.environ['MODEL_NAME'] = model_id
os.environ['OUTPUT_DIR'] = f"./CustomModel/"
os.environ['TOKENIZERS_PARALLELISM'] = "false"
topk = length
training_steps_per_epoch = topk * 10
os.environ['CHECKPOINTING_STEPS'] = str(training_steps_per_epoch)
os.environ['RESOLUTION'] = str(resolution)
os.environ['LEARNING_RATE'] = str(9e-6)

try:
    shutil.rmtree('./CustomModel')
except:
    pass
try:
    shutil.rmtree('./trainingdataset/imagefolder/')
except:
    pass

total = 0

for epoch in range(epochs + 1):
    print("Epoch: ", epoch)
    # training step
    training_steps = str(training_steps_per_epoch * (epoch + 1))
    os.environ['TRAINING_STEPS'] = training_steps
    os.environ['TRAINING_DIR'] = f'./trainingdataset/imagefolder/{epoch}'

    training_prompts = []

    ftprompts[f'Epoch{epoch} Scores'] = np.nan

    for step, prompt_list in enumerate(finetune_dataloader):
        print(prompt_list)

        image_list = []
        for prompt in prompt_list:
            print(prompt)
            all_samples = sampling(model, sampler, [[prompt]], num_images_per_prompt, scale=7.5, steps=50)


            images = []

            for sample in all_samples:
                for img in sample:
                    img = img * 255.0
                    img = img.permute(1, 2, 0).cpu().numpy()
                    img = Image.fromarray(img.astype(np.uint8))

                    images.append(img)


            image_list.append(images)
        torch.cuda.empty_cache()



        step_list = [i for i in range(step * batch_size, (step + 1) * batch_size)]
        for idx in step_list:
            best_score, best_prompt_idx, best_img_idx = get_max_score(prompt_list, image_list, idx, epoch)
            print(f"best_score: {best_score}, best_prompt_idx: {best_prompt_idx}, best_img_idx: {best_img_idx}")

            best_image = image_list[best_prompt_idx][best_img_idx]
            best_prompt = prompt_list[best_prompt_idx]

            training_prompts.append([best_score, best_image, best_prompt])



    training_prompts = [row[1:3] for row in sorted(training_prompts, key=lambda x: x[0], reverse=True)[:topk]]
    training_prompts = pd.DataFrame(training_prompts)

    if not os.path.exists(f"./trainingdataset/imagefolder/{epoch}/train/"):
        os.makedirs(f"./trainingdataset/imagefolder/{epoch}/train/")
    if not os.path.exists(f"./CustomModel/"):
        os.makedirs(f"./CustomModel/")

    for i in range(len(training_prompts)):
        training_prompts.iloc[i, 0].save(f'./trainingdataset/imagefolder/{epoch}/train/{i}.png')

    training_prompts['file_name'] = [f"{i}.png" for i in range(len(training_prompts))]
    training_prompts.columns = ['image', 'text', 'file_name']
    training_prompts.drop('image', axis=1, inplace=True)
    training_prompts.to_csv(f'./trainingdataset/imagefolder/{epoch}/train/metadata.csv', index=False)

    # start training
    if epoch < epochs:
        !accelerate launch --num_processes=1 --mixed_precision='fp16' --dynamo_backend='no' --num_machines=1 train_lora.py \
            --pretrained_model_name_or_path=$MODEL_NAME \
            --train_data_dir=$TRAINING_DIR \
            --resolution=$RESOLUTION \
            --train_batch_size=8 \
            --gradient_accumulation_steps=1 \
            --gradient_checkpointing \
            --max_grad_norm=1 \
            --mixed_precision="fp16" \
            --max_train_steps=$TRAINING_STEPS \
            --learning_rate=$LEARNING_RATE \
            --lr_warmup_steps=0 \
            --enable_xformers_memory_efficient_attention \
            --dataloader_num_workers=5 \
            --output_dir=$OUTPUT_DIR \
            --lr_warmup_steps=0 \
            --seed=1234 \
            --checkpointing_steps=$CHECKPOINTING_STEPS \
            --resume_from_checkpoint="latest" \
            --lr_scheduler='constant'

    torch.cuda.empty_cache()
