<a href="https://colab.research.google.com/github/un1tz3r0/anythingdiffusion/blob/main/Anything_Diffusion_Fine_Tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

![visitors](https://visitor-badge.glitch.me/badge?page_id=fine_tuning_your_own_diffusion_model_using_clip_retrieval_ipynb)

# [AnythingDiffusion](https://github.com/un1tz3r0/anythingdiffusion/)
#### by [un1tz3r0](https://linktr.ee/un1tz3r0), based on a notebook by [Alex Spirin](https://twitter.com/devdef).

A simple colab to fine-tune your very own diffusion models on images from CLIP-retrieval which are nearby a text prompt, and automatically resume training from the last checkpoint.


# Configure

Needs 16gb GPU RAM

Works in colab pro and on kaggle 

In [1]:
#@markdown This is the name of the subdirectory where your custom model snapshots and logs will be dumped during the training:

custom_model_name = "spiraldiffusion" #@param {type:"string"}

#@markdown Everything, including the dataset and the models and model progress and other training output will be on your drive in <tt>Disco_Diffusion/Fine_Tuning/<error>custom_model_name</error>/</tt>


# Setup

In [2]:
#@markdown Connect with google drive

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
#@markdown Download and install guided diffusion

%cd /content
!git clone https://github.com/Sxela/guided-diffusion-sxela
%cd /content/guided-diffusion-sxela
!pip install -e .

/content
Cloning into 'guided-diffusion-sxela'...
remote: Enumerating objects: 154, done.[K
remote: Counting objects: 100% (90/90), done.[K
remote: Compressing objects: 100% (47/47), done.[K
remote: Total 154 (delta 60), reused 48 (delta 43), pack-reused 64[K
Receiving objects: 100% (154/154), 87.56 KiB | 6.74 MiB/s, done.
Resolving deltas: 100% (71/71), done.
/content/guided-diffusion-sxela
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Obtaining file:///content/guided-diffusion-sxela
Collecting blobfile>=1.0.5
  Downloading blobfile-1.3.1-py3-none-any.whl (70 kB)
[K     |████████████████████████████████| 70 kB 5.2 MB/s 
Collecting xmltodict~=0.12.0
  Downloading xmltodict-0.12.0-py2.py3-none-any.whl (9.2 kB)
Collecting urllib3~=1.25
  Downloading urllib3-1.26.11-py2.py3-none-any.whl (139 kB)
[K     |████████████████████████████████| 139 kB 27.2 MB/s 
[?25hCollecting pycryptodomex~=3.8
  Downloading pycryptodomex-3.15.0-cp35-a

In [4]:
#@markdown Define some helpers and create directories

import pathlib, subprocess, os, sys, ipykernel

try:
  import google.colab
  is_colab = True
except:
  is_colab = False

def createPath(filepath):
    os.makedirs(filepath, exist_ok=True)

def createParent(filepath):
    os.makedirs(os.path.dirname(os.path.abspath(filepath)), exist_ok=True)

def pipi(*modulestrs):
    res = subprocess.run(['pip', 'install', *modulestrs], stdout=subprocess.PIPE).stdout.decode('utf-8')
    print(res)

def wget(url, outputdir=None, filename=None):
    if outputdir != None:
      res = subprocess.run(['wget', url, '-P', f'{outputdir}'], stdout=subprocess.PIPE).stdout.decode('utf-8')
    elif filename != None:
      res = subprocess.run(['wget', url, '-O', f'{filename}'], stdout=subprocess.PIPE).stdout.decode('utf-8')
    else:
      res = subprocess.run(['wget', url], stdout=subprocess.PIPE).stdout.decode('utf-8')
    print(res)

google_drive = True


if is_colab:
    if google_drive is True:
        from google.colab import drive
        drive.mount('/content/drive')
        rootPath = '/content/drive/MyDrive/Disco_Diffusion'
    else:
        rootPath = '/content'
else:
    rootPath = os.getcwd()

def createPath(filepath):
    os.makedirs(filepath, exist_ok=True)

def createParent(filepath):
    os.makedirs(os.path.dirname(os.path.abspath(filepath)), exist_ok=True)

# set up some folders based on the custom_model_name in the form for this cell...

finetuningRoot = f"{rootPath}/Fine_Tuning/{custom_model_name}"
createPath(f"{finetuningRoot}")

datasetRoot = f"{finetuningRoot}/dataset"
createPath(f"{datasetRoot}")

trainingRoot = f"{finetuningRoot}/training"
createPath(f"{trainingRoot}")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Get images to train on using CLIP-retrieval

Generate a dataset from images retrieved by proximity to a text prompt in the CLIP latent space.



In [6]:
# create the dataset using clip retrieval and aiohttp/aiomultiprocessing to download in paralell
import shlex

dataset_text_prompt = "spiral" #@param {type: "string"}
dataset_fetch_size = 5000 #@param {type: "integer"}
dataset_crop_count = 6000 #@param {type: "integer"}
force_existing_dataset = True #@param {type: "boolean"}

datasetPath = f'{datasetRoot}'
if pathlib.Path(datasetPath).exists():
  print(f"Dataset directory: {datasetPath}\nDataset already exists, skipping clip-retrieval...")
else:
  createPath(datasetPath)

if (not pathlib.Path(datasetPath).exists()) or force_existing_dataset: # or len(list(pathlib.Path(datasetPath).iterdir())) < 3':
  print(f"Creating new dataset from clip-retrieval for the prompt: '{dataset_text_prompt}'")
  try:
    datasetOutPath=datasetPath+"/out"
    createPath(datasetOutPath)
    createPath(datasetPath+"/crop")

    pipi("click", "clip-retrieval", "img2dataset", "aiomultiprocess", "aiohttp", "aiofile")
    wget("https://gist.githubusercontent.com/un1tz3r0/a18ba5cf48228ca5cabc58d1d556ad0b/raw/2c5f96e27ee077c8e218925380829a72770b0dd2/clipfetch.py", filename="clipfetch.py")
    wget("https://raw.githubusercontent.com/un1tz3r0/pixelscapes-dataset/main/randomcrops.py", filename="randomcrops.py")

    dataset_text_prompt_q = shlex.quote(dataset_text_prompt)
    datasetOutPath_q = shlex.quote(datasetOutPath)
    !python3 clipfetch.py $dataset_text_prompt_q $datasetOutPath_q --count $dataset_fetch_size --timeout 5 --paralell 25

    print(f"Generating {dataset_crop_count} cropped squares from source images in {datasetOutPath}...")

    import sys
    sys.path.append(pathlib.Path("./").absolute())
    import randomcrops
    randomcrops.randomcrops(datasetPath+"/out", datasetPath+"/crop", dataset_crop_count, 256, weighting=0.0, withclasses=False, statusinterval=50)

    datasetPath = datasetRoot+"/crop"
    print(f"Done creating dataset from clip-retrieval in: {datasetPath}")
    #touch(datasetPath+"/.completed")
  except Exception as err:
    import traceback as tb
    tb.print_exc(err)
    raise err
    

Dataset directory: /content/drive/MyDrive/Disco_Diffusion/Fine_Tuning/spiraldiffusion/dataset
Dataset already exists, skipping clip-retrieval...
Creating new dataset from clip-retrieval for the prompt: 'spiral'
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting clip-retrieval
  Downloading clip_retrieval-2.34.2-py3-none-any.whl (338 kB)
Collecting img2dataset
  Downloading img2dataset-1.32.0-py3-none-any.whl (34 kB)
Collecting aiomultiprocess
  Downloading aiomultiprocess-0.9.0-py3-none-any.whl (17 kB)
Collecting aiofile
  Downloading aiofile-3.8.1.tar.gz (18 kB)
Collecting fsspec==2022.1.0
  Downloading fsspec-2022.1.0-py3-none-any.whl (133 kB)
Collecting flask-cors<4,>=3.0.10
  Downloading Flask_Cors-3.0.10-py2.py3-none-any.whl (14 kB)
Collecting requests<3,>=2.27.1
  Downloading requests-2.28.1-py3-none-any.whl (62 kB)
Collecting clip-anytorch<3,>=2.3.1
  Downloading clip_anytorch-2.4.0-py3-none-any.whl (1.4 MB)
Collecting f

TypeError: ignored

# <big><big>Fine Tune</big></big>

This will run almost forever, but you should start checking your results at around ~50k iterations. Good results begin to appear at 100-200k iterations, depending on your dataset.


In [None]:
#@markdown # Do the run...

import shlex

def latest_checkpoint(checkpoint_path, default_model, default_model_url):
  import pathlib, os
  try:
    def kf(f):
      return f.lstat().st_mtime
    f = str(list(sorted(list(pathlib.Path(checkpoint_path).glob("ema_0.9999_*.pt")), key=kf))[-1])
    print(f"Resuming from latest checkpoint found: {f}")
    return f
  except Exception as err:
    print(f"Error finding latest checkpoint in {checkpoint_path}: {err}")
    print(f"Resuming from default pretrained model: {default_model}")
    if not pathlib.Path(default_model).exists():
      print(f"Downloading default pretrained model from: {default_model_url}")
      wget(default_model_url, filename=default_model)
      print(f"Done!")
    else:
      print(f"Default pretrained model found at: {default_model}")
      print(f"Skipping model download.")
    return default_model


#!wget https://openaipublic.blob.core.windows.net/diffusion/march-2021/lsun_uncond_100M_1200K_bs128.pt
MODEL_FLAGS="--image_size 256 --num_channels 128 --num_res_blocks 2 --num_heads 1 --learn_sigma True --use_scale_shift_norm False --attention_resolutions 16"
DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear --rescale_learned_sigmas False --rescale_timesteps False --use_scale_shift_norm False"
RESUME_CHECKPOINT=latest_checkpoint(trainingRoot, f"{trainingRoot}/lsun_uncond_100M_1200K_bs128.pt", 'https://openaipublic.blob.core.windows.net/diffusion/march-2021/lsun_uncond_100M_1200K_bs128.pt')
TRAIN_FLAGS=f"--lr 2e-5 --batch_size 4 --save_interval 1000 --log_interval 50  --resume_checkpoint {shlex.quote(RESUME_CHECKPOINT)}"
DATASET_PATH=shlex.quote(datasetPath) #change to point to your dataset path 
%cd /content/guided-diffusion-sxela
!OPENAI_LOGDIR=$trainingRoot python scripts/image_train.py --data_dir $DATASET_PATH $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS

NameError: ignored