In [None]:
# @title Check environment
!env | sort
!python --version
!pip list
!nvidia-smi
!nvcc --version

In [None]:
# @title Clone
import tempfile
import os
repoFolder = os.path.join(tempfile.gettempdir(), "kohya-gradio")
!git clone https://github.com/Jaid/kohya-gradio.git "$repoFolder"

In [None]:
!python -m pip --disable-pip-version-check install --upgrade pip
requirementsFile = os.path.join(repoFolder, "requirements_colab.txt")
%pip install --requirement {requirementsFile} --extra-index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://pypi.nvidia.com
%pip install shlex

In [None]:
# @title Check CUDA support in PyTorch
!pip show torch
import torch
print(torch.cuda.is_available())
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))
!pip show tensorflow
import tensorflow as tf
print(tf.config.list_physical_devices('GPU'))

In [None]:
import requests
def download(url, target):
  print(f"{url} → {target}")
  response = requests.get(url, allow_redirects=True)
  response.raise_for_status()
  with open(target, 'wb') as output_file:
    output_file.write(response.content)
vaeLocation = 'https://huggingface.co/stabilityai/sdxl-vae/resolve/main/sdxl_vae.safetensors' # @param {type:"string"}
%mkdir inputModel
download(vaeLocation, 'inputModel/vae.safetensors')

In [None]:
import locale
locale.getpreferredencoding = lambda: "UTF-8"
!accelerate config default
!accelerate test
import accelerate
print(accelerate.utils.get_max_memory())
print(accelerate.utils.is_bf16_available())
print(accelerate.utils.is_deepspeed_available())
print(accelerate.utils.is_cuda_available())
print(accelerate.utils.is_tpu_available())

In [None]:
# @title Generate Command { run: "auto" }
import math

dim = 64 # @param {type:"integer"}
imageCount = 4 # @param {type:"integer"}
epochs = 3 # @param {type:"integer"}
warmupStepsPercent = 0 # @param {type:"integer"}
learningRate = 0.0000004 # @param {type:"number"}
vramSaving = 'lossless' # @param ["lossless", "lossy", "none"]
wandbKey = '' # @param {type:"string"}
outputFormat = 'safetensors' # @param ["safetensors", "ckpt"]
alwaysSave = True # @param {type:"boolean"}
useLycoris = False # @param {type:"boolean"}
adaptiveOptimizer = False # @param {type:"boolean"}
samplePrompt = '' # @param {type:"string"}
pruneOutput = False # @param {type:"boolean"}

modelFolder = 'model'
steps = imageCount * epochs
alpha = math.trunc(dim / 2)

if warmupStepsPercent:
  warmupSteps = math.floor(steps * warmupStepsPercent / 100)
else:
  warmupSteps = 0

if samplePrompt:
  samplePrompts = samplePrompt.split('|')
  sampleInstruction = samplePrompt + ' --w 1024 --h 1024 --l 6 --s 50 --d 1'
  with open('prompt.txt', 'w') as file:
    file.write(sampleInstruction)

hasBf16 = accelerate.utils.is_bf16_available()
hasCuda = accelerate.utils.is_cuda_available()

accelerateArguments = [
]

launchArguments = [
  '--pretrained_model_name_or_path',
  # "$rootMixed/private/checkpoint.safetensors"
  'stabilityai/stable-diffusion-xl-base-1.0',
  '--vae',
  "inputModel/vae.safetensors",
  '--train_data_dir',
  os.path.join(repoFolder, 'img'),
  '--output_dir',
  modelFolder,
  '--logging_dir',
  "log",
  '--resolution',
  '1024,1024',
  '--save_model_as',
  outputFormat,
  '--text_encoder_lr',
  '0.0001',
  '--unet_lr',
  '0.0001',
  '--network_dim',
  dim,
  '--network_alpha',
  alpha,
  '--output_name',
  'trained',
  '--max_train_steps',
  steps,
  # '--max_train_epochs'
  # 10
  '--no_half_vae',
  '--caption_extension',
  '.txt',
  '--cache_latents',
]
if useLycoris:
  launchArguments += [
    '--network_module',
    'lycoris.kohya',
    '--network_args',
    'preset=full',
    'algo=full',
    'rank_dropout=0',
    'module_dropout=0',
    'use_tucker=False',
    'use_scalar=False',
    'rank_dropout_scale=True',
    'train_norm=True',
  ]
else:
  launchArguments += [
    '--network_module',
    'networks.lora',
  ]
if adaptiveOptimizer:
  launchArguments += [
    '--optimizer_type',
    'adafactor',
    '--optimizer',
    'adafactor',
    '--optimizer_args',
    'scale_parameter=False',
    'relative_step=False',
    'warmup_init=False',
  ]
else:
  launchArguments += [
    '--optimizer_type',
  ]
  if vramSaving == 'lossy':
    launchArguments += [
      'adamw8bit',
    ]
  else:
    launchArguments += [
      'adamw',
    ]
if alwaysSave:
  launchArguments += [
    '--save_every_n_epochs',
    1,
  ]
if not learningRate:
  launchArguments += [
    '--lr_scheduler',
    'adafactor',
  ]
elif warmupSteps:
  launchArguments += [
    '--lr_scheduler',
    'constant_with_warmup',
    '--lr_warmup_steps',
    warmupSteps,
    '--learning_rate',
    learningRate,
  ]
else:
  launchArguments += [
    '--lr_scheduler',
    'constant',
  ]
if vramSaving != 'none':
  launchArguments += [
    '--lowram',
  ]
  if hasCuda:
    launchArguments += [
      '--xformers',
    ]
if vramSaving == 'lossy':
  if hasBf16:
    launchArguments += [
      '--mixed_precision',
      'bf16',
      # '--full_bf16',
    ]
  else:
    launchArguments += [
      '--mixed_precision',
      'fp16',
      # '--full_fp16',
    ]
else:
    launchArguments += [
    '--mixed_precision',
    'no',
  ]
if vramSaving == 'lossy' or pruneOutput:
  if hasBf16:
    launchArguments += [
      '--save_precision',
      'bf16',
    ]
  else:
    launchArguments += [
      '--save_precision',
      'fp16',
    ]
else:
  launchArguments += [
    '--save_precision',
    'float',
  ]
if samplePrompt:
  launchArguments += [
    '--sample_sampler',
    'euler_a',
    '--sample_prompts',
    'prompt.txt',
    '--sample_every_n_epochs',
    1,
  ]
if wandbKey:
  launchArguments += [
    '--log_with',
    'wandb',
    '--wandb_api_key',
    wandbKey,
  ]
launchArguments += [
  '--metadata_title',
  'SinansWoche',
]

import shlex
segments = [
  'accelerate',
  'launch',
] + accelerateArguments + [
  os.path.join(repoFolder, 'sdxl_train_network.py'),
] + launchArguments
command = " ".join(shlex.quote(str(segment)) for segment in segments)
print(command)

In [None]:
!{command}

In [None]:
# @title Clean memory
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
import os
import re
# @title Collect states
%mkdir --parents samples
stepModels = []
for filename in os.listdir(modelFolder):
  nameWithoutSuffix = os.path.splitext(filename)[0]
  filepath = os.path.join(modelFolder, filename)
  if os.path.isfile(filepath):
    if re.match(r'.+-[0-9]{6}$', nameWithoutSuffix):
      stepModels.append(filename)
stepModels.sort()
stepModels.append('trained.safetensors')
print(stepModels)

In [None]:
import torch
from diffusers import StableDiffusionXLPipeline
# DPMSolverMultistepScheduler.from_config(
#             scheduler_config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++"
#         )
pipe = StableDiffusionXLPipeline.from_pretrained(
    'stabilityai/stable-diffusion-xl-base-1.0',
)
pipe.to('cuda')
pipe.enable_freeu(0.6, 0.4, 1.1, 1.2)
prompt = 'SinansWoche sitting on a chair' # @param {type:"string"}
image = pipe(
prompt=prompt,
guidance_scale=6,
num_inference_steps=100,
output_type='pil',
).images[0]
image.save('samples/sample.png')

In [None]:
# @title Display image
from IPython.display import Image
Image('samples/sample.png')