<a href="https://colab.research.google.com/github/entmike/disco-diffusion-1/blob/main/Simplified_Disco_Diffusion_YAML.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🖼️ Simplified Disco Diffusion (YAML)

[![Discord](https://badgen.net/badge/icon/discord?icon=discord&label)](https://discord.gg/Pmy3wFKbna) [![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://github.com/entmike/disco-diffusion-1) ![Terminal](https://badgen.net/badge/icon/terminal?icon=terminal&label) [![Issues](https://img.shields.io/github/issues/entmike/disco-diffusion-1)](https://github.com/entmike/disco-diffusion-1/issues) ![Stars](https://img.shields.io/github/stars/entmike/disco-diffusion-1) ![Commits](https://img.shields.io/github/commit-activity/w/entmike/disco-diffusion-1)

## 💡 About

**Run is Optional.**

Run this Cell for credits, help, change log, etc.

In [None]:
import requests
from IPython.display import Markdown as md
md(requests.get('https://raw.githubusercontent.com/entmike/disco-diffusion-1/main/NOTEBOOK-README.md').text)

## 🌲 Set Up Environment

**Run is Required.**

Expand to set parameters related to location, Git Repo, and Branch to pull Disco Diffusion from.

In [None]:
import os, sys
import subprocess, torch

#@markdown Leave these as defaults unless you know what you are doing.

use_google_drive = True #@param {type:"boolean"}
save_models = True #@param {type:"boolean"}
check_for_updates = True #@param {type:"boolean"}
content_root = '/content'
repo = "https://github.com/entmike/disco-diffusion-1" #@param {type:"string"}
branch = "main" #@param {type:"string"}
cwd = os.path.abspath('.')
is_local=True

if is_local: content_root=cwd
print (f'Current directory: {cwd}')

if use_google_drive == True:
  import os
  from google.colab import drive
  if os.path.isdir('/content/gdrive') == False:
    print(f'📁 Mounting Google Drive.  Please accept any confirmation screens.')
    drive.mount('/content/gdrive/')
  else:
    print(f'📁 Google Drive already mounted.')
  content_root = '/content/gdrive/MyDrive'

dd_root = f'{content_root}/disco-diffusion-1'

print(f'✅ Disco Diffusion root path will be "{dd_root}"')

is_colab = False
try:
    from google.colab import drive
    print("Google Colab detected.")
    is_colab=True
except:
    print("Google Colab not detected.")
    is_colab=False

root_path = dd_root

# Clone Repo
if os.path.isdir(f'{dd_root}') == False:
  print(f"Cloning repo '{repo}' into '{dd_root}'...")
  os.chdir(f'{content_root}')
  subprocess.run(f'git clone {repo}'.split(' '), stdout=subprocess.PIPE).stdout.decode("utf-8")

os.chdir(f'{dd_root}')
if check_for_updates == True:
  # Pull any updates
  print(f'📄 Pulling updates from GitHub...')
  for cmd in ['git clean -df', f'git checkout {branch}', f'git reset --hard', f'git pull origin {branch}']:
    gitresults = subprocess.run(f'{cmd}'.split(' '), stdout=subprocess.PIPE).stdout.decode("utf-8")
    print(f'{gitresults}')
else:
  print("⚠️ Skipping checking for Git updates")
#Upgrade pyyaml if in Colab
if is_colab:
    print(f'📦 Upgrading pyyaml...')
    subprocess.run(f'pip install --upgrade pyyaml --quiet'.split(' '), stdout=subprocess.PIPE).stdout.decode("utf-8")
    print(f'📦 Installing pip requirements...')
    subprocess.run(f'pip install -r colab-requirements.txt --quiet'.split(' '), stdout=subprocess.PIPE).stdout.decode("utf-8")
    need_pytorch3d=False
    try:
      import pytorch3d
    except ModuleNotFoundError:
        need_pytorch3d=True
    
    if need_pytorch3d:
        print(f'📦 Installing pytorch3d requirements...')
        if torch.__version__.startswith("1.11.") and sys.platform.startswith("linux"):
            # We try to install PyTorch3D via a released wheel.
            pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
            version_str="".join([
                f"py3{sys.version_info.minor}_cu",
                torch.version.cuda.replace(".",""),
                f"_pyt{pyt_version_str}"
            ])
            subprocess.run(f'pip install fvcore iopath'.split(' '), stdout=subprocess.PIPE).stdout.decode("utf-8")
            subprocess.run(f'pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html'.split(' '), stdout=subprocess.PIPE).stdout.decode("utf-8")
        # else:
        #     # We try to install PyTorch3D from source.
        #     !curl -LO https://github.com/NVIDIA/cub/archive/1.10.0.tar.gz
        #     !tar xzf 1.10.0.tar.gz
        #     os.environ["CUB_HOME"] = os.getcwd() + "/cub-1.10.0"
        #     !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'


# Set base project directory to current working directory
PROJECT_DIR = dd_root

# Import DD helper modules
sys.path.append(PROJECT_DIR)
import dd, dd_args

# Unsure about these:
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# import warnings
# warnings.filterwarnings("ignore", category=UserWarning)

# print(dd.is_in_notebook())

## 🏃‍♂️ Start Job

**Press Run to Start Job!**

🪄 **TIP:** If you have `multipliers` or `modifiers` that result in multiple jobs, you can press the 'Stop' button in your Notebook to skip that job and proceed to the next.  If you have many jobs and want to kill them all, it is easier just to reset the runtime instead of clicking the Stop button over and over, depending on how many jobs there are.

In [None]:
#@title  { display-mode: "form" }
from pydotted import pydot
args = pydot({})
#@markdown ## Specify your YAML config file
args.config_file = "/content/gdrive/MyDrive/disco-diffusion-1/examples/configs/explore.yaml" #@param {type: "string"}
#@markdown ## Parameter overriding
#@markdown 💡 You can also override any other parameters here manually, such as these examples:
# args.db = "/content/gdrive/MyDrive/disco-diffusion-1/disco.db" #@param {type: "string"}
args.set_seed = 8675309 #@param {type: "number"}

# Load defaults
pargs = dd_args.arg_configuration_loader(args)

# Setup folders
folders = dd.setupFolders(is_colab=dd.detectColab(), PROJECT_DIR=PROJECT_DIR, pargs=pargs)

# Load Models
dd.loadModels(folders)

# Report System Details
dd.systemDetails(pargs)

# Get CUDA Device
device = dd.getDevice(pargs)

dd.start_run(pargs=pargs, folders=folders, device=device, is_colab=dd.detectColab())