In [None]:
# @title Check environment
!env | sort
!python --version
!pip list
!nvidia-smi
!nvcc --version

In [None]:
# @title Clone
!git clone https://github.com/Jaid/kohya-gradio.git
%cd kohya-gradio

In [None]:
!python -m pip --disable-pip-version-check install --upgrade pip
%pip install --disable-pip-version-check --requirement requirements_colab.txt --extra-index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://pypi.nvidia.com


In [None]:
%mkdir --parents private
import requests
def download(url, target):
  print(f"{url} → {target}")
  response = requests.get(url, allow_redirects=True)
  response.raise_for_status()
  with open(target, 'wb') as output_file:
    output_file.write(response.content)
vaeLocation = 'https://huggingface.co/stabilityai/sdxl-vae/resolve/main/sdxl_vae.safetensors' # @param {type:"string"}
download(vaeLocation, '/content/vae.safetensors')

In [None]:
# @title Check CUDA support in PyTorch
!pip show torch
import torch
print(torch.cuda.is_available())
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))
!pip show tensorflow
import tensorflow as tf
print(tf.config.list_physical_devices('GPU'))

In [None]:
!accelerate config
!accelerate test
import accelerate
print(accelerate.utils.get_max_memory())
print(accelerate.utils.is_bf16_available())
print(accelerate.utils.is_xla_available())
print(accelerate.utils.is_deepspeed_available())
print(accelerate.utils.is_cuda_available())
print(accelerate.utils.is_safetensors_available())
print(accelerate.utils.is_tpu_available())

: 

In [None]:
import math

dim = 64 # @param {type:"integer"}
imageCount = 4 # @param {type:"integer"}
epochs = 20 # @param {type:"integer"}
steps = imageCount * epochs
alpha = math.trunc(dim / 2)
warmupStepsPercent = 0 # @param {type:"integer"}
learningRate = 0.0000004 # @param {type:"number"}
fullPrecision = False # @param {type:"boolean"}
saveVram = True # @param {type:"boolean"}
wandbKey = '' # @param {type:"string"}
outputFormat = 'safetensors' # @param ["safetensors", "checkpoint"]
alwaysSave = True # @param {type:"boolean"}
useLycoris = True # @param {type:"boolean"}
adaptiveOptimizer = False # @param {type:"boolean"}
sample = False # @param {type:"boolean"}

if warmupStepsPercent:
  warmupSteps = math.floor(steps * warmupStepsPercent / 100)
else:
  warmupSteps = 0

accelerateArguments = [
]

launchArguments = [
  '--pretrained_model_name_or_path',
  # "$rootMixed/private/checkpoint.safetensors"
  'stabilityai/stable-diffusion-xl-base-1.0',
  '--vae',
  "/content/vae.safetensors",
  '--train_data_dir',
  "img",
  '--output_dir',
  "/content/out/model",
  '--logging_dir',
  "/content/out/log",
  '--resolution',
  '1024,1024',
  '--save_model_as',
  outputFormat,
  '--text_encoder_lr',
  '0.0001',
  '--unet_lr',
  '0.0001',
  '--network_dim',
  dim,
  '--network_alpha',
  alpha,
  '--output_name',
  'trained',
  '--max_train_steps',
  steps,
  # '--max_train_epochs'
  # 10
  '--no_half_vae',
  '--caption_extension',
  '.txt',
  '--cache_latents',
]
if useLycoris:
  launchArguments += [
    '--network_module',
    'lycoris.kohya',
    '--network_args',
    'preset=full',
    'algo=full',
    'rank_dropout=0',
    'module_dropout=0',
    'use_tucker=False',
    'use_scalar=False',
    'rank_dropout_scale=True',
    'train_norm=True',
  ]
else:
  launchArguments += [
    '--network_module',
    'networks.lora',
  ]
if adaptiveOptimizer:
  launchArguments += [
    '--optimizer_type',
    'adafactor',
    '--optimizer',
    'adafactor',
    '--optimizer_args',
    'scale_parameter=False',
    'relative_step=False',
    'warmup_init=False',
  ]
else:
  launchArguments += [
    '--optimizer_type',
  ]
  if saveVram:
    launchArguments += [
      # 'adamw8bit'
      'adamw',
    ]
  else:
    launchArguments += [
      'adamw',
    ]
if alwaysSave:
  launchArguments += [
    '--save_every_n_epochs',
    1,
  ]
if not learningRate:
  launchArguments += [
    '--lr_scheduler',
    'adafactor',
  ]
elif warmupSteps:
  launchArguments += [
    '--lr_scheduler',
    'constant_with_warmup',
    '--lr_warmup_steps',
    warmupSteps,
    '--learning_rate',
    learningRate,
  ]
else:
  launchArguments += [
    '--lr_scheduler',
    'constant',
  ]
if saveVram:
  launchArguments += [
    '--xformers',
  ]
if fullPrecision:
  launchArguments += [
    '--mixed_precision',
    'no',
    '--save_precision',
    'float',
  ]
else:
  launchArguments += [
    '--mixed_precision',
    'bf16',
    '--save_precision',
    'bf16',
    # '--full_bf16',
  ]
if sample:
  launchArguments += [
    '--sample_sampler',
    'euler_a',
    '--sample_prompts',
    "$rootMixed/out/model/sample/prompt.txt",
    '--sample_every_n_epochs',
    1,
  ]
if wandbKey:
  launchArguments += [
    '--log_with',
    'wandb',
    '--wandb_api_key',
    wandbKey,
  ]
launchArguments += [
  '--metadata_title',
  'SinansWoche',
]

print(f"accelerate launch {accelerateArguments} sdxl_train_network.py {launchArguments}")
!accelerate launch {accelerateArguments} sdxl_train_network.py {launchArguments}