## Initial Setup

In [None]:
# #@title Install the required libs
# %pip install -U -qq git+https://github.com/huggingface/diffusers.git
# %pip install -q accelerate transformers ftfy fairscale bitsandbytes gradio natsort safetensors xformers datasets pytorch_lightning timm 
# %pip install -qq "ipywidgets>=7,<8"
# !pip install git+https://github.com/openai/CLIP.git

In [None]:
#@title Import required libraries
import shutil
import random
import pandas as pd
import numpy as np
import torch
import torch.utils.checkpoint
import concurrent
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from PIL import Image
import torch.nn as nn
#import bitsandbytes as bnb
from torch.utils.data import DataLoader
def image_grid(imgs, rows, cols):
    assert len(imgs) == rows * cols

    w, h = imgs[0].size
    grid_w = cols * w + (cols - 1)  # 计算网格宽度，包括图像间隔
    grid_h = rows * h + (rows - 1)  # 计算网格高度，包括图像间隔

    grid = Image.new('RGBA', size=(grid_w, grid_h))
    transparent_pixel = (0, 0, 0, 0)  # 定义透明像素的颜色

    for i, img in enumerate(imgs):
        x = (w + 1) * (i % cols)  # 计算当前图像的x坐标，包括图像间隔
        y = (h + 1) * (i // cols)  # 计算当前图像的y坐标，包括图像间隔

        for dx in range(w):
            for dy in range(h):
                pixel = img.getpixel((dx, dy))
                grid.putpixel((x + dx, y + dy), pixel)  # 将图像像素粘贴到网格中

        # 添加图像间隔列
        if i % cols < cols - 1:
            for dy in range(h):
                grid.putpixel((x + w, y + dy), transparent_pixel)

        # 添加图像间隔行
        if i // cols < rows - 1:
            for dx in range(w):
                grid.putpixel((x + dx, y + h), transparent_pixel)

    return grid
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
import tomesd
import os 
device = torch.device("cuda")

In [None]:
import os
import numpy as np
# Grounding DINO
import sys
sys.path.append('GroundingDINO')
import GroundingDINO.groundingdino.datasets.transforms as T
from GroundingDINO.groundingdino.models import build_model
from GroundingDINO.groundingdino.util import box_ops
from GroundingDINO.groundingdino.util.slconfig import SLConfig
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap

# segment anything
from segment_anything import build_sam, SamPredictor 
import cv2
import numpy as np
import matplotlib.pyplot as plt

# Tag2Text
sys.path.append('Tag2Text')
from Tag2Text.models import tag2text
from Tag2Text import inference
import torchvision.transforms as TS

    

# cfg
config_file = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"  # change the path of the model config file
tag2text_checkpoint = '../tag2text_swin_14m.pth'  # change the path of the model
grounded_checkpoint = '../groundingdino_swint_ogc.pth'  # change the path of the model
sam_checkpoint = '../sam_vit_h_4b8939.pth'
split = ","
output_dir = "outputs"
box_threshold = 0.25
text_threshold = 0.2
iou_threshold = 0.5


# initialize Tag2Text
normalize = TS.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
transform = TS.Compose([
                TS.Resize((384, 384)),
                TS.ToTensor(), normalize
            ])

# filter out attributes and action categories which are difficult to grounding
delete_tag_index = []
for i in range(3012, 3429):
    delete_tag_index.append(i)

specified_tags='None'
# load model
tag2text_model = tag2text.tag2text_caption(pretrained=tag2text_checkpoint,
                                    image_size=384,
                                    vit='swin_b',
                                    delete_tag_index=delete_tag_index)
# threshold for tagging
# we reduce the threshold to obtain more tags
tag2text_model.threshold = 0.64 
tag2text_model.eval()
tag2text_model = tag2text_model.to(device)



In [None]:
import sys
sys.path.append('fastchat')
from fastchat.serve.inference import ChatIO, generate_stream,load_model
from fastchat.conversation import get_default_conv_template
default = "Please assess the described scene based on the provided prompt and determine the likelihood of each tag appearing in the scene. Assign a score to each tag according to the following criteria:  If a tag is certain to appear, assign a score of 3. If a tag may appear, assign a score of 2. If a tag is unlikely to appear, assign a score of 1."
class SimpleChatIO(ChatIO):
    def prompt_for_input(self, role,prompt,tags) -> str:
        return default+'\n'+"prompt:"+prompt+';tags:'+tags
    def prompt_for_output(self, role: str):
        print(f"{role}: ", end="", flush=True)

    def stream_output(self, output_stream):
        pre = 0
        output = ''
        for outputs in output_stream:
            outputs = outputs.strip().split(" ")
            now = len(outputs) - 1
            if now > pre:
                output = output+" ".join(outputs[pre:now])
                pre = now
        output = output+" ".join(outputs[pre:])
        return output
chatio = SimpleChatIO()
vicuna_path = "../vicuna"


vicuna, tokenizer = load_model(
        vicuna_path, "cuda", 4, None, True, True, False
    )

conv = get_default_conv_template(vicuna_path)


        

## Loading Dataset

In [None]:
# #@title Creating Dataloader

# device = "cuda" if torch.cuda.is_available() else "cpu"
# ftprompts=['airplane','automobile','bird','deer','dog','cat','frog','horse','ship','truck']  # CIFAR labels
# ftprompts = pd.DataFrame({'prompts': ftprompts}) #converting prompts list into a pandas dataframe

# class CIFAR10Dataset():
#     def __init__(self):
#         global ftprompts
#         self.ftprompts=ftprompts.iloc[:,0]
        
#     def __len__(self):
#         return len(self.ftprompts)
    
#     def __getitem__(self,index):
#         return self.ftprompts.iloc[index]

# #@markdown Please mention the batch size.
# batch_size = 5 #@param {type:"integer"}


# dataset = CIFAR10Dataset()
# finetune_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [None]:
coco_train = np.load('./data/coco_train.npy', allow_pickle=True).tolist()
coco_val = np.load('./data/coco_val.npy', allow_pickle=True).tolist()
coco_test = np.load('./data/coco_test.npy', allow_pickle=True).tolist()
human_prompts = np.load('./data/human_prompts.npy', allow_pickle=True).tolist()

length = 10000#len(coco_train)
ftprompts =  coco_train[:length]

ftprompts = pd.DataFrame({'prompts': ftprompts})
class MSCOCODataset():
    def __init__(self):
        global ftprompts
        self.ftprompts=ftprompts.iloc[:,0]
        
    def __len__(self):
        return len(self.ftprompts)
    
    def __getitem__(self,index):
        return self.ftprompts.iloc[index]

#@markdown Please mention the batch size.
batch_size = 8 #@param {type:"integer"}
dataset = MSCOCODataset()
finetune_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)


## Loading CLIP

In [None]:

import pytorch_lightning as pl
import clip
import torch.nn.functional as F



model, preprocess = clip.load('ViT-L/14', device=device)
params = torch.load("../hpc.pt")['state_dict']
model.load_state_dict(params)


from transformers import Blip2Processor, Blip2ForConditionalGeneration, Blip2Model, AutoTokenizer
import torch
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
caption = Blip2ForConditionalGeneration.from_pretrained(
    "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16
)
caption.to(device)


## Evaluating Aesthetic Score

In [None]:

def get_sam_score(prompt,tags,length):
    inp = chatio.prompt_for_input(conv.roles[0],prompt,tags)
    conv.append_message(conv.roles[0], inp)
    conv.append_message(conv.roles[1], None)
    generate_stream_func = generate_stream
    prompt = conv.get_prompt()

    gen_params = {
        "model": vicuna_path,
        "prompt": prompt,
        "temperature": 0.7,
        "max_new_tokens": 512,
        "stop": conv.stop_str,
        "stop_token_ids": conv.stop_token_ids,
        "echo": False,
    }


    output_stream = generate_stream_func(vicuna, tokenizer, gen_params, device)
    outputs = chatio.stream_output(output_stream)
    sam_score = 0

    for i in range(len(outputs)):
        if outputs[i].isdigit():
            sam_score += int(outputs[i])
    return (sam_score-2*length)/(2*length)
def get_tags(image_pil):
    raw_image = image_pil.resize(
                    (384, 384))
    raw_image  = transform(raw_image).unsqueeze(0).to(device)

    res = inference.inference(raw_image , tag2text_model, specified_tags)


    text_prompt=res[0].replace(' |', ',')

    length = len(text_prompt.split(','))
    return text_prompt,length
def get_image_score(image,prompt):    #Evaluating Scores if images
    with torch.no_grad():
   
        text = clip.tokenize([prompt]).to(device)
        text_features = model.encode_text(text)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        text_features=text_features.to(torch.float16)
 
        tags,length = get_tags(image)
        sam_score = get_sam_score(prompt,tags,length) #sam_score

        inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
        generated_ids = caption.generate(**inputs)
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
 
        captext = clip.tokenize([generated_text]).to(device)
        caption_features = model.encode_text(captext)
        caption_features /= caption_features.norm(dim=-1, keepdim=True)
        cos_sim = torch.cosine_similarity(text_features, caption_features, dim=1) #cap_score
   
        reward = float(cos_sim)+float(sam_score)
        return reward

def get_cap_reward(image,prompt): 
    with torch.no_grad():
   
        text = clip.tokenize([prompt]).to(device)
        text_features = model.encode_text(text)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        text_features=text_features.to(torch.float16)
 
        inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
        generated_ids = caption.generate(**inputs)
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
 
        captext = clip.tokenize([generated_text]).to(device)
        caption_features = model.encode_text(captext)
        caption_features /= caption_features.norm(dim=-1, keepdim=True)
        cos_sim = torch.cosine_similarity(text_features, caption_features, dim=1) #cap_score
   
        reward = float(cos_sim)
        return reward

def get_sam_reward(image,prompt): 
    with torch.no_grad():
        tags,length = get_tags(image)
        sam_score = get_sam_score(prompt,tags,length) #sam_score

        reward = float(sam_score)
        return reward    


def get_max_score(prompt_list,image_list,index,epoch=0,ARL=True):  #The get_max_score function will return prompt's image with the highest aesthetic score will be chosen for additional fine-tuning.

    if ARL:
        cap_score_list = []
        sam_score_list = []
        for i in range(len(prompt_list)):
            cap_score = get_cap_reward(image_list[i],prompt_list[i])
            sam_score = get_sam_reward(image_list[i],prompt_list[i])
            cap_score_list.append(cap_score)
            sam_score_list.append(sam_score)
            
        cap_rankings = sorted(range(len(cap_score_list)), key=lambda x: cap_score_list[x])
        sam_rankings = sorted(range(len(sam_score_list)), key=lambda x: sam_score_list[x])
        
        total_rankings = [cap + sam for cap, sam in zip(cap_rankings, sam_rankings)]
        
        ftprompts.loc[index, f'Epoch{epoch} Scores'] = min(total_rankings)
        return [min(total_rankings), total_rankings.index(min(total_rankings))]

    else:
        score_list=[]
        for i in range(len(prompt_list)):
            score_list.append(get_image_score(image_list[i],prompt_list[i]))
        torch.cuda.empty_cache()

        ftprompts.loc[index,f'Epoch{epoch} Scores']=max(score_list)
        return [max(score_list),score_list.index(max(score_list))]


## Parameters

In [None]:
#@title Settings for the model

#@markdown All settings have been configured to achieve optimal outputorch. Changing them is not advisable.

#@markdown Enter value for `resolution`.
resolution=256 #@param {type:"integer"}

#@markdown Enter value for `num_images_per_prompt`.
num_images_per_prompt=10 #@param {type:"integer"} 

#@markdown Enter value for `epochs`. 
epochs=2 #@param {type:"integer"} |

#@markdown Enter value for `seed`.
generator = torch.Generator(device=device).manual_seed(seed)

In [None]:
# @title Setting Stable Diffusion pipeline
model_id = "runwayml/stable-diffusion-v1-5"
def get_pipe(amp=True):
    if amp:
        pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)
    else:
        pipe = StableDiffusionPipeline.from_pretrained(model_id).to(device)
    
    pipe.enable_xformers_memory_efficient_attention()
    torch.cuda.empty_cache()

    #@markdown Check the `set_progress_bar_config` option if you would like to hide the progress bar for image generation
    set_progress_bar_config= True #@param {type:"boolean"}
    pipe.set_progress_bar_config(disable=set_progress_bar_config) 


    scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
    pipe.scheduler = scheduler
    pipe.safety_checker = None
    return pipe
pipe = get_pipe()

In [None]:
#@title Run training
os.environ['MODEL_NAME'] = model_id
os.environ['OUTPUT_DIR'] = f"./CustomModel/"
os.environ['TOKENIZERS_PARALLELISM'] = "false"
topk=length
training_steps_per_epoch=topk*10
os.environ['CHECKPOINTING_STEPS']=str(training_steps_per_epoch)
os.environ['RESOLUTION']=str(resolution)
os.environ['LEARNING_RATE']=str(9e-6)

# remove old account directory
try: 
    shutil.rmtree('./CustomModel')
except:
    pass
try: 
    shutil.rmtree('./trainingdataset/imagefolder/')
except:
    pass

total = 0
for epoch in range(epochs+1):
  print("Epoch: ",epoch)
  epoch=epoch
  training_steps=str(training_steps_per_epoch*(epoch+1))
  os.environ['TRAINING_STEPS']=training_steps
  os.environ['TRAINING_DIR'] = f'./trainingdataset/imagefolder/{epoch}'

  training_prompts=[]
  ftprompts[f'Epoch{epoch} Scores']=np.nan

  for step, prompt_list in enumerate(finetune_dataloader):
    tomesd.apply_patch(pipe, ratio=0.5)
    image=pipe(prompt_list,num_images_per_prompt=num_images_per_prompt,width=resolution,height=resolution).images
    image_list=[]

    for i in range(int(len(image)/num_images_per_prompt)):
      image_list.append(image[i*num_images_per_prompt:(i+1)*num_images_per_prompt])
    torch.cuda.empty_cache()
    
    with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
      step_list=[i for i in range(step*batch_size,(step+1)*batch_size)]
      prompts_list = [prompt_list] * len(step_list)
      score_index=executor.map(get_max_score,prompts_list,image_list,step_list,[epoch for i in range(len(step_list))])

    iterator=0
    for max_scores in score_index:
      training_prompts.append([max_scores[0],image_list[iterator][max_scores[1]],prompt_list[iterator]])
      iterator+=1

  training_prompts=[row[1:3] for row in sorted(training_prompts,key=lambda x: (x[0]),reverse=True)[:topk]]
  training_prompts=pd.DataFrame(training_prompts)

  if not os.path.exists(f"./trainingdataset/imagefolder/{epoch}/train/"):
    os.makedirs(f"./trainingdataset/imagefolder/{epoch}/train/")
  if not os.path.exists(f"./CustomModel/"):
    os.makedirs(f"./CustomModel/")
  for i in range(len(training_prompts)):
    training_prompts.iloc[i,0].save(f'./trainingdataset/imagefolder/{epoch}/train/{i}.png')

  training_prompts['file_name']=[f"{i}.png" for i in range(len(training_prompts))]
  training_prompts.columns = ['0','text','file_name']
  training_prompts.drop('0',axis=1,inplace=True)
  training_prompts.to_csv(f'./trainingdataset/imagefolder/{epoch}/train/metadata.csv',index=False)


  if epoch<epochs:
    !accelerate launch --num_processes=1 --mixed_precision='fp16' --dynamo_backend='no' --num_machines=1 train_text_to_image_lora.py \
        --pretrained_model_name_or_path=$MODEL_NAME \
        --train_data_dir=$TRAINING_DIR \
        --resolution=$RESOLUTION \
        --train_batch_size=8 \
        --gradient_accumulation_steps=1 \
        --gradient_checkpointing \
        --max_grad_norm=1 \
        --mixed_precision="fp16" \
        --max_train_steps=$TRAINING_STEPS \
        --learning_rate=$LEARNING_RATE \
        --lr_warmup_steps=0 \
        --enable_xformers_memory_efficient_attention \
        --dataloader_num_workers=5 \
        --output_dir=$OUTPUT_DIR \
        --lr_warmup_steps=0 \
        --seed=1234 \
        --checkpointing_steps=$CHECKPOINTING_STEPS \
        --resume_from_checkpoint="latest" \
        --lr_scheduler='constant' 
  pipe.unet.load_attn_procs(f'./CustomModel/')
  torch.cuda.empty_cache()

In [None]:
generate_pretrained_model_images= True
if generate_pretrained_model_images:
  image_list=[]
  for step, prompt_list in enumerate(finetune_dataloader):
      image=pipe(prompt_list,num_images_per_prompt=num_images_per_prompt,width=resolution,height=resolution).images 
      image_list+=image
      torch.cuda.empty_cache()

  grid = image_grid(image_list, len(ftprompts),num_images_per_prompt)
  grid.save("pretrained.png") 
  grid
generate_finetuned_model_images= True #@param {type:"boolean"}

if generate_finetuned_model_images:
  image_list=[]
  pipe.unet.load_attn_procs('./CustomModel')
  for step, prompt_list in enumerate(finetune_dataloader):
      image=pipe(prompt_list,num_images_per_prompt=num_images_per_prompt,width=resolution,height=resolution).images 
      image_list+=image
      torch.cuda.empty_cache()

  grid = image_grid(image_list, len(ftprompts),num_images_per_prompt)
  grid.save("trained.png")
  grid