In [None]:
# @title Script

from torch import nn
from pathlib import Path
from PIL import Image
from unittest.mock import patch
from IPython.display import clear_output,display, HTML
from itertools import islice
from openai import OpenAI
from google import genai
from transformers.dynamic_module_utils import get_imports

import io, base64, json, yaml, toml
import numpy as np
import requests, copy, os, torch, gc, re
import torch.amp.autocast_mode

gpu_name = torch.cuda.get_device_name()
print(gpu_name)
if 'A100' in gpu_name:
  os.environ['TORCH_CUDA_ARCH_LIST'] = '8.0'
if 'L4' in gpu_name:
  os.environ['TORCH_CUDA_ARCH_LIST'] = '8.9'
if 'T4' in gpu_name:
  os.environ['TORCH_CUDA_ARCH_LIST'] = '7.5'

#API
model_list = {
    "APIGemini | 2.5 Pro": "gemini-2.5-pro",
    "APIGemini | 2.5 Flash": "gemini-2.5-flash",
    "APIGemini | 2.0 Flash" : "gemini-2.0-flash",
    "APIGemini | 2.0 Flash Lite": "gemini-2.0-flash-lite",
    "APIOpenAI | GPT 4-o mini": "gpt-4o-mini",
    "APIOpenAI | GPT o4-mini": "o4-mini",
}

url = "https://raw.githubusercontent.com/StableDiffusionVN/SDVN-WebUI/refs/heads/main/model_lib.json"
response = requests.get(url)
model_train_list = json.loads(response.text)

lora_train_py = {
    "Flux": "flux_train_network.py",
    "SDXL": "sdxl_train_network.py",
    "SD15": "train_network.py"
}
db_train_py = {
    "Flux": "flux_train.py",
    "SDXL": "sdxl_train.py",
    "SD15": "train_db.py"
}
def encode_image(image):
    with io.BytesIO() as image_buffer:
        image.save(image_buffer, format="PNG")
        image_buffer.seek(0)
        encoded_image = base64.b64encode(image_buffer.read()).decode('utf-8')
    return encoded_image

def api_check():
    api_file = os.path.join(data_dir,"Setting/API_key_for_sdvn_comfy_node.json")
    if os.path.exists(api_file):
        with open(api_file, 'r', encoding='utf-8') as f:
            api_list = json.load(f)
        return api_list
    else:
        return None

def api_caption(image, length:int, APIkey, Caption, prompt):
    if APIkey == "":
        api_list = api_check()
        if api_check() != None:
            if "Gemini" in Caption:
                APIkey =  api_list["Gemini"]
            if "OpenAI" in Caption:
                APIkey =  api_list["OpenAI"]
    model_name = model_list[Caption]
    prompt += f"Picture description, Send the description on demand, limit {length} words, only send me the answer, Always return English. "
    if 'Gemini' in Caption:
        client = genai.Client(api_key=APIkey)
        response = client.models.generate_content(
                    model=model_name,
                    contents=[prompt, image])
        answer = response.text
    if "OpenAI" in Caption:
        answer = ""
        client = OpenAI(
            api_key=APIkey)
        if image != None:
            image = encode_image(image)
            prompt = [{"type": "text", "text": prompt, }, {
                "type": "image_url", "image_url": {"url":  f"data:image/jpeg;base64,{image}"}, },]
        messages = [{"role": "user", "content": prompt }]
        stream = client.chat.completions.create(
            model=model_name,
            messages=messages,
            stream=True
        )
        for chunk in stream:
            if chunk.choices[0].delta.content is not None:
                answer += chunk.choices[0].delta.content
        if image != None:
            answer = answer.split('return True')[-1]
    return answer.strip()

#Florence

version = "large"
device = torch.device(torch.cuda.current_device())

def clean_directory(directory):
  supported_types = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".safetensors"]
  for item in os.listdir(directory):
      file_path = os.path.join(directory, item)
      if os.path.isfile(file_path):
          file_ext = os.path.splitext(item)[1]
          if file_ext not in supported_types:
              print(f"Deleting file {item} from {directory}")
              os.remove(file_path)
      elif os.path.isdir(file_path):
          clean_directory(file_path)

def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
    """Workaround for FlashAttention"""
    if os.path.basename(filename) != "modeling_florence2.py":
        return get_imports(filename)
    imports = get_imports(filename)
    # imports.remove("flash_attn")
    return imports

def load_model(version, device):
    from transformers import AutoProcessor, AutoModelForCausalLM
    model_dir = "/content/Model"
    if not os.path.exists(model_dir):
        os.mkdir(model_dir)

    identifier = "microsoft/Florence-2-" + version

    with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
        model = AutoModelForCausalLM.from_pretrained(identifier, cache_dir=model_dir, trust_remote_code=True)
        processor = AutoProcessor.from_pretrained(identifier, cache_dir=model_dir, trust_remote_code=True)

    model = model.to(device)
    return (model, processor)

def load(version, device):
  if 'processor' not in globals():
    global model, processor
    model, processor = load_model(version, device)

def run_example(task_prompt, image, max_new_tokens, num_beams, do_sample, text_input=None):
    if text_input is None:
        prompt = task_prompt
    else:
        prompt = task_prompt + text_input
    inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
    generated_ids = model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=max_new_tokens,
        early_stopping=False,
        do_sample=do_sample,
        num_beams=num_beams,
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    parsed_answer = processor.post_process_generation(
        generated_text,
        task=task_prompt,
        image_size=(image.width, image.height)
    )
    return parsed_answer

def florence_caption(task_prompt, image, max_new_tokens = 1024, num_beams = 3, do_sample = False, fill_mask = False, text_input=None):
    if task_prompt == '<CAPTION>':
        result = run_example(task_prompt, image, max_new_tokens, num_beams, do_sample)
        return result[task_prompt].replace("\n", "")
    elif task_prompt == '<DETAILED_CAPTION>':
        result = run_example(task_prompt, image, max_new_tokens, num_beams, do_sample)
        return result[task_prompt].replace("\n", "")
    elif task_prompt == '<MORE_DETAILED_CAPTION>':
        task_prompt = '<MORE_DETAILED_CAPTION>'
        result = run_example(task_prompt, image, max_new_tokens, num_beams, do_sample)
        return result[task_prompt].replace("\n", "")

#Caption

def caption_dir(image_dir,prompt):
  if Caption == 'Florence':
    load(version, device)
  if Caption == 'WD14':
    print(f'Tạo caption WD14: {image_dir}')
    run = f"python /content/sd-scripts/finetune/tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-eva02-large-tagger-v3 --thresh {prompt[1]} --batch_size 4 '{image_dir}'"
    !{run}
  for img_file in os.listdir(image_dir):
      file_path = os.path.join(image_dir, img_file)
      if os.path.isdir(file_path) :
          caption_dir(file_path,prompt)
      if Caption != 'WD14':
        if img_file.lower().endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG")):
            img_path = os.path.join(image_dir, img_file)
            image = Image.open(img_path).convert("RGB")
            if Caption == 'Florence':
              cap = florence_caption(prompt[0],image).replace('The image shows','')
            else:
              cap = api_caption(image, prompt[2], APIkey, Caption, API_Prompt)
            txt_path = os.path.join(image_dir, f"{os.path.splitext(img_file)[0]}{extension}")
            with open(txt_path, "w") as f:
                f.write(cap)
            print(f"Miêu tả của ảnh {img_file}: {cap}")

def read_file(filename):
    with open(filename, "r") as f:
        contents = f.read()
    return contents

def write_file(filename, contents):
    with open(filename, "w") as f:
        f.write(contents)

def process_tags(filename, custom_tag, append, remove_tag):
    contents = read_file(filename)
    if remove_tag:
      contents = contents.replace(custom_tag, "")
    else:
      tags = [tag.strip() for tag in contents.split(',')]
      custom_tags = [tag.strip() for tag in custom_tag.split(',')]
      for custom_tag in custom_tags:
          custom_tag = custom_tag.replace("_", " ")
          if custom_tag not in tags:
              if append:
                  tags.append(custom_tag)
              else:
                  tags.insert(0, custom_tag)
      contents = ', '.join(tags)
    write_file(filename, contents)

def check_dir(image_dir):
  if not any([filename.endswith(extension) for filename in os.listdir(image_dir)]):
      for filename in os.listdir(image_dir):
          if filename.endswith(((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG"))):
              open(
                  os.path.join(image_dir, filename.split(".")[0] + extension),
                  "w",
              ).close()

def process_dir(image_dir, tag, append, remove_tag):
  check_dir(image_dir)
  for filename in os.listdir(image_dir):
      file_path = os.path.join(image_dir, filename)
      if os.path.isdir(file_path) :
          print(filename)
          process_dir(file_path, tag, append, remove_tag)
      elif filename.endswith(extension):
          process_tags(file_path, tag, append, remove_tag)

def add_forder_name(folder):
  for filename in os.listdir(folder):
    file_path = os.path.join(folder, filename)
    if os.path.isdir(file_path):
      folder_name = os.path.basename(file_path)
      try:
          steps, name = folder_name.split('_', 1)
          steps = int(steps)
      except ValueError:
          name = folder_name
      name = name.replace("/", ", ")
      process_dir(file_path, name, False, False)
      add_forder_name(file_path)

def get_steps(folder):
    folder_name = os.path.basename(folder)
    try:
        steps, name = folder_name.split('_', 1)
        steps = int(steps)
    except ValueError:
        steps = Steps
        name = folder_name
    return steps, name

def check_txt(image_dir):
    txt_files = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.txt')]
    for filename in os.listdir(image_dir):
        file_path = os.path.join(image_dir, filename)
        if os.path.isdir(file_path):
            txt_files += check_txt(file_path)
    return txt_files

def random_sample(folder):
  import random
  txt_files = check_txt(folder)
  try:
    sample = read_file(random.choice(txt_files))
    sample = sample.replace('"', r'\"')
  except IndexError:
    sample = "girl portrait, smile"
  return sample

def get_supported_images(folder):
  import glob
  supported_extensions = (".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG")
  list_img = [file for ext in supported_extensions for file in glob.glob(f"{folder}/*{ext}")]
  for img_file in os.listdir(folder):
      file_path = os.path.join(folder, img_file)
      if os.path.isdir(file_path) :
          list_img = list_img + get_supported_images(file_path)
  return list_img

def check_folder_train(folder):
    if len(get_supported_images(folder)) > 0:
      folder_dic = {
        "path": folder,
      }
      print('=====================')
      print(f'Thư mục train: {folder_dic["path"]}')
      print(f'  Số lượng ảnh: {len(get_supported_images(folder_dic["path"]))}')
      print('=====================')
    else:
      print(f"Thư mục [ {folder} ] có thể không chứa ảnh được hỗ trợ, hãy kiểm tra lại (.png, .jpg, .jpeg, .webp, .bmp, .JPG, .JPEG, .PNG)")

def check_dir_image(image_dir):
  if not any([filename.endswith(((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG"))) for filename in os.listdir(image_dir)]):
    return False
  else:
    return True

def check_sub_dir(image_dir):
    list_dir = []
    for filename in os.listdir(image_dir):
        file_path = os.path.join(image_dir, filename)
        if os.path.isdir(file_path):
            list_dir += check_sub_dir(file_path)
    if check_dir_image(image_dir):
      list_dir += [image_dir]
    return list_dir

def repeat_dir(dir,num_repeats):
    dir_name = dir.split('/')[-1]
    try:
        r = int(dir_name.split('_')[0])
    except:
        r = num_repeats
    return r

def dic2arg(config:dict):
  arg = ''
  for value in config:
    arg += f'{value if str(config[value]) != "False" else ""} {"" if type(config[value]) == bool else config[value]} '
  return arg

def civit_downlink(link):
  !wget {link} -q -O model.html
  try:
      # Mở tệp và đọc nội dung
      with open('model.html', 'r', encoding='utf-8') as file:
          html_content = file.read()
      pattern = r'"modelVersionId":(\d+),'
      model_id = re.findall(pattern, html_content)
      if model_id:
        api_link = f'https://civitai.com/api/download/models/{model_id[0]}'
        print(f'Download model id_link: {api_link}')
        return api_link
      else:
          return "Không tìm thấy đoạn nội dung phù hợp."
  except requests.RequestException as e:
      return f"Lỗi khi tải trang: {e}"

def check_link(link):
  if 'huggingface.co' in link:
    if 'blob' in link:
      link = link.replace('blob', 'resolve')
  if 'civitai.com' in link:
    if 'civitai.com/models' in link:
      link = civit_downlink(link)
    link = link+'?token=8c7337ac0c39fe4133ae19a3d65b806f'
  return link

def aria_down(link,path,name, over = False):
  print(link)
  link = check_link(link)
  !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {'--allow-overwrite=true' if over else ''} {link} -d  {path} -o {name}

def download_lib(model):
  if 'https:' in model:
    model = model.replace('&', '\&')
    aria_down(model,model_folder,"model.safetensors", True)
    model_path = f"{model_folder}/model.safetensors"
  elif '/content/' in model:
    model_path = model
  else:
    if not any(ext in model for ext in ['.ckpt', '.gguf', '.safetensors']):
        model += '.safetensors'
    if model not in  model_train_list:
      model = "RealisticVision51.safetensors" if TrainType == 'SD15' else "SDXL-Base.safetensors"
    aria_down(model_train_list[model],model_folder,model)
    model_path = f"{model_folder}/{model}"
  return model_path

def hug_down(link,path):
  name = path.split('/')[-1]
  folder = path.split(name)[0]
  if "blob" in link:
    link = link.replace("blob","resolve")
  !aria2c --console-log-level=error --summary-interval=10 -c -x 16 -s 16 -k 1M {link} -d {folder} -o {name}

#Model download

flux = "https://huggingface.co/StableDiffusionVN/Flux/blob/main/Unet/flux1-dev.safetensors"
clip_l = "https://huggingface.co/StableDiffusionVN/Flux/blob/main/Clip/clip_l.safetensors"
t5xxl_fp16 = "https://huggingface.co/StableDiffusionVN/Flux/blob/main/Clip/t5xxl_fp16.safetensors"
vae = "https://huggingface.co/StableDiffusionVN/Flux/blob/main/Vae/flux_vae.safetensors"

flux_path = f"{model_folder}/flux1-dev.safetensors"
clip_l_path = f"{model_folder}/clip_l.safetensors"
t5xxl_fp16_path = f"{model_folder}/t5xxl_fp16.safetensors"
vae_path = f"{model_folder}/flux_vae.safetensors"

if TrainType == "Flux":
  hug_down(flux,flux_path)
  hug_down(clip_l,clip_l_path)
  hug_down(t5xxl_fp16,t5xxl_fp16_path)
  hug_down(vae,vae_path)
