<a href="https://colab.research.google.com/github/AIManifest/deforum-stable-diffusion/blob/main/Deforum_Stable_Diffusion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# $ \color{blue} {\large \textsf{Deforum Stable Diffusion v0.7}}$
[Stable Diffusion](https://github.com/CompVis/stable-diffusion) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer and the [Stability.ai](https://stability.ai/) Team. [K Diffusion](https://github.com/crowsonkb/k-diffusion) by [Katherine Crowson](https://twitter.com/RiversHaveWings). Notebook by [deforum](https://discord.gg/upmXXsrwZc)

[Quick Guide](https://docs.google.com/document/d/1RrQv7FntzOuLg4ohjRZPVL7iptIyBhwwbcEYEW2OfcI/edit?usp=sharing) to Deforum v0.7

In [None]:
#@title $ \color{blue} {\large \textsf{Environment Setup / NVIDIA GPU}}$
import subprocess, os, sys
from tqdm import tqdm
from IPython import display
sub_p_res = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total,memory.free', '--format=csv,noheader'], stdout=subprocess.PIPE).stdout.decode('utf-8')
print(f"\033[92m{sub_p_res[:-1]}")
import subprocess, time, gc, os, sys
!pip install -qq -U kora
def setup_environment():
    start_time = time.time()
    print_subprocess = False
    use_xformers_for_colab = True
    try:
        ipy = get_ipython()
    except:
        ipy = 'could not get_ipython'
    if 'google.colab' in str(ipy):
        print("\033[92m..setting up environment")

        # weird hack
        import torch
        
        all_process = [
            ['pip', 'install', 'facexlib>=0.2.5', 'gfpgan>=1.3.5', 'basicsr>=1.4.2', 'sk-video>=1.1.10', 'omegaconf', 'einops==0.4.1', 'pytorch-lightning==1.7.7', 'torchmetrics', 'transformers', 'safetensors', 'kornia'],
            ['git', 'clone', 'https://github.com/AIManifest/deforum-stable-diffusion.git'],
            ['git', 'clone', 'https://github.com/AIManifest/Practical-RIFE.git'],
            ['git', 'clone', 'https://github.com/AIManifest/Real-ESRGAN.git'],
            ['pip', 'install', '-qq', '--user', 'ffmpeg-python'],
            ['pip', 'install', 'accelerate', 'ftfy', 'jsonmerge', 'matplotlib', 'resize-right', 'timm', 'torchdiffeq','scikit-learn','torchsde','open-clip-torch','numpngw'],
        ]
        for process in tqdm(all_process):
            running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8')
            if print_subprocess:
                print(running)
        with open('deforum-stable-diffusion/src/k_diffusion/__init__.py', 'w') as f:
            f.write('')
        sys.path.extend([
            'deforum-stable-diffusion/',
            'deforum-stable-diffusion/src',
        ])
        if use_xformers_for_colab:

            print("\033[92m..installing triton and xformers")

            all_process = [['pip', 'install', 'triton==2.0.0.dev20221202', 'xformers==0.0.16']]
            for process in all_process:
                running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8')
                if print_subprocess:
                    print(running)
    else:
        sys.path.extend([
            'src'
        ])
    end_time = time.time()
    print(f"\033[92m..environment set up in {end_time-start_time:.0f} seconds")
    return

setup_environment()
with open('deforum-stable-diffusion/src/k_diffusion/__init__.py', 'w') as f:
  f.write('')
sys.path.extend([
  'deforum-stable-diffusion/',
  'deforum-stable-diffusion/src',
])
import torch
import random
import clip
import yaml
from IPython.display import display
from ipywidgets import widgets
from IPython import display
import glob
from types import SimpleNamespace
from helpers.save_images import get_output_folder
from helpers.settings import load_args
from helpers.render import render_animation, render_input_video, render_image_batch, render_interpolation
from helpers.model_load import make_linear_decode, load_model, get_model_output_paths
from helpers.aesthetics import load_aesthetics_model




[92mTesla T4, 15360 MiB, 15109 MiB


In [None]:
#@title $\color{blue} {\large \textsf{Path Setup}}$

def Root():
    models_path = "models" #@param {type:"string"}
    configs_path = "configs" #@param {type:"string"}
    output_path = "outputs" #@param {type:"string"}
    mount_google_drive = True #@param {type:"boolean"}
    models_path_gdrive = "/content/drive/MyDrive/sd/stable-diffusion-webui/models/Stable-diffusion" #@param {type:"string"}
    output_path_gdrive = "/content/drive/MyDrive/AI/StableDiffusion" #@param {type:"string"}

    #@markdown **Model Setup**
    map_location = "cuda" #@param ["cpu", "cuda"]
    model_config = "custom" #@param ["custom","v2-inference.yaml","v2-inference-v.yaml","v1-inference.yaml"]
    model_checkpoint =  "custom" #@param ["custom","v2-1_768-ema-pruned.ckpt","v2-1_512-ema-pruned.ckpt","768-v-ema.ckpt","512-base-ema.ckpt","Protogen_V2.2.ckpt","v1-5-pruned.ckpt","v1-5-pruned-emaonly.ckpt","sd-v1-4-full-ema.ckpt","sd-v1-4.ckpt","sd-v1-3-full-ema.ckpt","sd-v1-3.ckpt","sd-v1-2-full-ema.ckpt","sd-v1-2.ckpt","sd-v1-1-full-ema.ckpt","sd-v1-1.ckpt", "robo-diffusion-v1.ckpt","wd-v1-3-float16.ckpt"]
    custom_config_path = "" #@param {type:"string"}
    custom_checkpoint_path = "" #@param {type:"string"}
    embeddings_dir = "/content/drive/MyDrive/AI/embeddings" #@param{type:'string'}
    hypernetwork_dir = "/content/drive/MyDrive/AI/hypernetworks" #@param{type:'string'}
    data_dir = "/content/drive/MyDrive/sd/stable-diffusion-webui" #@param{type:'string'}
    #@markdown **Don't use yet, still working on it**
    use_xformers = True #@param{type:'boolean'}
    use_sub_quad_attention = False #@param{type:'boolean'}
    use_split_attention_v1 = False #@param{type:'boolean'}
    use_split_cross_attention_forward_invokeAI = False #@param{type:'boolean'}
    use_cross_attention_attnblock_forward = False #@param{type:'boolean'}

    variables = locals()
    with open("output.yaml", "w") as file:
      yaml.dump(variables, file)
    return variables

root = Root()
root = SimpleNamespace(**root)

root.models_path, root.output_path = get_model_output_paths(root)
root.model, root.device = load_model(root, load_on_run_all=True, check_sha256=True)

In [None]:
#@title $\color{blue} {\large \textsf{Load Textual Inversion!}}$
model = root.model
import os
import sys
import traceback
import inspect
from collections import namedtuple

import torch
import tqdm
import html
import datetime
import csv
import safetensors.torch

import numpy as np
from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter
root.model_checkpoint = "/content/drive/MyDrive/sd/stable-diffusion-webui/models/Stable-diffusion/Deliberate_Nova_Dream.ckpt"
checkpoint_name = os.path.splitext(os.path.basename(root.model_checkpoint))

import hashlib
print("..checking sha256")
with open(root.model_checkpoint, "rb") as f:
  bytes = f.read() 
  ckpt_hash = hashlib.sha256(bytes).hexdigest()

#learn_rate_scheduler

import tqdm

def autocast(disable=False):
    
    if disable:
        return contextlib.nullcontext()

    if dtype == torch.float32 or precision == "full":
        return contextlib.nullcontext()

    return torch.autocast("cuda")
    
class LearnScheduleIterator:
    def __init__(self, learn_rate, max_steps, cur_step=0):
        """
        specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000
        """
        print(learn_rate)
        pairs = learn_rate.split(',')
        self.rates = []
        self.it = 0
        self.maxit = 0
        try:
            for i, pair in enumerate(pairs):
                if not pair.strip():
                    continue
                tmp = pair.split(':')
                if len(tmp) == 2:
                    step = int(tmp[1])
                    if step > cur_step:
                        self.rates.append((float(tmp[0]), min(step, max_steps)))
                        self.maxit += 1
                        if step > max_steps:
                            return
                    elif step == -1:
                        self.rates.append((float(tmp[0]), max_steps))
                        self.maxit += 1
                        return
                else:
                    self.rates.append((float(tmp[0]), max_steps))
                    self.maxit += 1
                    return
            assert self.rates
        except (ValueError, AssertionError):
            raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.')


    def __iter__(self):
        return self

    def __next__(self):
        if self.it < self.maxit:
            self.it += 1
            return self.rates[self.it - 1]
        else:
            raise StopIteration


class LearnRateScheduler:
    def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
        self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
        (self.learn_rate,  self.end_step) = next(self.schedules)
        self.verbose = verbose

        if self.verbose:
            print(f'Training at rate of {self.learn_rate} until step {self.end_step}')

        self.finished = False

    def step(self, step_number):
        if step_number < self.end_step:
            return False

        try:
            (self.learn_rate, self.end_step) = next(self.schedules)
        except StopIteration:
            self.finished = True
            return False
        return True

    def apply(self, optimizer, step_number):
        if not self.step(step_number):
            return

        if self.verbose:
            tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')

        for pg in optimizer.param_groups:
            pg['lr'] = self.learn_rate

#dataset
import os
import numpy as np
import PIL
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import transforms
from collections import defaultdict
from random import shuffle, choices

import random
import tqdm
import re

from ldm.modules.distributions.distributions import DiagonalGaussianDistribution

re_numbers_at_start = re.compile(r"^[-\d]+\s*")


class DatasetEntry:
    def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None):
        self.filename = filename
        self.filename_text = filename_text
        self.latent_dist = latent_dist
        self.latent_sample = latent_sample
        self.cond = cond
        self.cond_text = cond_text
        self.pixel_values = pixel_values

dataset_filename_join_string= " "
dataset_filename_word_regex = ""
class PersonalizedBase(Dataset):
    def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False):
        re_word = re.compile(dataset_filename_word_regex) if len(dataset_filename_word_regex) > 0 else None

        self.placeholder_token = placeholder_token

        self.flip = transforms.RandomHorizontalFlip(p=flip_p)

        self.dataset = []

        with open(template_file, "r") as file:
            lines = [x.strip() for x in file.readlines()]

        self.lines = lines

        assert data_root, 'dataset directory not specified'
        assert os.path.isdir(data_root), "Dataset directory doesn't exist"
        assert os.listdir(data_root), "Dataset directory is empty"

        self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]

        self.shuffle_tags = shuffle_tags
        self.tag_drop_out = tag_drop_out
        groups = defaultdict(list)

        print("Preparing dataset...")
        for path in tqdm.tqdm(self.image_paths):
            # if shared.state.interrupted:
            #     raise Exception("interrupted")
            try:
                image = Image.open(path).convert('RGB')
                if not varsize:
                    image = image.resize((width, height), PIL.Image.BICUBIC)
            except Exception:
                continue

            text_filename = os.path.splitext(path)[0] + ".txt"
            filename = os.path.basename(path)

            if os.path.exists(text_filename):
                with open(text_filename, "r", encoding="utf8") as file:
                    filename_text = file.read()
            else:
                filename_text = os.path.splitext(filename)[0]
                filename_text = re.sub(re_numbers_at_start, '', filename_text)
                if re_word:
                    tokens = re_word.findall(filename_text)
                    filename_text = (dataset_filename_join_string or "").join(tokens)

            npimage = np.array(image).astype(np.uint8)
            npimage = (npimage / 127.5 - 1.0).astype(np.float32)

            torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
            latent_sample = None

            with torch.autocast("cuda"):
                latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))

            if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
                latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to("cuda")
                latent_sampling_method = "once"
                entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
            elif latent_sampling_method == "deterministic":
                # Works only for DiagonalGaussianDistribution
                latent_dist.std = 0
                latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to("cuda")
                entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
            elif latent_sampling_method == "random":
                entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)

            if not (self.tag_drop_out != 0 or self.shuffle_tags):
                entry.cond_text = self.create_text(filename_text)

            if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
                with autocast():
                    entry.cond = cond_model([entry.cond_text]).to("cuda").squeeze(0)
            groups[image.size].append(len(self.dataset))
            self.dataset.append(entry)
            del torchdata
            del latent_dist
            del latent_sample

        self.length = len(self.dataset)
        self.groups = list(groups.values())
        assert self.length > 0, "No images have been found in the dataset."
        self.batch_size = min(batch_size, self.length)
        self.gradient_step = min(gradient_step, self.length // self.batch_size)
        self.latent_sampling_method = latent_sampling_method

        if len(groups) > 1:
            print("Buckets:")
            for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
                print(f"  {w}x{h}: {len(ids)}")
            print()

    def create_text(self, filename_text):
        text = random.choice(self.lines)
        tags = filename_text.split(',')
        if self.tag_drop_out != 0:
            tags = [t for t in tags if random.random() > self.tag_drop_out]
        if self.shuffle_tags:
            random.shuffle(tags)
        text = text.replace("[filewords]", ','.join(tags))
        text = text.replace("[name]", self.placeholder_token)
        return text

    def __len__(self):
        return self.length

    def __getitem__(self, i):
        entry = self.dataset[i]
        if self.tag_drop_out != 0 or self.shuffle_tags:
            entry.cond_text = self.create_text(entry.filename_text)
        if self.latent_sampling_method == "random":
            entry.latent_sample = model.get_first_stage_encoding(entry.latent_dist).to("cuda")
        return entry


class GroupedBatchSampler(Sampler):
    def __init__(self, data_source: PersonalizedBase, batch_size: int):
        super().__init__(data_source)

        n = len(data_source)
        self.groups = data_source.groups
        self.len = n_batch = n // batch_size
        expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
        self.base = [int(e) // batch_size for e in expected]
        self.n_rand_batches = nrb = n_batch - sum(self.base)
        self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
        self.batch_size = batch_size

    def __len__(self):
        return self.len

    def __iter__(self):
        b = self.batch_size

        for g in self.groups:
            shuffle(g)

        batches = []
        for g in self.groups:
            batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
        for _ in range(self.n_rand_batches):
            rand_group = choices(self.groups, self.probs)[0]
            batches.append(choices(rand_group, k=b))

        shuffle(batches)

        yield from batches


class PersonalizedDataLoader(DataLoader):
    def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
        super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
        if latent_sampling_method == "random":
            self.collate_fn = collate_wrapper_random
        else:
            self.collate_fn = collate_wrapper


class BatchLoader:
    def __init__(self, data):
        self.cond_text = [entry.cond_text for entry in data]
        self.cond = [entry.cond for entry in data]
        self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
        #self.emb_index = [entry.emb_index for entry in data]
        #print(self.latent_sample.device)

    def pin_memory(self):
        self.latent_sample = self.latent_sample.pin_memory()
        return self

def collate_wrapper(batch):
    return BatchLoader(batch)

class BatchLoaderRandom(BatchLoader):
    def __init__(self, data):
        super().__init__(data)

    def pin_memory(self):
        return self

def collate_wrapper_random(batch):
    return BatchLoaderRandom(batch)

TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"])
textual_inversion_templates = {}

class Embedding:
    def __init__(self, vec, name, step=None):
        self.vec = vec
        self.name = name
        self.step = step
        self.shape = None
        self.vectors = 0
        self.cached_checksum = None
        self.sd_checkpoint = None
        self.sd_checkpoint_name = None
        self.optimizer_state_dict = None
        self.filename = None

    def save(self, filename):
        embedding_data = {
            "string_to_token": {"*": 265},
            "string_to_param": {"*": self.vec},
            "name": self.name,
            "step": self.step,
            "sd_checkpoint": self.sd_checkpoint,
            "sd_checkpoint_name": self.sd_checkpoint_name,
        }

        torch.save(embedding_data, filename)

        save_optimizer_state=True
        if save_optimizer_state and self.optimizer_state_dict is not None:
            optimizer_saved_dict = {
                'hash': self.checksum(),
                'optimizer_state_dict': self.optimizer_state_dict,
            }
            torch.save(optimizer_saved_dict, filename + '.optim')

    def checksum(self):
        if self.cached_checksum is not None:
            return self.cached_checksum

        def const_hash(a):
            r = 0
            for v in a:
                r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
            return r

        self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
        return self.cached_checksum


class DirWithTextualInversionEmbeddings:
    def __init__(self, path):
        self.path = path
        self.mtime = None

    def has_changed(self):
        if not os.path.isdir(self.path):
            return False

        mt = os.path.getmtime(self.path)
        if self.mtime is None or mt > self.mtime:
            return True

    def update(self):
        if not os.path.isdir(self.path):
            return

        self.mtime = os.path.getmtime(self.path)


class EmbeddingDatabase:
    def __init__(self):
        self.ids_lookup = {}
        self.word_embeddings = {}
        self.skipped_embeddings = {}
        self.expected_shape = -1
        self.embedding_dirs = {}

    def add_embedding_dir(self, path):
        self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)

    def clear_embedding_dirs(self):
        self.embedding_dirs.clear()

    def register_embedding(self, embedding, model):
        self.word_embeddings[embedding.name] = embedding

        ids = model.cond_stage_model.tokenize([embedding.name])[0]

        first_id = ids[0]
        if first_id not in self.ids_lookup:
            self.ids_lookup[first_id] = []

        self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)

        return embedding

    def get_expected_shape(self):
        vec = root.model.cond_stage_model.encode_embedding_init_text(",", 1)
        return vec.shape[1]

    def load_from_file(self, path, filename):
        name, ext = os.path.splitext(filename)
        ext = ext.upper()

        if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
            _, second_ext = os.path.splitext(name)
            if second_ext.upper() == '.PREVIEW':
                return

            embed_image = Image.open(path)
            if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
                data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
                name = data.get('name', name)
            else:
                data = extract_image_data_embed(embed_image)
                name = data.get('name', name)
        elif ext in ['.BIN', '.PT']:
            data = torch.load(path, map_location="cpu")
        elif ext in ['.SAFETENSORS']:
            data = safetensors.torch.load_file(path, device="cpu")
        else:
            return

        # textual inversion embeddings
        if 'string_to_param' in data:
            param_dict = data['string_to_param']
            if hasattr(param_dict, '_parameters'):
                param_dict = getattr(param_dict, '_parameters')  # fix for torch 1.12.1 loading saved file from torch 1.11
            assert len(param_dict) == 1, 'embedding file has multiple terms in it'
            emb = next(iter(param_dict.items()))[1]
        # diffuser concepts
        elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
            assert len(data.keys()) == 1, 'embedding file has multiple terms in it'

            emb = next(iter(data.values()))
            if len(emb.shape) == 1:
                emb = emb.unsqueeze(0)
        else:
            raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")

        vec = emb.detach().to(root.device, dtype=torch.float32)
        embedding = Embedding(vec, name)
        embedding.step = data.get('step', None)
        embedding.sd_checkpoint = data.get('sd_checkpoint', None)
        embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
        embedding.vectors = vec.shape[0]
        embedding.shape = vec.shape[-1]
        embedding.filename = path

        if self.expected_shape == -1 or self.expected_shape == embedding.shape:
            self.register_embedding(embedding, root.model)
        else:
            self.skipped_embeddings[name] = embedding

    def load_from_dir(self, embdir):
        if not os.path.isdir(embdir.path):
            return

        for root, dirs, fns in os.walk(embdir.path):
            for fn in fns:
                try:
                    fullfn = os.path.join(root, fn)

                    if os.stat(fullfn).st_size == 0:
                        continue

                    self.load_from_file(fullfn, fn)
                except Exception:
                    print(f"Error loading embedding {fn}:", file=sys.stderr)
                    print(traceback.format_exc(), file=sys.stderr)
                    continue

    def load_textual_inversion_embeddings(self, force_reload=False):
        if not force_reload:
            need_reload = False
            for path, embdir in self.embedding_dirs.items():
                if embdir.has_changed():
                    need_reload = True
                    break

            if not need_reload:
                return

        self.ids_lookup.clear()
        self.word_embeddings.clear()
        self.skipped_embeddings.clear()
        self.expected_shape = self.get_expected_shape()

        for path, embdir in self.embedding_dirs.items():
            self.load_from_dir(embdir)
            embdir.update()

        print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
        if len(self.skipped_embeddings) > 0:
            print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")

    def find_embedding_at_position(self, tokens, offset):
        token = tokens[offset]
        possible_matches = self.ids_lookup.get(token, None)

        if possible_matches is None:
            return None, None

        for ids, embedding in possible_matches:
            if tokens[offset:offset + len(ids)] == ids:
                return embedding, len(ids)

        return None, None

#train embeddings
def autocast(disable=False):
    
    if disable:
        return contextlib.nullcontext()

    if dtype == torch.float32 or precision == "full":
        return contextlib.nullcontext()

    return torch.autocast("cuda")

##markdown $\color{blue} {\textsf{These Params are FOR TEXTUAL INVERSION TRAINING ONLY!}}$

def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
    cond_model = model.cond_stage_model

    with autocast():
        cond_model([""])  # will send cond model to GPU if lowvram/medvram is active

    #cond_model expects at least some text, so we provide '*' as backup.
    embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token)
    vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=root.device)

    #Only copy if we provided an init_text, otherwise keep vectors as zeros
    if init_text:
        for i in range(num_vectors_per_token):
            vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]

    # Remove illegal characters from name.
    name = "".join( x for x in name if (x.isalnum() or x in "._- "))
    fn = os.path.join(root.embeddings_dir, f"{name}.pt")
    if not overwrite_old:
        assert not os.path.exists(fn), f"file {fn} already exists"

    embedding = Embedding(vec, name)
    embedding.step = 0
    embedding.save(fn)

    return fn

training_write_csv_every=0
def write_loss(log_directory, filename, step, epoch_len, values):
    if training_write_csv_every == 0:
        return

    if step % training_write_csv_every != 0:
        return
    write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True

    with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
        csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())])

        if write_csv_header:
            csv_writer.writeheader()

        epoch = (step - 1) // epoch_len
        epoch_step = (step - 1) % epoch_len

        csv_writer.writerow({
            "step": step,
            "epoch": epoch,
            "epoch_step": epoch_step,
            **values,
        })

def tensorboard_setup(log_directory):
    os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
    training_tensorboard_flush_every=120
    return SummaryWriter(
            log_dir=os.path.join(log_directory, "tensorboard"),
            flush_secs=training_tensorboard_flush_every)

def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epoch_num):
    tensorboard_add_scaler(tensorboard_writer, "Loss/train", loss, global_step)
    tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", loss, step)
    tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", learn_rate, global_step)
    tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)

def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
    tensorboard_writer.add_scalar(tag=tag, 
        scalar_value=value, global_step=step)

def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
    # Convert a pil image to a torch tensor
    img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
    img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], 
        len(pil_image.getbands()))
    img_tensor = img_tensor.permute((2, 0, 1))
                
    tensorboard_writer.add_image(tag, img_tensor, global_step=step)

def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
    assert model_name, f"{name} not selected"
    assert learn_rate, "Learning rate is empty or 0"
    assert isinstance(batch_size, int), "Batch size must be integer"
    assert batch_size > 0, "Batch size must be positive"
    assert isinstance(gradient_step, int), "Gradient accumulation step must be integer"
    assert gradient_step > 0, "Gradient accumulation step must be positive"
    assert data_root, "Dataset directory is empty"
    assert os.path.isdir(data_root), "Dataset directory doesn't exist"
    assert os.listdir(data_root), "Dataset directory is empty"
    assert template_filename, "Prompt template file not selected"
    assert template_file, f"Prompt template file {template_filename} not found"
    assert os.path.isfile(template_file), f"Prompt template file {template_filename} doesn't exist"
    assert steps, "Max steps is empty or 0"
    assert isinstance(steps, int), "Max steps must be integer"
    assert steps > 0, "Max steps must be positive"
    assert isinstance(save_model_every, int), "Save {name} must be integer"
    assert save_model_every >= 0, "Save {name} must be positive or 0"
    assert isinstance(create_image_every, int), "Create image must be integer"
    assert create_image_every >= 0, "Create image must be positive or 0"
    if save_model_every or create_image_every:
        assert log_directory, "Log directory is empty"

import os
def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
    save_embedding_every = save_embedding_every or 0
    create_image_every = create_image_every or 0
    template_file = template_filename
    validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
    template_file = template_filename


    job = "train-embedding"
    textinfo = "Initializing textual inversion training..."
    job_count = steps
    import os
    import datetime
    filename = os.path.join(root.embeddings_dir, f'{embedding_name}.pt')

    log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
    unload_models_when_training=False
    unload = unload_models_when_training

    if save_embedding_every > 0:
        embedding_dir = os.path.join(log_directory, "embeddings")
        os.makedirs(embedding_dir, exist_ok=True)
    else:
        embedding_dir = None

    if create_image_every > 0:
        images_dir = os.path.join(log_directory, "images")
        os.makedirs(images_dir, exist_ok=True)
    else:
        images_dir = None

    if create_image_every > 0 and save_image_with_stored_embedding:
        images_embeds_dir = os.path.join(log_directory, "image_embeddings")
        os.makedirs(images_embeds_dir, exist_ok=True)
    else:
        images_embeds_dir = None

    hijack = model_hijack

    embedding = hijack.embedding_db.word_embeddings[embedding_name]
    checkpoint = root.model

    initial_step = embedding.step or 0
    if initial_step >= steps:
        extinfo = "Model has already been trained beyond specified max steps"
        return embedding, filename
    
    scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
    clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
        torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
        None
    if clip_grad:
        clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
    # dataset loading may take a while, so input validations and early returns should be done before this
    textinfo = f"Preparing dataset from {html.escape(data_root)}..."
    parallel_processing_allowed=False
    old_parallel_processing_allowed = parallel_processing_allowed
    
    training_enable_tensorboard=True
    if training_enable_tensorboard:
        tensorboard_writer = tensorboard_setup(log_directory)

    pin_memory = False

    ds = PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=1, placeholder_token=embedding_name, model=model, cond_model=model.cond_stage_model, device=root.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)

    import datetime
    import json
    import os

    saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "gradient_step", "latent_sampling_method"}
    saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
    saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
    saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
    saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"}


    def save_settings_to_file(log_directory, all_params):
        now = datetime.datetime.now()
        params = {"datetime": now.strftime("%Y-%m-%d %H:%M:%S")}

        keys = saved_params_all
        if all_params.get('preview_from_txt2img'):
          keys = keys | saved_params_previews

        params.update({k: v for k, v in all_params.items() if k in keys})

        filename = f'settings-{now.strftime("%Y-%m-%d-%H-%M-%S")}.json'
        with open(os.path.join(log_directory, filename), "w") as file:
            json.dump(params, file, indent=4)

    save_training_settings_to_txt=True
    if save_training_settings_to_txt:
        save_settings_to_file(log_directory, {**dict(model_name=root.model_checkpoint, model_hash=ckpt_hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})

    latent_sampling_method = ds.latent_sampling_method

    dl = PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)

    if unload:
        parallel_processing_allowed = False
        model.first_stage_model.to("cuda")

    embedding.vec.requires_grad = True
    optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
    save_optimizer_state=True
    if save_optimizer_state:
        optimizer_state_dict = None
        if os.path.exists(filename + '.optim'):
            optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
            if embedding.checksum() == optimizer_saved_dict.get('hash', None):
                optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
    
        if optimizer_state_dict is not None:
            optimizer.load_state_dict(optimizer_state_dict)
            print("Loaded existing optimizer from checkpoint")
        else:
            print("No saved optimizer exists in checkpoint")

    scaler = torch.cuda.amp.GradScaler()

    batch_size = ds.batch_size
    gradient_step = ds.gradient_step
    # n steps = batch_size * gradient_step * n image processed
    steps_per_epoch = len(ds) // batch_size // gradient_step
    max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
    loss_step = 0
    _loss_step = 0 #internal

    last_saved_file = "<none>"
    last_saved_image = "<none>"
    forced_filename = "<none>"
    embedding_yet_to_be_embedded = False

    # is_training_inpainting_model = root.model.conditioning_key in {'hybrid', 'concat'}
    img_c = None

   ## hijack_checkpoint

    from torch.utils.checkpoint import checkpoint

    import ldm.modules.attention
    import ldm.modules.diffusionmodules.openaimodel


    def BasicTransformerBlock_forward(self, x, context=None):
        return checkpoint(self._forward, x, context)


    def AttentionBlock_forward(self, x):
        return checkpoint(self._forward, x)


    def ResBlock_forward(self, x, emb):
        return checkpoint(self._forward, x, emb)


    stored = []


    def hijack_checkpoint_add():
        if len(stored) != 0:
            return

        stored.extend([
            ldm.modules.attention.BasicTransformerBlock.forward,
            ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
            ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
        ])

        ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
        ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
        ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward


    def hijack_checkpoint_remove():
        if len(stored) == 0:
            return

        ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
        ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
        ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]

        stored.clear()

    pbar = tqdm.tqdm(total=steps - initial_step)
    try:
        hijack_checkpoint_add()

        for i in range((steps-initial_step) * gradient_step):
            # if scheduler.finished:
            #     break
            # if interrupted:
            #     break
            for j, batch in enumerate(dl):
                # works as a drop_last=True for gradient accumulation
                if j == max_steps_per_epoch:
                    break
                scheduler.apply(optimizer, embedding.step)
                # if scheduler.finished:
                #     break
                # if shared.state.interrupted:
                #     break

                if clip_grad:
                    clip_grad_sched.step(embedding.step)
                with torch.autocast("cuda"):
                    x = batch.latent_sample.to("cuda", non_blocking=pin_memory)
                    c = model.cond_stage_model(batch.cond_text)
                    is_training_inpainting_model=False
                    if is_training_inpainting_model:
                        if img_c is None:
                            img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)

                        cond = {"c_concat": [img_c], "c_crossattn": [c]}
                    else:
                        cond = c
                    cond = cond.to("cuda")
                    x = x.to("cuda")
                    loss = model(x, cond)[0] / gradient_step
                    del x

                    _loss_step += loss.item()
                scaler.scale(loss).backward()

                # go back until we reach gradient accumulation steps
                if (j + 1) % gradient_step != 0:
                    continue
                
                if clip_grad:
                    clip_grad(embedding.vec, clip_grad_sched.learn_rate)

                scaler.step(optimizer)
                scaler.update()
                embedding.step += 1
                pbar.update()
                optimizer.zero_grad(set_to_none=True)
                loss_step = _loss_step
                _loss_step = 0

                steps_done = embedding.step + 1

                epoch_num = embedding.step // steps_per_epoch
                epoch_step = embedding.step % steps_per_epoch

                description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}"
                pbar.set_description(description)
                if embedding_dir is not None and steps_done % save_embedding_every == 0:
                    # Before saving, change name to match current checkpoint.
                    embedding_name_every = f'{embedding_name}-{steps_done}'
                    last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
                    save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
                    embedding_yet_to_be_embedded = True

                write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
                    "loss": f"{loss_step:.7f}",
                    "learn_rate": scheduler.learn_rate
                })

                training_tensorboard_save_images=False

                if images_dir is not None and steps_done % create_image_every == 0:
                    forced_filename = f'{embedding_name}-{steps_done}'
                    last_saved_image = os.path.join(images_dir, forced_filename)

                    model.first_stage_model.to(root.device)

                    # p = processing.StableDiffusionProcessingTxt2Img(
                    #     sd_model=shared.sd_model,
                    #     do_not_save_grid=True,
                    #     do_not_save_samples=True,
                    #     do_not_reload_embeddings=True,
                    # )

                    if preview_from_txt2img:
                        p.prompt = preview_prompt
                        p.negative_prompt = preview_negative_prompt
                        p.steps = preview_steps
                        p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
                        p.cfg_scale = preview_cfg_scale
                        p.seed = preview_seed
                        p.width = preview_width
                        p.height = preview_height
                    else:
                        prompt = batch.cond_text[0]
                        steps = 20
                        width = training_width
                        height = training_height

                    preview_text = prompt

                    # processed = processing.process_images(p)
                    # image = processed.images[0] if len(processed.images) > 0 else None

                    # def assign_current_image(self, image):
                    #     self.current_image = image
                    #     self.id_live_preview += 1
                    # if unload:
                    #     model.first_stage_model.to("cuda")

                    # if image is not None:
                    #     assign_current_image(image)

                    #     last_saved_image, last_text_info = images.save_image(image, images_dir, "", seed, prompt, samples_format="png", processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
                    #     last_saved_image += f", prompt: {preview_text}"

                    #     if training_enable_tensorboard and training_tensorboard_save_images:
                    #         tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)

                    # if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:

                    #     last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')

                    #     info = PngImagePlugin.PngInfo()
                    #     data = torch.load(last_saved_file)
                    #     info.add_text("sd-ti-embedding", embedding_to_b64(data))

                    #     title = "<{}>".format(data.get('name', '???'))

                    #     try:
                    #         vectorSize = list(data['string_to_param'].values())[0].shape[0]
                    #     except Exception as e:
                    #         vectorSize = '?'

                    #     checkpoint = root.model_checkpoint
                    #     footer_left = checkpoint.model_name
                    #     footer_mid = '[{}]'.format(checkpoint.shorthash)
                    #     footer_right = '{}v {}s'.format(vectorSize, steps_done)

                    #     captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
                    #     captioned_image = insert_image_data_embed(captioned_image, data)

                    #     captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
                    #     embedding_yet_to_be_embedded = False

                    # last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
                    # last_saved_image += f", prompt: {preview_text}"

                job_no = embedding.step

                textinfo = f"""
# <p>
# Loss: {loss_step:.7f}<br/>
# Step: {steps_done}<br/>
# Last prompt: {html.escape(batch.cond_text[0])}<br/>
# Last saved embedding: {html.escape(last_saved_file)}<br/>
# Last saved image: {html.escape(last_saved_image)}<br/>
# </p>
# """
        filename = os.path.join(root.embeddings_dir, f'{embedding_name}.pt')
        save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
    except Exception:
        print(traceback.format_exc(), file=sys.stderr)
        pass
    finally:
        pbar.leave = False
        pbar.close()
        model.first_stage_model.to(root.device)
        parallel_processing_allowed = old_parallel_processing_allowed
        hijack_checkpoint_remove()

    return embedding, filename


def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
    old_embedding_name = embedding.name
    old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
    old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
    old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
    try:
        embedding.sd_checkpoint = checkpoint.shorthash
        embedding.sd_checkpoint_name = checkpoint.model_name
        if remove_cached_checksum:
            embedding.cached_checksum = None
        embedding.name = embedding_name
        embedding.optimizer_state_dict = optimizer.state_dict()
        embedding.save(filename)
    except:
        embedding.sd_checkpoint = old_sd_checkpoint
        embedding.sd_checkpoint_name = old_sd_checkpoint_name
        embedding.name = old_embedding_name
        embedding.cached_checksum = old_cached_checksum
        raise

#cond_func
import importlib

class CondFunc:
    def __new__(cls, orig_func, sub_func, cond_func):
        self = super(CondFunc, cls).__new__(cls)
        if isinstance(orig_func, str):
            func_path = orig_func.split('.')
            for i in range(len(func_path)-1, -1, -1):
                try:
                    resolved_obj = importlib.import_module('.'.join(func_path[:i]))
                    break
                except ImportError:
                    pass
            for attr_name in func_path[i:-1]:
                resolved_obj = getattr(resolved_obj, attr_name)
            orig_func = getattr(resolved_obj, func_path[-1])
            setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
        self.__init__(orig_func, sub_func, cond_func)
        return lambda *args, **kwargs: self(*args, **kwargs)
    def __init__(self, orig_func, sub_func, cond_func):
        self.__orig_func = orig_func
        self.__sub_func = sub_func
        self.__cond_func = cond_func
    def __call__(self, *args, **kwargs):
        if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
            return self.__sub_func(self.__orig_func, *args, **kwargs)
        else:
            return self.__orig_func(*args, **kwargs)

#errors
import sys
import traceback


def print_error_explanation(message):
    lines = message.strip().split("\n")
    max_len = max([len(x) for x in lines])

    print('=' * max_len, file=sys.stderr)
    for line in lines:
        print(line, file=sys.stderr)
    print('=' * max_len, file=sys.stderr)


def display(e: Exception, task):
    print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
    print(traceback.format_exc(), file=sys.stderr)

    message = str(e)
    if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
        print_error_explanation("""
The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its config file.
See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
        """)


already_displayed = {}


def display_once(e: Exception, task):
    if task in already_displayed:
        return

    display(e, task)

    already_displayed[task] = 1


def run(code, task):
    try:
        code()
    except Exception as e:
        display(task, e)

#sub_quad_attn
from functools import partial
import torch
from torch import Tensor
from torch.utils.checkpoint import checkpoint
import math
from typing import Optional, NamedTuple, List

def without_autocast(disable=False):
    return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()

def narrow_trunc(
    input: Tensor,
    dim: int,
    start: int,
    length: int
) -> Tensor:
    return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)


class AttnChunk(NamedTuple):
    exp_values: Tensor
    exp_weights_sum: Tensor
    max_score: Tensor


class SummarizeChunk:
    @staticmethod
    def __call__(
        query: Tensor,
        key: Tensor,
        value: Tensor,
    ) -> AttnChunk: ...


class ComputeQueryChunkAttn:
    @staticmethod
    def __call__(
        query: Tensor,
        key: Tensor,
        value: Tensor,
    ) -> Tensor: ...


def _summarize_chunk(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    scale: float,
) -> AttnChunk:
    attn_weights = torch.baddbmm(
        torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
        query,
        key.transpose(1,2),
        alpha=scale,
        beta=0,
    )
    max_score, _ = torch.max(attn_weights, -1, keepdim=True)
    max_score = max_score.detach()
    exp_weights = torch.exp(attn_weights - max_score)
    exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype)
    max_score = max_score.squeeze(-1)
    return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)


def _query_chunk_attention(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    summarize_chunk: SummarizeChunk,
    kv_chunk_size: int,
) -> Tensor:
    batch_x_heads, k_tokens, k_channels_per_head = key.shape
    _, _, v_channels_per_head = value.shape

    def chunk_scanner(chunk_idx: int) -> AttnChunk:
        key_chunk = narrow_trunc(
            key,
            1,
            chunk_idx,
            kv_chunk_size
        )
        value_chunk = narrow_trunc(
            value,
            1,
            chunk_idx,
            kv_chunk_size
        )
        return summarize_chunk(query, key_chunk, value_chunk)

    chunks: List[AttnChunk] = [
        chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
    ]
    acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
    chunk_values, chunk_weights, chunk_max = acc_chunk

    global_max, _ = torch.max(chunk_max, 0, keepdim=True)
    max_diffs = torch.exp(chunk_max - global_max)
    chunk_values *= torch.unsqueeze(max_diffs, -1)
    chunk_weights *= max_diffs

    all_values = chunk_values.sum(dim=0)
    all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
    return all_values / all_weights


# TODO: refactor CrossAttention#get_attention_scores to share code with this
def _get_attention_scores_no_kv_chunking(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    scale: float,
) -> Tensor:
    attn_scores = torch.baddbmm(
        torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
        query,
        key.transpose(1,2),
        alpha=scale,
        beta=0,
    )
    attn_probs = attn_scores.softmax(dim=-1)
    del attn_scores
    hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype)
    return hidden_states_slice

#sub_quadratic_attention
class ScannedChunk(NamedTuple):
    chunk_idx: int
    attn_chunk: AttnChunk

def efficient_dot_product_attention(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    query_chunk_size=1024,
    kv_chunk_size: Optional[int] = None,
    kv_chunk_size_min: Optional[int] = None,
    use_checkpoint=True,
):
    """Computes efficient dot-product attention given query, key, and value.
      This is efficient version of attention presented in
      https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
      Args:
        query: queries for calculating attention with shape of
          `[batch * num_heads, tokens, channels_per_head]`.
        key: keys for calculating attention with shape of
          `[batch * num_heads, tokens, channels_per_head]`.
        value: values to be used in attention with shape of
          `[batch * num_heads, tokens, channels_per_head]`.
        query_chunk_size: int: query chunks size
        kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
        kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
        use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
      Returns:
        Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
      """
    batch_x_heads, q_tokens, q_channels_per_head = query.shape
    _, k_tokens, _ = key.shape
    scale = q_channels_per_head ** -0.5

    kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
    if kv_chunk_size_min is not None:
        kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)

    def get_query_chunk(chunk_idx: int) -> Tensor:
        return narrow_trunc(
            query,
            1,
            chunk_idx,
            min(query_chunk_size, q_tokens)
        )
    
    summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
    summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
    compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
        _get_attention_scores_no_kv_chunking,
        scale=scale
    ) if k_tokens <= kv_chunk_size else (
        # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
        partial(
            _query_chunk_attention,
            kv_chunk_size=kv_chunk_size,
            summarize_chunk=summarize_chunk,
        )
    )

    if q_tokens <= query_chunk_size:
        # fast-path for when there's just 1 query chunk
        return compute_query_chunk_attn(
            query=query,
            key=key,
            value=value,
        )
    
    # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
    # and pass slices to be mutated, instead of torch.cat()ing the returned slices
    res = torch.cat([
        compute_query_chunk_attn(
            query=get_query_chunk(i * query_chunk_size),
            key=key,
            value=value,
        ) for i in range(math.ceil(q_tokens / query_chunk_size))
    ], dim=1)
    return res

#cache
import hashlib
import json
import os.path

import filelock

cache_filename = os.path.join(root.data_dir, "cache.json")
cache_data = None


def dump_cache():
    with filelock.FileLock(cache_filename+".lock"):
        with open(cache_filename, "w", encoding="utf8") as file:
            json.dump(cache_data, file, indent=4)


def cache(subsection):
    global cache_data

    if cache_data is None:
        with filelock.FileLock(cache_filename+".lock"):
            if not os.path.isfile(cache_filename):
                cache_data = {}
            else:
                with open(cache_filename, "r", encoding="utf8") as file:
                    cache_data = json.load(file)

    s = cache_data.get(subsection, {})
    cache_data[subsection] = s

    return s


def calculate_sha256(filename):
    hash_sha256 = hashlib.sha256()
    blksize = 1024 * 1024

    with open(filename, "rb") as f:
        for chunk in iter(lambda: f.read(blksize), b""):
            hash_sha256.update(chunk)

    return hash_sha256.hexdigest()


def sha256_from_cache(filename, title):
    hashes = cache("hashes")
    ondisk_mtime = os.path.getmtime(filename)

    if title not in hashes:
        return None

    cached_sha256 = hashes[title].get("sha256", None)
    cached_mtime = hashes[title].get("mtime", 0)

    if ondisk_mtime > cached_mtime or cached_sha256 is None:
        return None

    return cached_sha256


def sha256(filename, title):
    hashes = cache("hashes")

    sha256_value = sha256_from_cache(filename, title)
    if sha256_value is not None:
        return sha256_value

    print(f"Calculating sha256 for {filename}: ", end='')
    sha256_value = calculate_sha256(filename)
    print(f"{sha256_value}")

    hashes[title] = {
        "mtime": os.path.getmtime(filename),
        "sha256": sha256_value,
    }

    dump_cache()

    return sha256_value

#hypernetwork
loaded_hypernetworks=[]
import csv
import datetime
import glob
import html
import os
import sys
import traceback
import inspect

import torch
import tqdm
from einops import rearrange, repeat
from ldm.util import default
from torch import einsum
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_

from collections import defaultdict, deque
from statistics import stdev, mean


optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}

class HypernetworkModule(torch.nn.Module):
    activation_dict = {
        "linear": torch.nn.Identity,
        "relu": torch.nn.ReLU,
        "leakyrelu": torch.nn.LeakyReLU,
        "elu": torch.nn.ELU,
        "swish": torch.nn.Hardswish,
        "tanh": torch.nn.Tanh,
        "sigmoid": torch.nn.Sigmoid,
    }
    activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})

    def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
                 add_layer_norm=False, activate_output=False, dropout_structure=None):
        super().__init__()

        self.multiplier = 1.0

        assert layer_structure is not None, "layer_structure must not be None"
        assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
        assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"

        linears = []
        for i in range(len(layer_structure) - 1):

            # Add a fully-connected layer
            linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))

            # Add an activation func except last layer
            if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
                pass
            elif activation_func in self.activation_dict:
                linears.append(self.activation_dict[activation_func]())
            else:
                raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')

            # Add layer normalization
            if add_layer_norm:
                linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))

            # Everything should be now parsed into dropout structure, and applied here.
            # Since we only have dropouts after layers, dropout structure should start with 0 and end with 0.
            if dropout_structure is not None and dropout_structure[i+1] > 0:
                assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!"
                linears.append(torch.nn.Dropout(p=dropout_structure[i+1]))
            # Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0].

        self.linear = torch.nn.Sequential(*linears)

        if state_dict is not None:
            self.fix_old_state_dict(state_dict)
            self.load_state_dict(state_dict)
        else:
            for layer in self.linear:
                if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
                    w, b = layer.weight.data, layer.bias.data
                    if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
                        normal_(w, mean=0.0, std=0.01)
                        normal_(b, mean=0.0, std=0)
                    elif weight_init == 'XavierUniform':
                        xavier_uniform_(w)
                        zeros_(b)
                    elif weight_init == 'XavierNormal':
                        xavier_normal_(w)
                        zeros_(b)
                    elif weight_init == 'KaimingUniform':
                        kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
                        zeros_(b)
                    elif weight_init == 'KaimingNormal':
                        kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
                        zeros_(b)
                    else:
                        raise KeyError(f"Key {weight_init} is not defined as initialization!")
        self.to(root.device)

    def fix_old_state_dict(self, state_dict):
        changes = {
            'linear1.bias': 'linear.0.bias',
            'linear1.weight': 'linear.0.weight',
            'linear2.bias': 'linear.1.bias',
            'linear2.weight': 'linear.1.weight',
        }

        for fr, to in changes.items():
            x = state_dict.get(fr, None)
            if x is None:
                continue

            del state_dict[fr]
            state_dict[to] = x

    def forward(self, x):
        return x + self.linear(x) * (self.multiplier if not self.training else 1)

    def trainables(self):
        layer_structure = []
        for layer in self.linear:
            if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
                layer_structure += [layer.weight, layer.bias]
        return layer_structure


#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
    if layer_structure is None:
        layer_structure = [1, 2, 1]
    if not use_dropout:
        return [0] * len(layer_structure)
    dropout_values = [0]
    dropout_values.extend([0.3] * (len(layer_structure) - 3))
    if last_layer_dropout:
        dropout_values.append(0.3)
    else:
        dropout_values.append(0)
    dropout_values.append(0)
    return dropout_values

def shorthash(self):
        sha256 = sha256(self.filename, f'hypernet/{self.name}')

        return sha256[0:10]

class Hypernetwork:
    filename = None
    name = None

    def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
        self.filename = None
        self.name = name
        self.layers = {}
        self.step = 0
        self.sd_checkpoint = None
        self.sd_checkpoint_name = None
        self.layer_structure = layer_structure
        self.activation_func = activation_func
        self.weight_init = weight_init
        self.add_layer_norm = add_layer_norm
        self.use_dropout = use_dropout
        self.activate_output = activate_output
        self.last_layer_dropout = kwargs.get('last_layer_dropout', True)
        self.dropout_structure = kwargs.get('dropout_structure', None)
        if self.dropout_structure is None:
            self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
        self.optimizer_name = None
        self.optimizer_state_dict = None
        self.optional_info = None

        for size in enable_sizes or []:
            self.layers[size] = (
                HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
                                   self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
                HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
                                   self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
            )
        self.eval()

    def weights(self):
        res = []
        for k, layers in self.layers.items():
            for layer in layers:
                res += layer.parameters()
        return res

    def train(self, mode=True):
        for k, layers in self.layers.items():
            for layer in layers:
                layer.train(mode=mode)
                for param in layer.parameters():
                    param.requires_grad = mode

    def to(self, device):
        for k, layers in self.layers.items():
            for layer in layers:
                layer.to(device)

        return self

    def set_multiplier(self, multiplier):
        for k, layers in self.layers.items():
            for layer in layers:
                layer.multiplier = multiplier

        return self

    def eval(self):
        for k, layers in self.layers.items():
            for layer in layers:
                layer.eval()
                for param in layer.parameters():
                    param.requires_grad = False

    def save(self, filename):
        state_dict = {}
        optimizer_saved_dict = {}

        for k, v in self.layers.items():
            state_dict[k] = (v[0].state_dict(), v[1].state_dict())

        state_dict['step'] = self.step
        state_dict['name'] = self.name
        state_dict['layer_structure'] = self.layer_structure
        state_dict['activation_func'] = self.activation_func
        state_dict['is_layer_norm'] = self.add_layer_norm
        state_dict['weight_initialization'] = self.weight_init
        state_dict['sd_checkpoint'] = self.sd_checkpoint
        state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
        state_dict['activate_output'] = self.activate_output
        state_dict['use_dropout'] = self.use_dropout
        state_dict['dropout_structure'] = self.dropout_structure
        state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout
        state_dict['optional_info'] = self.optional_info if self.optional_info else None

        if self.optimizer_name is not None:
            optimizer_saved_dict['optimizer_name'] = self.optimizer_name

        torch.save(state_dict, filename)
        save_optimizer_state=True
        if save_optimizer_state and self.optimizer_state_dict:
            optimizer_saved_dict['hash'] = self.shorthash()
            optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
            torch.save(optimizer_saved_dict, filename + '.optim')
    
    def sha256(filename, title):
        hashes = cache("hashes")

        sha256_value = sha256_from_cache(filename, title)
        if sha256_value is not None:
            return sha256_value

        print(f"Calculating sha256 for {filename}: ", end='')
        sha256_value = calculate_sha256(filename)
        print(f"{sha256_value}")

        hashes[title] = {
            "mtime": os.path.getmtime(filename),
            "sha256": sha256_value,
        }

        dump_cache()

        return sha256_value
    def load(self, filename):
        self.filename = filename
        if self.name is None:
            self.name = os.path.splitext(os.path.basename(filename))[0]

        state_dict = torch.load(filename, map_location='cpu')

        self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
        self.optional_info = state_dict.get('optional_info', None)
        self.activation_func = state_dict.get('activation_func', None)
        self.weight_init = state_dict.get('weight_initialization', 'Normal')
        self.add_layer_norm = state_dict.get('is_layer_norm', False)
        self.dropout_structure = state_dict.get('dropout_structure', None)
        self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
        self.activate_output = state_dict.get('activate_output', True)
        self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
        # Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
        if self.dropout_structure is None:
            self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)

        print_hypernet_extra=True
        if print_hypernet_extra:
            if self.optional_info is not None:
                print(f"  INFO:\n {self.optional_info}\n")

            print(f"  Layer structure: {self.layer_structure}")
            print(f"  Activation function: {self.activation_func}")
            print(f"  Weight initialization: {self.weight_init}")
            print(f"  Layer norm: {self.add_layer_norm}")
            print(f"  Dropout usage: {self.use_dropout}" )
            print(f"  Activate last layer: {self.activate_output}")
            print(f"  Dropout structure: {self.dropout_structure}")

        optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {}

        if self.shorthash() == optimizer_saved_dict.get('hash', None):
            self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
        else:
            self.optimizer_state_dict = None
        if self.optimizer_state_dict:
            self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
            if print_hypernet_extra:
                print("Loaded existing optimizer from checkpoint")
                print(f"Optimizer name is {self.optimizer_name}")
        else:
            self.optimizer_name = "AdamW"
            if print_hypernet_extra:
                print("No saved optimizer exists in checkpoint")

        for size, sd in state_dict.items():
            if type(size) == int:
                self.layers[size] = (
                    HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
                                       self.add_layer_norm, self.activate_output, self.dropout_structure),
                    HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
                                       self.add_layer_norm, self.activate_output, self.dropout_structure),
                )

        self.name = state_dict.get('name', self.name)
        self.step = state_dict.get('step', 0)
        self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
        self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
        self.eval()
    
    import hashlib

    def shorthash(self):
        sha256 = hashlib.sha256()
        sha256.update(f'{self.filename}hypernet/{self.name}'.encode())
        print(f"sha256= ", sha256.hexdigest()[0:10])
        return sha256.hexdigest()[0:10]


def list_hypernetworks(path):
    res = {}
    for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
        name = os.path.splitext(os.path.basename(filename))[0]
        # Prevent a hypothetical "None.pt" from being listed.
        if name != "None":
            res[name] = filename
    return res


def load_hypernetwork(name):
    path = hypernetworks.get(name, None)

    if path is None:
        return None

    hypernetwork = Hypernetwork()

    try:
        hypernetwork.load(path)
    except Exception:
        print(f"Error loading hypernetwork {path}", file=sys.stderr)
        print(traceback.format_exc(), file=sys.stderr)
        return None

    return hypernetwork


def load_hypernetworks(names, multipliers=None):
    already_loaded = {}

    for hypernetwork in loaded_hypernetworks:
        if hypernetwork.name in names:
            already_loaded[hypernetwork.name] = hypernetwork

    loaded_hypernetworks.clear()

    for i, name in enumerate(names):
        hypernetwork = already_loaded.get(name, None)
        if hypernetwork is None:
            hypernetwork = load_hypernetwork(name)

        if hypernetwork is None:
            continue

        hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
        loaded_hypernetworks.append(hypernetwork)
    print(f"Hypernetworks Loaded", names)

def find_closest_hypernetwork_name(search: str):
    if not search:
        return None
    search = search.lower()
    applicable = [name for name in hypernetworks if search in name.lower()]
    if not applicable:
        return None
    applicable = sorted(applicable, key=lambda name: len(name))
    return applicable[0]


def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
    hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)

    if hypernetwork_layers is None:
        return context_k, context_v

    if layer is not None:
        layer.hyper_k = hypernetwork_layers[0]
        layer.hyper_v = hypernetwork_layers[1]

    context_k = hypernetwork_layers[0](context_k)
    context_v = hypernetwork_layers[1](context_v)
    return context_k, context_v


def apply_hypernetworks(hypernetworks, context, layer=None):
    context_k = context
    context_v = context
    for hypernetwork in hypernetworks:
        context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)

    return context_k, context_v


def attention_CrossAttention_forward(self, x, context=None, mask=None):
    h = self.heads

    q = self.to_q(x)
    context = default(context, x)

    context_k, context_v = apply_hypernetworks(loaded_hypernetworks, context, self)
    k = self.to_k(context_k)
    v = self.to_v(context_v)

    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

    sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

    if mask is not None:
        mask = rearrange(mask, 'b ... -> b (...)')
        max_neg_value = -torch.finfo(sim.dtype).max
        mask = repeat(mask, 'b j -> (b h) () j', h=h)
        sim.masked_fill_(~mask, max_neg_value)

    # attention, what we cannot get enough of
    attn = sim.softmax(dim=-1)

    out = einsum('b i j, b j d -> b i d', attn, v)
    out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
    return self.to_out(out)


def stack_conds(conds):
    if len(conds) == 1:
        return torch.stack(conds)

    # same as in reconstruct_multicond_batch
    token_count = max([x.shape[0] for x in conds])
    for i in range(len(conds)):
        if conds[i].shape[0] != token_count:
            last_vector = conds[i][-1:]
            last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
            conds[i] = torch.vstack([conds[i], last_vector_repeated])

    return torch.stack(conds)


def statistics(data):
    if len(data) < 2:
        std = 0
    else:
        std = stdev(data)
    total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
    recent_data = data[-32:]
    if len(recent_data) < 2:
        std = 0
    else:
        std = stdev(recent_data)
    recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + u"\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})"
    return total_information, recent_information


def report_statistics(loss_info:dict):
    keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
    for key in keys:
        try:
            print("Loss statistics for file " + key)
            info, recent = statistics(list(loss_info[key]))
            print(info)
            print(recent)
        except Exception as e:
            print(e)

def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
    # Remove illegal characters from name.
    name = "".join( x for x in name if (x.isalnum() or x in "._- "))
    assert name, "Name cannot be empty!"

    fn = os.path.join(root.hypernetwork_dir, f"{name}.pt")
    if not overwrite_old:
        assert not os.path.exists(fn), f"file {fn} already exists"

    if type(layer_structure) == str:
        layer_structure = [float(x.strip()) for x in layer_structure.split(",")]

    if use_dropout and dropout_structure and type(dropout_structure) == str:
        dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")]
    else:
        dropout_structure = [0] * len(layer_structure)

    hypernet = Hypernetwork(
        name=name,
        enable_sizes=[int(x) for x in enable_sizes],
        layer_structure=layer_structure,
        activation_func=activation_func,
        weight_init=weight_init,
        add_layer_norm=add_layer_norm,
        use_dropout=use_dropout,
        dropout_structure=dropout_structure
    )
    hypernet.save(fn)

    reload_hypernetworks()

#apply_optimizations
hypernetwork=Hypernetwork()
import math
import sys
import traceback
import psutil
import contextlib
import torch
from torch import einsum

from ldm.util import default
from einops import rearrange

if root.use_xformers:
    try:
        import xformers.ops
        xformers_available = True
    except Exception:
        print("Cannot import xformers", file=sys.stderr)
        print(traceback.format_exc(), file=sys.stderr)


def get_available_vram():
    if root.device == 'cuda':
        stats = torch.cuda.memory_stats(root.device)
        mem_active = stats['active_bytes.all.current']
        mem_reserved = stats['reserved_bytes.all.current']
        mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
        mem_free_torch = mem_reserved - mem_active
        mem_free_total = mem_free_cuda + mem_free_torch
        return mem_free_total
    else:
        return psutil.virtual_memory().available

# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
    h = self.heads

    q_in = self.to_q(x)
    context = default(context, x)

    context_k, context_v = hypernetwork.apply_hypernetworks(loaded_hypernetworks, context)
    k_in = self.to_k(context_k)
    v_in = self.to_v(context_v)
    del context, context_k, context_v, x

    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
    del q_in, k_in, v_in

    dtype = q.dtype
    upcast_attn=True
    if upcast_attn:
        q, k, v = q.float(), k.float(), v.float()
    
    upcast_attn=True
    with without_autocast(disable=not upcast_attn):
        r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
        for i in range(0, q.shape[0], 2):
            end = i + 2
            s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
            s1 *= self.scale
    
            s2 = s1.softmax(dim=-1)
            del s1
    
            r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
            del s2
        del q, k, v

    r1 = r1.to(dtype)

    r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
    del r1

    return self.to_out(r2)


# taken from https://github.com/Doggettx/stable-diffusion and modified
def split_cross_attention_forward(self, x, context=None, mask=None):
    h = self.heads

    q_in = self.to_q(x)
    context = default(context, x)

    context_k, context_v = hypernetwork.apply_hypernetworks(loaded_hypernetworks, context)
    k_in = self.to_k(context_k)
    v_in = self.to_v(context_v)

    dtype = q_in.dtype
    upcast_attn=True
    if upcast_attn:
        q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
    upcast_attn=True
    with without_autocast(disable=not upcast_attn):
        k_in = k_in * self.scale
    
        del context, x
    
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
        del q_in, k_in, v_in
    
        r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
    
        mem_free_total = get_available_vram()
    
        gb = 1024 ** 3
        tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
        modifier = 3 if q.element_size() == 2 else 2.5
        mem_required = tensor_size * modifier
        steps = 1
    
        if mem_required > mem_free_total:
            steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
            # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
            #       f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
    
        if steps > 64:
            max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
            raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
                               f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
    
        slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
        for i in range(0, q.shape[1], slice_size):
            end = i + slice_size
            s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
    
            s2 = s1.softmax(dim=-1, dtype=q.dtype)
            del s1
    
            r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
            del s2
    
        del q, k, v

    r1 = r1.to(dtype)

    r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
    del r1

    return self.to_out(r2)


# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
mem_total_gb = psutil.virtual_memory().total // (1 << 30)

def einsum_op_compvis(q, k, v):
    s = einsum('b i d, b j d -> b i j', q, k)
    s = s.softmax(dim=-1, dtype=s.dtype)
    return einsum('b i j, b j d -> b i d', s, v)

def einsum_op_slice_0(q, k, v, slice_size):
    r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
    for i in range(0, q.shape[0], slice_size):
        end = i + slice_size
        r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
    return r

def einsum_op_slice_1(q, k, v, slice_size):
    r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
    for i in range(0, q.shape[1], slice_size):
        end = i + slice_size
        r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
    return r

def einsum_op_mps_v1(q, k, v):
    if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
        return einsum_op_compvis(q, k, v)
    else:
        slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
        if slice_size % 4096 == 0:
            slice_size -= 1
        return einsum_op_slice_1(q, k, v, slice_size)

def einsum_op_mps_v2(q, k, v):
    if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
        return einsum_op_compvis(q, k, v)
    else:
        return einsum_op_slice_0(q, k, v, 1)

def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
    size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
    if size_mb <= max_tensor_mb:
        return einsum_op_compvis(q, k, v)
    div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
    if div <= q.shape[0]:
        return einsum_op_slice_0(q, k, v, q.shape[0] // div)
    return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))

def einsum_op_cuda(q, k, v):
    stats = torch.cuda.memory_stats(q.device)
    mem_active = stats['active_bytes.all.current']
    mem_reserved = stats['reserved_bytes.all.current']
    mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
    mem_free_torch = mem_reserved - mem_active
    mem_free_total = mem_free_cuda + mem_free_torch
    # Divide factor of safety as there's copying and fragmentation
    return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))

def einsum_op(q, k, v):
    if q.device.type == 'cuda':
        return einsum_op_cuda(q, k, v)

    if q.device.type == 'mps':
        if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
            return einsum_op_mps_v1(q, k, v)
        return einsum_op_mps_v2(q, k, v)

    # Smaller slices are faster due to L2/L3/SLC caches.
    # Tested on i7 with 8MB L3 cache.
    return einsum_op_tensor_mem(q, k, v, 32)

def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
    h = self.heads

    q = self.to_q(x)
    context = default(context, x)

    context_k, context_v = hypernetwork.apply_hypernetworks(loaded_hypernetworks, context)
    k = self.to_k(context_k)
    v = self.to_v(context_v)
    del context, context_k, context_v, x

    dtype = q.dtype
    upcast_attn=True
    if upcast_attn:
        q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()

    with without_autocast(disable=not upcast_attn):
        k = k * self.scale
    
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
        r = einsum_op(q, k, v)
    r = r.to(dtype)
    return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))

# -- End of code from https://github.com/invoke-ai/InvokeAI --


# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
def sub_quad_attention_forward(self, x, context=None, mask=None):
    assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."

    h = self.heads

    q = self.to_q(x)
    context = default(context, x)

    context_k, context_v = hypernetwork.apply_hypernetworks(loaded_hypernetworks, context)
    k = self.to_k(context_k)
    v = self.to_v(context_v)
    del context, context_k, context_v, x

    q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
    k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
    v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)

    dtype = q.dtype
    upcast_attn=True
    if upcast_attn:
        q, k = q.float(), k.float()
    sub_quad_q_chunk_size=1024
    x = sub_quad_attention(q, k, v, q_chunk_size=sub_quad_q_chunk_size, kv_chunk_size=sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)

    x = x.to(dtype)

    x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)

    out_proj, dropout = self.to_out
    x = out_proj(x)
    x = dropout(x)

    return x

def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
    bytes_per_token = torch.finfo(q.dtype).bits//8
    batch_x_heads, q_tokens, _ = q.shape
    _, k_tokens, _ = k.shape
    qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens

    if chunk_threshold is None:
        chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
    elif chunk_threshold == 0:
        chunk_threshold_bytes = None
    else:
        chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())

    if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
        kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
    elif kv_chunk_size_min == 0:
        kv_chunk_size_min = None

    if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
        # the big matmul fits into our memory limit; do everything in 1 chunk,
        # i.e. send it down the unchunked fast-path
        query_chunk_size = q_tokens
        kv_chunk_size = k_tokens

    with without_autocast(disable=q.dtype == v.dtype):
        return efficient_dot_product_attention(
            q,
            k,
            v,
            query_chunk_size=q_chunk_size,
            kv_chunk_size=kv_chunk_size,
            kv_chunk_size_min = kv_chunk_size_min,
            use_checkpoint=use_checkpoint,
        )

xformers_flash_attention=True
def get_xformers_flash_attention_op(q, k, v):
    if not xformers_flash_attention:
        return None

    try:
        flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
        fw, bw = flash_attention_op
        if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
            return flash_attention_op
    except Exception as e:
        display_once(e, "enabling flash attention")

    return None


def xformers_attention_forward(self, x, context=None, mask=None):
    h = self.heads
    q_in = self.to_q(x)
    context = default(context, x)

    context_k, context_v = hypernetwork.apply_hypernetworks(loaded_hypernetworks, context)
    k_in = self.to_k(context_k)
    v_in = self.to_v(context_v)

    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
    del q_in, k_in, v_in

    dtype = q.dtype
    upcast_attn=True
    if upcast_attn:
        q, k = q.float(), k.float()

    out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))

    out = out.to(dtype)

    out = rearrange(out, 'b n h d -> b n (h d)', h=h)
    return self.to_out(out)

def cross_attention_attnblock_forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q1 = self.q(h_)
        k1 = self.k(h_)
        v = self.v(h_)

        # compute attention
        b, c, h, w = q1.shape

        q2 = q1.reshape(b, c, h*w)
        del q1

        q = q2.permute(0, 2, 1)   # b,hw,c
        del q2

        k = k1.reshape(b, c, h*w) # b,c,hw
        del k1

        h_ = torch.zeros_like(k, device=q.device)

        mem_free_total = get_available_vram()

        tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
        mem_required = tensor_size * 2.5
        steps = 1

        if mem_required > mem_free_total:
            steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))

        slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
        for i in range(0, q.shape[1], slice_size):
            end = i + slice_size

            w1 = torch.bmm(q[:, i:end], k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
            w2 = w1 * (int(c)**(-0.5))
            del w1
            w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
            del w2

            # attend to values
            v1 = v.reshape(b, c, h*w)
            w4 = w3.permute(0, 2, 1)   # b,hw,hw (first hw of k, second of q)
            del w3

            h_[:, :, i:end] = torch.bmm(v1, w4)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
            del v1, w4

        h2 = h_.reshape(b, c, h, w)
        del h_

        h3 = self.proj_out(h2)
        del h2

        h3 += x

        return h3
    
def xformers_attnblock_forward(self, x):
    try:
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)
        b, c, h, w = q.shape
        q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
        dtype = q.dtype
        upcast_attn=True
        if upcast_attn:
            q, k = q.float(), k.float()
        q = q.contiguous()
        k = k.contiguous()
        v = v.contiguous()
        out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
        out = out.to(dtype)
        out = rearrange(out, 'b (h w) c -> b c h w', h=h)
        out = self.proj_out(out)
        return x + out
    except NotImplementedError:
        return cross_attention_attnblock_forward(self, x)

def sub_quad_attnblock_forward(self, x):
    h_ = x
    h_ = self.norm(h_)
    q = self.q(h_)
    k = self.k(h_)
    v = self.v(h_)
    b, c, h, w = q.shape
    q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
    q = q.contiguous()
    k = k.contiguous()
    v = v.contiguous()
    sub_quad_q_chunk_size=1024
    sub_quad_kv_chunk_size=None
    sub_quad_chunk_threshold=None
    out = sub_quad_attention(q, k, v, q_chunk_size=sub_quad_q_chunk_size, kv_chunk_size=sub_quad_kv_chunk_size, chunk_threshold=sub_quad_chunk_threshold, use_checkpoint=self.training)
    out = rearrange(out, 'b (h w) c -> b c h w', h=h)
    out = self.proj_out(out)
    return x + out

#sd_hijack_unet
import torch
from packaging import version
cpu = torch.device("cpu")
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
dtype = torch.float32
dtype_vae = torch.float16
dtype_unet = torch.float16
unet_needs_upcast = False
precision = "full"

def autocast(disable=False):
    
    if disable:
        return contextlib.nullcontext()

    if dtype == torch.float32 or precision == "full":
        return contextlib.nullcontext()

    return torch.autocast("cuda")

class TorchHijackForUnet:
    """
    This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
    this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
    """

    def __getattr__(self, item):
        if item == 'cat':
            return self.cat

        if hasattr(torch, item):
            return getattr(torch, item)

        raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))

    def cat(self, tensors, *args, **kwargs):
        if len(tensors) == 2:
            a, b = tensors
            if a.shape[-2:] != b.shape[-2:]:
                a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")

            tensors = (a, b)

        return torch.cat(tensors, *args, **kwargs)


#th = TorchHijackForUnet()


# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):

    if isinstance(cond, dict):
        for y in cond.keys():
            cond[y] = [x.to(dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]

    with autocast():
        return orig_func(self, x_noisy.to(dtype_unet), t.to(dtype_unet), cond, **kwargs).float()
unet_needs_upcast=True
class GELUHijack(torch.nn.GELU, torch.nn.Module):
    def __init__(self, *args, **kwargs):
        torch.nn.GELU.__init__(self, *args, **kwargs)
    def forward(self, x):
        if unet_needs_upcast:
            return torch.nn.GELU.forward(self.float(), x.float()).to(dtype_unet)
        else:
            return torch.nn.GELU.forward(self, x)

unet_needs_upcast = lambda *args, **kwargs: unet_needs_upcast
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else dtype_unet), unet_needs_upcast)
if version.parse(torch.__version__) <= version.parse("1.13.1"):
    CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
    CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(dtype_unet), unet_needs_upcast)
    CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)

first_stage_cond = lambda _, self, *args, **kwargs: unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(dtype_vae), **kwargs)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)

#@title sd_hijack
import torch
from torch.nn.functional import silu
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
import ldm.modules.diffusionmodules.openaimodel
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
import ldm.modules.encoders.modules

attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward

# new memory efficient cross attention blocks do not support hypernets and we already
# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention

# silence new console spam from SD2
ldm.modules.attention.print = lambda *args: None
ldm.modules.diffusionmodules.model.print = lambda *args: None
th = TorchHijackForUnet()
def apply_optimizations():
    undo_optimizations()

    ldm.modules.diffusionmodules.model.nonlinearity = silu
    ldm.modules.diffusionmodules.openaimodel.th = th
    
    optimization_method = None

    if root.use_xformers and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(root.device) <= (9, 0):
        print("Applying xformers cross attention optimization.")
        ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
        ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
        optimization_method = 'xformers'
    elif root.use_sub_quad_attention and not root.use_xformers:
        print("Applying sub-quadratic cross attention optimization.")
        ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
        ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
        optimization_method = 'sub-quadratic'
    elif root.use_split_attention_v1 and not root.use_sub_quad_attention and not root.use_xformers:
        print("Applying v1 cross attention optimization.")
        ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
        optimization_method = 'V1'
    elif not disable_opt_split_attention and (root.use_split_cross_attention_forward_invokeAI or not opt_split_attention and not torch.cuda.is_available()):
        print("Applying cross attention optimization (InvokeAI).")
        ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
        optimization_method = 'InvokeAI'
    elif not disable_opt_split_attention and (opt_split_attention or torch.cuda.is_available()):
        print("Applying cross attention optimization (Doggettx).")
        ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
        ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
        optimization_method = 'Doggettx'

    return optimization_method


def undo_optimizations():
    ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward
    ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
    ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward


def fix_checkpoint():
    """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
    checkpoints to be added when not training (there's a warning)"""

    pass

from transformers import BertPreTrainedModel,BertModel,BertConfig
import torch.nn as nn
import torch
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional

class BertSeriesConfig(BertConfig):
    def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):

        super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
        self.project_dim = project_dim
        self.pooler_fn = pooler_fn
        self.learn_encoder = learn_encoder

class RobertaSeriesConfig(XLMRobertaConfig):
    def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
        super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
        self.project_dim = project_dim
        self.pooler_fn = pooler_fn
        self.learn_encoder = learn_encoder


class BertSeriesModelWithTransformation(BertPreTrainedModel):

    _keys_to_ignore_on_load_unexpected = [r"pooler"]
    _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
    config_class = BertSeriesConfig

    def __init__(self, config=None, **kargs):
        # modify initialization for autoloading 
        if config is None:
            config = XLMRobertaConfig()
            config.attention_probs_dropout_prob= 0.1
            config.bos_token_id=0
            config.eos_token_id=2
            config.hidden_act='gelu'
            config.hidden_dropout_prob=0.1
            config.hidden_size=1024
            config.initializer_range=0.02
            config.intermediate_size=4096
            config.layer_norm_eps=1e-05
            config.max_position_embeddings=514

            config.num_attention_heads=16
            config.num_hidden_layers=24
            config.output_past=True
            config.pad_token_id=1
            config.position_embedding_type= "absolute"

            config.type_vocab_size= 1
            config.use_cache=True
            config.vocab_size= 250002
            config.project_dim = 768
            config.learn_encoder = False
        super().__init__(config)
        self.roberta = XLMRobertaModel(config)
        self.transformation = nn.Linear(config.hidden_size,config.project_dim)
        self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
        self.pooler = lambda x: x[:,0]
        self.post_init()

    def encode(self,c):
        device = next(self.parameters()).device
        text = self.tokenizer(c,
                        truncation=True,
                        max_length=77,
                        return_length=False,
                        return_overflowing_tokens=False,
                        padding="max_length",
                        return_tensors="pt")
        text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
        text["attention_mask"] = torch.tensor(
            text['attention_mask']).to(device)
        features = self(**text)
        return features['projection_state'] 

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
    ) :
        r"""
        """

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict


        outputs = self.roberta(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=True,
            return_dict=return_dict,
        )

        # last module outputs
        sequence_output = outputs[0]


        # project every module
        sequence_output_ln = self.pre_LN(sequence_output)

        # pooler
        pooler_output = self.pooler(sequence_output_ln)
        pooler_output = self.transformation(pooler_output)
        projection_state = self.transformation(outputs.last_hidden_state)

        return {
            'pooler_output':pooler_output,
            'last_hidden_state':outputs.last_hidden_state,
            'hidden_states':outputs.hidden_states,
            'attentions':outputs.attentions,
            'projection_state':projection_state,
            'sequence_out': sequence_output
        }


class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
    base_model_prefix = 'roberta'
    config_class= RobertaSeriesConfig

class StableDiffusionModelHijack:
    fixes = None
    comments = []
    layers = None
    circular_enabled = False
    clip = None
    optimization_method = None

    embedding_db = EmbeddingDatabase()

    def __init__(self):
        self.embedding_db.add_embedding_dir(root.embeddings_dir)

    def hijack(self, m):

        # if type(m.cond_stage_model) == BertSeriesModelWithTransformation:
        #     model_embeddings = m.cond_stage_model.roberta.embeddings
        #     model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
        #     m.cond_stage_model = FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)

        if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
            model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
            model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
            m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)

        elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
            m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
            m.cond_stage_model = FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)

        optimization_method = apply_optimizations()

        self.clip = m.cond_stage_model

        def flatten(el):
            flattened = [flatten(children) for children in el.children()]
            res = [el]
            for c in flattened:
                res += c
            return res

        self.layers = flatten(m)

    def undo_hijack(self, m):
        
        if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords:
            m.cond_stage_model = m.cond_stage_model.wrapped

            model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
            if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
                model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
        elif type(m.cond_stage_model) == FrozenOpenCLIPEmbedderWithCustomWords:
            m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
            m.cond_stage_model = m.cond_stage_model.wrapped

        self.apply_circular(False)
        self.layers = None
        self.clip = None

    def apply_circular(self, enable):
        if self.circular_enabled == enable:
            return

        self.circular_enabled = enable

        for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
            layer.padding_mode = 'circular' if enable else 'zeros'

    def clear_comments(self):
        self.comments = []

    def get_prompt_lengths(self, text):
        _, token_count = self.clip.process_texts([text])

        return token_count, self.clip.get_target_prompt_token_count(token_count)


class EmbeddingsWithFixes(torch.nn.Module):
    def __init__(self, wrapped, embeddings):
        super().__init__()
        self.wrapped = wrapped
        self.embeddings = embeddings

    def forward(self, input_ids):
        batch_fixes = self.embeddings.fixes
        self.embeddings.fixes = None

        inputs_embeds = self.wrapped(input_ids)

        if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
            return inputs_embeds

        vecs = []
        for fixes, tensor in zip(batch_fixes, inputs_embeds):
            for offset, embedding in fixes:
                emb = embedding.vec
                emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
                tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])

            vecs.append(tensor)

        return torch.stack(vecs)


def add_circular_option_to_conv_2d():
    conv2d_constructor = torch.nn.Conv2d.__init__

    def conv2d_constructor_circular(self, *args, **kwargs):
        return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)

    torch.nn.Conv2d.__init__ = conv2d_constructor_circular


model_hijack = StableDiffusionModelHijack()


def register_buffer(self, name, attr):
    """
    Fix register buffer bug for Mac OS.
    """

    if type(attr) == torch.Tensor:
        if attr.device != root.device:
            attr = attr.to(device=root.device, dtype=(torch.float32 if root.device.type == 'mps' else None))

    setattr(self, name, attr)


ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer

#@title hijack_checkpoint

from torch.utils.checkpoint import checkpoint

import ldm.modules.attention
import ldm.modules.diffusionmodules.openaimodel


def BasicTransformerBlock_forward(self, x, context=None):
    return checkpoint(self._forward, x, context)


def AttentionBlock_forward(self, x):
    return checkpoint(self._forward, x)


def ResBlock_forward(self, x, emb):
    return checkpoint(self._forward, x, emb)


stored = []


def hijack_checkpoint_add():
    if len(stored) != 0:
        return

    stored.extend([
        ldm.modules.attention.BasicTransformerBlock.forward,
        ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
        ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
    ])

    ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
    ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
    ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward


def hijack_checkpoint_remove():
    if len(stored) == 0:
        return

    ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
    ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
    ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]

    stored.clear()


#@title hijack_clip

import math
from collections import namedtuple

import torch

class PromptChunk:
    """
    This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
    If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
    Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
    so just 75 tokens from prompt.
    """

    def __init__(self):
        self.tokens = []
        self.multipliers = []
        self.fixes = []


PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
"""An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""


class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
    """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
    have unlimited prompt length and assign weights to tokens in prompt.
    """

    def __init__(self, wrapped, hijack):
        super().__init__()

        self.wrapped = wrapped
        """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
        depending on model."""

        self.hijack: StableDiffusionModelHijack = hijack
        self.chunk_length = 75

    def empty_chunk(self):
        """creates an empty PromptChunk and returns it"""

        chunk = PromptChunk()
        chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
        chunk.multipliers = [1.0] * (self.chunk_length + 2)
        return chunk

    def get_target_prompt_token_count(self, token_count):
        """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""

        return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length

    def tokenize(self, texts):
        """Converts a batch of texts into a batch of token ids"""

        raise NotImplementedError

    def encode_with_transformers(self, tokens):
        """
        converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
        All python lists with tokens are assumed to have same length, usually 77.
        if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
        model - can be 768 and 1024.
        Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None).
        """

        raise NotImplementedError

    def encode_embedding_init_text(self, init_text, nvpt):
        """Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through
        transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""

        raise NotImplementedError

    def tokenize_line(self, line):
        """
        this transforms a single prompt into a list of PromptChunk objects - as many as needed to
        represent the prompt.
        Returns the list and the total number of tokens in the prompt.
        """


        parsed = [[line, 1.0]]

        tokenized = self.tokenize([text for text, _ in parsed])

        chunks = []
        chunk = PromptChunk()
        token_count = 0
        last_comma = -1

        def next_chunk(is_last=False):
            """puts current chunk into the list of results and produces the next one - empty;
            if is_last is true, tokens <end-of-text> tokens at the end won't add to token_count"""
            nonlocal token_count
            nonlocal last_comma
            nonlocal chunk

            if is_last:
                token_count += len(chunk.tokens)
            else:
                token_count += self.chunk_length

            to_add = self.chunk_length - len(chunk.tokens)
            if to_add > 0:
                chunk.tokens += [self.id_end] * to_add
                chunk.multipliers += [1.0] * to_add

            chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
            chunk.multipliers = [1.0] + chunk.multipliers + [1.0]

            last_comma = -1
            chunks.append(chunk)
            chunk = PromptChunk()

        for tokens, (text, weight) in zip(tokenized, parsed):
            if text == 'BREAK' and weight == -1:
                next_chunk()
                continue

            position = 0
            while position < len(tokens):
                token = tokens[position]

                if token == self.comma_token:
                    last_comma = len(chunk.tokens)
                # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
                # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
                # comma_padding_backtrack=74    
                # elif comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= comma_padding_backtrack:
                #     break_location = last_comma + 1

                #     reloc_tokens = chunk.tokens[break_location:]
                #     reloc_mults = chunk.multipliers[break_location:]

                #     chunk.tokens = chunk.tokens[:break_location]
                #     chunk.multipliers = chunk.multipliers[:break_location]

                #     next_chunk()
                #     chunk.tokens = reloc_tokens
                #     chunk.multipliers = reloc_mults

                if len(chunk.tokens) == self.chunk_length:
                    next_chunk()

                embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position)
                if embedding is None:
                    chunk.tokens.append(token)
                    chunk.multipliers.append(weight)
                    position += 1
                    continue

                emb_len = int(embedding.vec.shape[0])
                if len(chunk.tokens) + emb_len > self.chunk_length:
                    next_chunk()

                chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))

                chunk.tokens += [0] * emb_len
                chunk.multipliers += [weight] * emb_len
                position += embedding_length_in_tokens

        if len(chunk.tokens) > 0 or len(chunks) == 0:
            next_chunk(is_last=True)

        return chunks, token_count

    def process_texts(self, texts):
        """
        Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
        length, in tokens, of all texts.
        """

        token_count = 0

        cache = {}
        batch_chunks = []
        for line in texts:
            if line in cache:
                chunks = cache[line]
            else:
                chunks, current_token_count = self.tokenize_line(line)
                token_count = max(current_token_count, token_count)

                cache[line] = chunks

            batch_chunks.append(chunks)

        return batch_chunks, token_count

    def forward(self, texts):
        """
        Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
        Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
        be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
        An example shape returned by this function can be: (2, 77, 768).
        Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
        is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
        """

        # if opts.use_old_emphasis_implementation:
        #     import modules.sd_hijack_clip_old
        #     return modules.sd_hijack_clip_old.forward_old(self, texts)

        batch_chunks, token_count = self.process_texts(texts)

        used_embeddings = {}
        chunk_count = max([len(x) for x in batch_chunks])

        zs = []
        for i in range(chunk_count):
            batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]

            tokens = [x.tokens for x in batch_chunk]
            multipliers = [x.multipliers for x in batch_chunk]
            self.hijack.fixes = [x.fixes for x in batch_chunk]

            for fixes in self.hijack.fixes:
                for position, embedding in fixes:
                    used_embeddings[embedding.name] = embedding

            z = self.process_tokens(tokens, multipliers)
            zs.append(z)

        if len(used_embeddings) > 0:
            embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()])
            self.hijack.comments.append(f"Used embeddings: {embeddings_list}")

        return torch.hstack(zs)

    def process_tokens(self, remade_batch_tokens, batch_multipliers):
        """
        sends one single prompt chunk to be encoded by transformers neural network.
        remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
        there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
        Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
        corresponds to one token.
        """
        tokens = torch.asarray(remade_batch_tokens).to(root.device)

        # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
        if self.id_end != self.id_pad:
            for batch_pos in range(len(remade_batch_tokens)):
                index = remade_batch_tokens[batch_pos].index(self.id_end)
                tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad

        z = self.encode_with_transformers(tokens)

        # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
        batch_multipliers = torch.asarray(batch_multipliers).to(root.device)
        original_mean = z.mean()
        z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
        new_mean = z.mean()
        z = z * (original_mean / new_mean)

        return z


class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
    def __init__(self, wrapped, hijack):
        super().__init__(wrapped, hijack)
        self.tokenizer = wrapped.tokenizer

        vocab = self.tokenizer.get_vocab()

        self.comma_token = vocab.get(',</w>', None)

        self.token_mults = {}
        tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
        for text, ident in tokens_with_parens:
            mult = 1.0
            for c in text:
                if c == '[':
                    mult /= 1.1
                if c == ']':
                    mult *= 1.1
                if c == '(':
                    mult *= 1.1
                if c == ')':
                    mult /= 1.1

            if mult != 1.0:
                self.token_mults[ident] = mult

        self.id_start = self.wrapped.tokenizer.bos_token_id
        self.id_end = self.wrapped.tokenizer.eos_token_id
        self.id_pad = self.id_end

    def tokenize(self, texts):
        tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]

        return tokenized

    def encode_with_transformers(self, tokens):
        CLIP_stop_at_last_layers=1
        outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=CLIP_stop_at_last_layers)

        z = outputs.last_hidden_state

        return z

    def encode_embedding_init_text(self, init_text, nvpt):
        embedding_layer = self.wrapped.transformer.text_model.embeddings
        ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
        embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)

        return embedded

#@title hijack_clip_old
# from drive.MyDrive import sd_hijack_clip
def process_text_old(self: FrozenCLIPEmbedderWithCustomWordsBase, texts):
    id_start = self.id_start
    id_end = self.id_end
    maxlen = self.wrapped.max_length  # you get to stay at 77
    used_custom_terms = []
    remade_batch_tokens = []
    hijack_comments = []
    hijack_fixes = []
    token_count = 0

    cache = {}
    batch_tokens = self.tokenize(texts)
    batch_multipliers = []
    for tokens in batch_tokens:
        tuple_tokens = tuple(tokens)

        if tuple_tokens in cache:
            remade_tokens, fixes, multipliers = cache[tuple_tokens]
        else:
            fixes = []
            remade_tokens = []
            multipliers = []
            mult = 1.0

            i = 0
            while i < len(tokens):
                token = tokens[i]

                embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)

                mult_change = self.token_mults.get(token)
                if mult_change is not None:
                    mult *= mult_change
                    i += 1
                elif embedding is None:
                    remade_tokens.append(token)
                    multipliers.append(mult)
                    i += 1
                else:
                    emb_len = int(embedding.vec.shape[0])
                    fixes.append((len(remade_tokens), embedding))
                    remade_tokens += [0] * emb_len
                    multipliers += [mult] * emb_len
                    used_custom_terms.append((embedding.name, embedding.checksum()))
                    i += embedding_length_in_tokens

            if len(remade_tokens) > maxlen - 2:
                vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
                ovf = remade_tokens[maxlen - 2:]
                overflowing_words = [vocab.get(int(x), "") for x in ovf]
                overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
                hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")

            token_count = len(remade_tokens)
            remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
            remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
            cache[tuple_tokens] = (remade_tokens, fixes, multipliers)

        multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
        multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]

        remade_batch_tokens.append(remade_tokens)
        hijack_fixes.append(fixes)
        batch_multipliers.append(multipliers)
    return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count


def forward_old(self: FrozenCLIPEmbedderWithCustomWordsBase, texts):
    batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts)

    self.hijack.comments += hijack_comments

    if len(used_custom_terms) > 0:
        self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))

    self.hijack.fixes = hijack_fixes
    return self.process_tokens(remade_batch_tokens, batch_multipliers)

#@title hijack_open_clip

import open_clip.tokenizer
import torch

tokenizer = open_clip.tokenizer._tokenizer


class FrozenOpenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
    def __init__(self, wrapped, hijack):
        super().__init__(wrapped, hijack)

        self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
        self.id_start = tokenizer.encoder["<start_of_text>"]
        self.id_end = tokenizer.encoder["<end_of_text>"]
        self.id_pad = 0

    def tokenize(self, texts):

        tokenized = [tokenizer.encode(text) for text in texts]

        return tokenized

    def encode_with_transformers(self, tokens):
        # set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers
        z = self.wrapped.encode_with_transformer(tokens)

        return z

    def encode_embedding_init_text(self, init_text, nvpt):
        ids = tokenizer.encode(init_text)
        ids = torch.asarray([ids], device="cuda", dtype=torch.int)
        embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)

        return embedded

import open_clip.tokenizer
import torch

class FrozenXLMREmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWords):
    def __init__(self, wrapped, hijack):
        super().__init__(wrapped, hijack)

        self.id_start = wrapped.config.bos_token_id
        self.id_end = wrapped.config.eos_token_id
        self.id_pad = wrapped.config.pad_token_id

        self.comma_token = self.tokenizer.get_vocab().get(',', None)  # alt diffusion doesn't have </w> bits for comma

    def encode_with_transformers(self, tokens):
        # there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
        # trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
        # layer to work with - you have to use the last

        attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
        features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
        z = features['projection_state']

        return z

    def encode_embedding_init_text(self, init_text, nvpt):
        embedding_layer = self.wrapped.roberta.embeddings
        ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
        embedded = embedding_layer.token_embedding.wrapped(ids.to(root.device)).squeeze(0)

        return embedded

def reload_hypernetworks():
    
    global hypernetworks

    hypernetworks = list_hypernetworks(root.hypernetwork_dir)

model_hijack = StableDiffusionModelHijack()
embedding_db = EmbeddingDatabase()
hypernetwork = Hypernetwork()
model_hijack.hijack(model)
model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
reload_hypernetworks()
names = []
names_list1 = list_hypernetworks(root.hypernetwork_dir)
names_list = tuple(list_hypernetworks(root.hypernetwork_dir).keys())
names.append(names_list)
multipliers=3.0
load_hypernetworks(names, multipliers)
for name in names_list:
    load_hypernetwork(name)

In [None]:
#@title $\color{blue} {\large \textsf{Animation Settings}}$
def DeforumAnimArgs():

    #@markdown ####**Animation:**
    animation_mode = '3D' #@param ['None', '2D', '3D', 'Video Input', 'Interpolation'] {type:'string'}
    max_frames = 6000 #@param {type:"number"}
    border = 'wrap' #@param ['wrap', 'replicate'] {type:'string'}

    #@markdown ####**Motion Parameters:**
    angle = "0:(0)"#@param {type:"string"}
    zoom = "0:(1.02+0.02*sin(2*3.14*t/300))"#@param {type:"string"}
    translation_x = "0:(3*sin(2*3.14*t/270))"#@param {type:"string"}
    translation_y = "0:((3*sin(2*3.14*t/270)**20)-3)"#@param {type:"string"}
    translation_z = "0: (2), 30: ((10*(sin(3.141*t/300)**900)+2) + (+0.1 *(sin(3.141*t/900)**50)+2))"#@param {type:"string"}
    rotation_3d_x = "0:(0)"#@param {type:"string"}
    rotation_3d_y = "0:(-0.2*sin(2*3.14*t/270))"#@param {type:"string"}
    rotation_3d_z = "0: (0), 30: ((10*(sin(3.141*t/600)**5000)+0.5) + (+0.1 *(sin(3.141*t/600)**50)-0.5))"#@param {type:"string"}
    rotation_3d_w = "0: (0.4)"
    flip_2d_perspective = False #@param {type:"boolean"}
    perspective_flip_theta = "0:(0)"#@param {type:"string"}
    perspective_flip_phi = "0:(t%15)"#@param {type:"string"}
    perspective_flip_gamma = "0:(0)"#@param {type:"string"}
    perspective_flip_fv = "0:(53)"#@param {type:"string"}
    noise_schedule = "0:(.02)"#@param {type:"string"}
    strength_schedule = "0:(.60), 30: (-4*(cos(3.141*t/90)**1000)+0.60)"#@param {type:"string"}
    contrast_schedule = "0: (1.0)"#@param {type:"string"}
    hybrid_comp_alpha_schedule = "0:(1)" #@param {type:"string"}
    hybrid_comp_mask_blend_alpha_schedule = "0:(0.5)" #@param {type:"string"}
    hybrid_comp_mask_contrast_schedule = "0:(1)" #@param {type:"string"}
    hybrid_comp_mask_auto_contrast_cutoff_high_schedule =  "0:(100)" #@param {type:"string"}
    hybrid_comp_mask_auto_contrast_cutoff_low_schedule =  "0:(0)" #@param {type:"string"}

    #@markdown ####**Cadence Scheduling:**
    enable_cadence_schedule = True #@param {type:"boolean"}
    cadence_schedule = "0:(4),30:(6), 60:(8), 90:(4)" #@param{type:'string'}

    #@markdown ####**Sampler Scheduling:**
    enable_schedule_samplers = False #@param {type:"boolean"}
    sampler_schedule = "0:('euler'),10:('dpm2'),20:('dpm2_ancestral'),30:('heun'),40:('euler'),50:('euler_ancestral'),60:('dpm_fast'),70:('dpm_adaptive'),80:('dpmpp_2s_a'),90:('dpmpp_2m')" #@param {type:"string"}

    #@markdown ####**Unsharp mask (anti-blur) Parameters:**
    kernel_schedule = "0: (5)"#@param {type:"string"}
    sigma_schedule = "0: (1.0)"#@param {type:"string"}
    amount_schedule = "0: (0.2)"#@param {type:"string"}
    threshold_schedule = "0: (0.0)"#@param {type:"string"}

    #@markdown ####**Coherence:**
    color_coherence = 'Match Frame 0 RGB' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB', 'Video Input'] {type:'string'}
    color_coherence_video_every_N_frames = 1 #@param {type:"integer"}
    color_force_grayscale = False #@param {type:"boolean"}
    diffusion_cadence = '4' #@param ['1','2','3','4','5','6','7','8'] {type:'string'}

    #@markdown ####**3D Depth Warping:**
    use_depth_warping = True #@param {type:"boolean"}
    midas_weight = 0.3#@param {type:"number"}
    near_plane = 200
    far_plane = 10000
    fov = 40#@param {type:"number"}
    padding_mode = 'border'#@param ['border', 'reflection', 'zeros'] {type:'string'}
    sampling_mode = 'bicubic'#@param ['bicubic', 'bilinear', 'nearest'] {type:'string'}
    save_depth_maps = False #@param {type:"boolean"}

    #@markdown ####**Video Input:**
    video_init_path ='/content/video_in.mp4'#@param {type:"string"}
    extract_nth_frame = 1#@param {type:"number"}
    overwrite_extracted_frames = True #@param {type:"boolean"}
    use_mask_video = False #@param {type:"boolean"}
    video_mask_path ='/content/video_in.mp4'#@param {type:"string"}

    #@markdown ####**Hybrid Video for 2D/3D Animation Mode:**
    hybrid_generate_inputframes = False #@param {type:"boolean"}
    hybrid_use_first_frame_as_init_image = True #@param {type:"boolean"}
    hybrid_motion = "None" #@param ['None','Optical Flow','Perspective','Affine']
    hybrid_motion_use_prev_img = False #@param {type:"boolean"}
    hybrid_flow_method = "DIS Medium" #@param ['DenseRLOF','DIS Medium','Farneback','SF']
    hybrid_composite = False #@param {type:"boolean"}
    hybrid_comp_mask_type = "None" #@param ['None', 'Depth', 'Video Depth', 'Blend', 'Difference']
    hybrid_comp_mask_inverse = False #@param {type:"boolean"}
    hybrid_comp_mask_equalize = "None" #@param  ['None','Before','After','Both']
    hybrid_comp_mask_auto_contrast = False #@param {type:"boolean"}
    hybrid_comp_save_extra_frames = False #@param {type:"boolean"}
    hybrid_use_video_as_mse_image = False #@param {type:"boolean"}

    #@markdown ####**Interpolation:**
    interpolate_key_frames = False #@param {type:"boolean"}
    interpolate_x_frames = 4 #@param {type:"number"}
    
    #@markdown ####**Resume Animation:**
    resume_from_timestring = False #@param {type:"boolean"}
    resume_timestring = "20230203050018" #@param {type:"string"}

    return locals()

In [None]:
#@title $\color{BLUE} {\large \textsf{Simple Prompt Modifier Generator}}$

import random
from ipywidgets import widgets, Label
from IPython import display
#artists
psychedelic_artists = ["Pablo Amaringo", "Chuck Arnett", "Chris Dyer", "Doug Binder",  "Brummbaer", "Mark Boyle", "Joan Hills",]
surrealist_artists = ["Genevieve Leavold", "Sam Wilde", "Adam Lawrence", "Violet Polsangi",  "Kazuhiro Higashi", "Kim Marra", "Ricardo Harris-Fuentes",]
anime_artists = ["George Morikaw", "Keisuke Itagaki", "Yoichi Takahash", "Hirohiko Araki", "Masashi Kishimoto", "Yoshihiro Togashi", "Hajime Isayama", "Gosho Aoyama", "Akira Toriyama", "Eiichiro Oda",]

#high definition
high_definition = ["4K","High-resolution","High-quality","Sharp","Detailed","Vibrant","Clear","Realistic","Immersive","Stunning","Crisp","Bright","Colorful","Smooth","Life-like","Rich","Lifelike","Vivid","Gorgeous","Beautiful","Striking","Sharpness","Depth","Intense","Brilliant","Extraordinary","Dynamic","Elegant","Expansive","Fantastic","Fine","Outstanding","Magnificent","Mesmerizing","Incredible","Impressive","Sensational","Spectacular","Superb","Stupendous","Tremendous","Unbelievable","Unreal"]

#art style
art_style = ["Renaissance","Baroque","Rococo","Gothic","Impressionism","Post-Impressionism","Expressionism","Art Nouveau","Futurism","Cubism","Surrealism","Abstract Expressionism","Pop Art","Minimalism","Conceptual Art","Realism","Romanticism","Neoclassicism","Pre-Raphaelitism","Fauvism","Orphism","De Stijl","Suprematism","Constructivism","Dadaism","Expressionist Architecture","Bauhaus","Art Deco","Futurist Architecture","Concrete Art","Arte Povera","New Objectivity","Op Art","Hyperrealism","Photorealism","Magic Realism","Social Realism","Early Netherlandish Painting","High Renaissance","Mannerism","Northern Renaissance","Baroque Classicism","Caravaggism","Dutch Golden Age Painting","Romanticism in Art","American Regionalism","Abstract Illusionism","Abstract Impressionism","Neo-Expressionism","Hyperrealistic Sculpture","Environmental Art","New Media Art","Contemporary Realism","Neo-Pop Art"]

button1 = widgets.Button(description="Psychedelic Artists", button_style="danger")

def on_button1_clicked(b):

  selected_artists = random.sample(psychedelic_artists, 3)
  print("Psychedelic Artists: " + ", ".join(selected_artists))
  # label = Label(value="Selected artists: " + ", ".join(selected_artists))
  # display(label)
button1.on_click(on_button1_clicked)

button2 = widgets.Button(description="Surrealist Artists", button_style="success")

def on_button2_clicked(b):

  selected_artists = random.sample(surrealist_artists, 3)
  print("Surrealist Artists: " + ", ".join(selected_artists))
  # label = Label(value="Selected artists: " + ", ".join(selected_artists))
  # display(label)
button2.on_click(on_button2_clicked)

button3 = widgets.Button(description="Anime Artists", button_style="info")

def on_button3_clicked(b):

  selected_artists = random.sample(anime_artists, 3)
  print("Anime Artists: " + ", ".join(selected_artists))
  # label = Label(value="Selected artists: " + ", ".join(selected_artists))
  # display(label)
button3.on_click(on_button3_clicked)

button_all = widgets.Button(description="Remix", button_style="Warning")

def on_button_all_clicked(b):

  first_artist = random.choice(psychedelic_artists)
  second_artist = random.choice(surrealist_artists)
  third_artist = random.choice(anime_artists)
  art_styles = random.choice(art_style)
  hd_word = random.choice(high_definition)
  print("Remix of Prompt Words: {}, {}, {}, {}, {}".format(first_artist, second_artist, third_artist, art_styles, hd_word))
  # label = Label(value="Selected artists: " + ", ".join(selected_artists))
  # display(label)
button_all.on_click(on_button_all_clicked)

button4 = widgets.Button(description="High Definition Words", button_style="success")

def on_button4_clicked(b):

  selected_hd = random.sample(high_definition, 4)
  print("High Definition Words: " + ", ".join(selected_hd))
  # label = Label(value="Selected artists: " + ", ".join(selected_artists))
  # display(label)
button4.on_click(on_button4_clicked)

button5 = widgets.Button(description="Art Styles", button_style="info")

def on_button5_clicked(b):

  selected_style = random.sample(art_style, 4)
  print("Art Styles: " + ", ".join(selected_style))
  # label = Label(value="Selected artists: " + ", ".join(selected_artists))
  # display(label)
button5.on_click(on_button5_clicked)

display.display(button1, button2, button3, button4, button5, button_all)

Button(button_style='danger', description='Psychedelic Artists', style=ButtonStyle())

Button(button_style='success', description='Surrealist Artists', style=ButtonStyle())

Button(button_style='info', description='Anime Artists', style=ButtonStyle())

Button(button_style='success', description='High Definition Words', style=ButtonStyle())

Button(button_style='info', description='Art Styles', style=ButtonStyle())



In [None]:
# conditional (postitive) prompts
cond_prompts = {
    0: "dreamlikeart portrait of (LuisapDarkage_pinguinstyle5Darkage:3) hulking herculean extra terrestrial aquatic grey alien horns, dark, gore, creepy, luisap",
    200: "dreamlikeart portrait of (mjv4Hypernetwork_v1:3) melting psychedelic android robot, portrait, octane render, highly detailed",
    400: "dreamlikeart portrait of (LuisapDarkage_pinguinstyle5Darkage:3) hulking herculean extra terrestrial aquatic grey alien horns muted colors, dark, gore, creepy, luisap",
    600: "dreamlikeart portrait of (mjv4Hypernetwork_v1:3) melting psychedelic zombie, portrait, octane render, highly detailed",
    
}

# unconditional (negative) prompts
uncond_prompts = {
    0: "realistic, bad drawing, incoherent, messy, low resolution, pixelated",
}

In [None]:
#@title $\color{BLUE} {\large \textsf{Load Settings}}$
override_settings_with_file = False #@param {type:"boolean"}
settings_file = "custom" #@param ["custom", "512x512_aesthetic_0.json","512x512_aesthetic_1.json","512x512_colormatch_0.json","512x512_colormatch_1.json","512x512_colormatch_2.json","512x512_colormatch_3.json"]
custom_settings_file = "/content/drive/MyDrive/Settings.txt"#@param {type:"string"}

def DeforumArgs():
    #@markdown **Image Settings**
    W = 512 #@param
    H = 512 #@param
    W, H = map(lambda x: x - x % 64, (W, H))  # resize to integer multiple of 64
    bit_depth_output = 8 #@param [8, 16, 32] {type:"raw"}

    #@markdown **Sampling Settings**
    seed = 724456508 #@param
    sampler = 'dpm2_ancestral' #@param ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral","plms", "ddim", "dpm_fast", "dpm_adaptive", "dpmpp_2s_a", "dpmpp_2m"]
    steps = 50 #@param
    scale = 7 #@param
    ddim_eta = 0.0 #@param
    dynamic_threshold = None
    static_threshold = None   

    #@markdown **Save & Display Settings**
    save_samples = True #@param {type:"boolean"}
    save_settings = True #@param {type:"boolean"}
    display_samples = True #@param {type:"boolean"}
    save_sample_per_step = False #@param {type:"boolean"}
    show_sample_per_step = False #@param {type:"boolean"}

    #@markdown **Batch Settings**
    n_batch = 1 #@param
    n_samples = 1 #@param
    batch_name = "upscale" #@param {type:"string"}
    filename_format = "{timestring}_{index}_{prompt}.png" #@param ["{timestring}_{index}_{seed}.png","{timestring}_{index}_{prompt}.png"]
    seed_behavior = "iter" #@param ["iter","fixed","random","ladder","alternate"]
    seed_iter_N = 1 #@param {type:'integer'}
    make_grid = False #@param {type:"boolean"}
    grid_rows = 2 #@param 
    outdir = get_output_folder(root.output_path, batch_name)

    #@markdown **Init Settings**
    use_init = False #@param {type:"boolean"}
    strength = 0.65 #@param {type:"number"}
    strength_0_no_init = True # Set the strength to 0 automatically when no init image is used
    init_image = "https://cdn.pixabay.com/photo/2022/07/30/13/10/green-longhorn-beetle-7353749_1280.jpg" #@param {type:"string"}
    add_init_noise = True #@param {type:"boolean"}
    init_noise = 0.01 #@param
    # Whiter areas of the mask are areas that change more
    use_mask = False #@param {type:"boolean"}
    use_alpha_as_mask = False # use the alpha channel of the init image as the mask
    mask_file = "https://www.filterforge.com/wiki/images/archive/b/b7/20080927223728%21Polygonal_gradient_thumb.jpg" #@param {type:"string"}
    invert_mask = False #@param {type:"boolean"}
    # Adjust mask image, 1.0 is no adjustment. Should be positive numbers.
    mask_brightness_adjust = 1.0  #@param {type:"number"}
    mask_contrast_adjust = 1.0  #@param {type:"number"}
    # Overlay the masked image at the end of the generation so it does not get degraded by encoding and decoding
    overlay_mask = True  # {type:"boolean"}
    # Blur edges of final overlay mask, if used. Minimum = 0 (no blur)
    mask_overlay_blur = 5 # {type:"number"}

    #@markdown **Exposure/Contrast Conditional Settings**
    mean_scale = 0 #@param {type:"number"}
    var_scale = 0 #@param {type:"number"}
    exposure_scale = 0 #@param {type:"number"}
    exposure_target = 0.5 #@param {type:"number"}

    #@markdown **Color Match Conditional Settings**
    colormatch_scale = 0 #@param {type:"number"}
    colormatch_image = "https://www.saasdesign.io/wp-content/uploads/2021/02/palette-3-min-980x588.png" #@param {type:"string"}
    colormatch_n_colors = 4 #@param {type:"number"}
    ignore_sat_weight = 0 #@param {type:"number"}

    #@markdown **CLIP\Aesthetics Conditional Settings**
    clip_name = 'ViT-L/14' #@param ['ViT-L/14', 'ViT-L/14@336px', 'ViT-B/16', 'ViT-B/32']
    clip_scale = 0 #@param {type:"number"}
    aesthetics_scale = 0 #@param {type:"number"}
    cutn = 1 #@param {type:"number"}
    cut_pow = 0.0001 #@param {type:"number"}

    #@markdown **Other Conditional Settings**
    init_mse_scale = 0 #@param {type:"number"}
    init_mse_image = "https://cdn.pixabay.com/photo/2022/07/30/13/10/green-longhorn-beetle-7353749_1280.jpg" #@param {type:"string"}
    blue_scale = 0 #@param {type:"number"}
    
    #@markdown **Conditional Gradient Settings**
    gradient_wrt = 'x0_pred' #@param ["x", "x0_pred"]
    gradient_add_to = 'both' #@param ["cond", "uncond", "both"]
    decode_method = 'linear' #@param ["autoencoder","linear"]
    grad_threshold_type = 'dynamic' #@param ["dynamic", "static", "mean", "schedule"]
    clamp_grad_threshold = 0.2 #@param {type:"number"}
    clamp_start = 0.2 #@param
    clamp_stop = 0.01 #@param
    grad_inject_timing = list(range(1,10)) #@param

    #@markdown **Speed vs VRAM Settings**
    cond_uncond_sync = True #@param {type:"boolean"}
    precision = 'autocast' 
    C = 4
    f = 8

    cond_prompt = ""
    cond_prompts = ""
    uncond_prompt = ""
    uncond_prompts = ""
    timestring = ""
    init_latent = None
    init_sample = None
    init_sample_raw = None
    mask_sample = None
    init_c = None
    seed_internal = 0

    return locals()

args_dict = DeforumArgs()
anim_args_dict = DeforumAnimArgs()

if override_settings_with_file:
    load_args(args_dict, anim_args_dict, settings_file, custom_settings_file, verbose=False)

args = SimpleNamespace(**args_dict)
anim_args = SimpleNamespace(**anim_args_dict)

args.timestring = time.strftime('%Y%m%d%H%M%S')
args.strength = max(0.0, min(1.0, args.strength))

# Load clip model if using clip guidance
if (args.clip_scale > 0) or (args.aesthetics_scale > 0):
    root.clip_model = clip.load(args.clip_name, jit=False)[0].eval().requires_grad_(False).to(root.device)
    if (args.aesthetics_scale > 0):
        root.aesthetics_model = load_aesthetics_model(args, root)

if args.seed == -1:
    args.seed = random.randint(0, 2**32 - 1)
if not args.use_init:
    args.init_image = None
if args.sampler == 'plms' and (args.use_init or anim_args.animation_mode != 'None'):
    print(f"Init images aren't supported with PLMS yet, switching to KLMS")
    args.sampler = 'klms'
if args.sampler != 'ddim':
    args.ddim_eta = 0

if anim_args.animation_mode == 'None':
    anim_args.max_frames = 1
elif anim_args.animation_mode == 'Video Input':
    args.use_init = True

# clean up unused memory
gc.collect()
torch.cuda.empty_cache()

# dispatch to appropriate renderer
if anim_args.animation_mode == '2D' or anim_args.animation_mode == '3D':
    render_animation(root, anim_args, args, cond_prompts, uncond_prompts)
elif anim_args.animation_mode == 'Video Input':
    render_input_video(root, anim_args, args, cond_prompts, uncond_prompts)
elif anim_args.animation_mode == 'Interpolation':
    render_interpolation(root, anim_args, args, cond_prompts, uncond_prompts)
else:
    render_image_batch(root, args, cond_prompts, uncond_prompts)
print("\033[92m..Your Animation Has Completed, Congratulations..")

 $\color{blue} {\large \textsf{Create Video From Frames}}$

In [None]:
skip_video_for_run_all = False #@param {type: 'boolean'}
fps = 30 #@param {type:"number"}
#@markdown **Manual Settings**
use_manual_settings = False #@param {type:"boolean"}
image_path = "/content/drive/MyDrive/AI/StableDiffusion/2023-02/hypernet/20230101212135_%05d.png" #@param {type:"string"}
mp4_path = "/content/drive/MyDrive/AI/StableDiffusion/2023-02/hypernet/20230101212135.mp4" #@param {type:"string"}
render_steps = False  #@param {type: 'boolean'}
path_name_modifier = "x0_pred" #@param ["x0_pred","x"]
make_gif = False #@param{type:'boolean'}
bitdepth_extension = "exr" if args.bit_depth_output == 32 else "png"

if skip_video_for_run_all == True:
    print('Skipping video creation, uncheck skip_video_for_run_all if you want to run it')
else:
    import os
    import subprocess
    from base64 import b64encode

    print(f"{image_path} -> {mp4_path}")

    if use_manual_settings:
        max_frames = "200" #@param {type:"string"}
    else:
        if render_steps: # render steps from a single image
            fname = f"{path_name_modifier}_%05d.png"
            all_step_dirs = [os.path.join(args.outdir, d) for d in os.listdir(args.outdir) if os.path.isdir(os.path.join(args.outdir,d))]
            newest_dir = max(all_step_dirs, key=os.path.getmtime)
            image_path = os.path.join(newest_dir, fname)
            print(f"Reading images from {image_path}")
            mp4_path = os.path.join(newest_dir, f"{args.timestring}_{path_name_modifier}.mp4")
            max_frames = str(args.steps)
        else: # render images for a video
            image_path = os.path.join(args.outdir, f"{args.timestring}_%05d.{bitdepth_extension}")
            mp4_path = os.path.join(args.outdir, f"{args.timestring}.mp4")
            max_frames = str(anim_args.max_frames)

    # make video
    cmd = [
        'ffmpeg',
        '-y',
        '-vcodec', bitdepth_extension,
        '-r', str(fps),
        '-start_number', str(0),
        '-i', image_path,
        '-frames:v', max_frames,
        '-c:v', 'libx264',
        '-vf',
        f'fps={fps}',
        '-pix_fmt', 'yuv420p',
        '-crf', '17',
        '-preset', 'veryfast',
        '-pattern_type', 'sequence',
        mp4_path
    ]
    process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    stdout, stderr = process.communicate()
    if process.returncode != 0:
        print(stderr)
        raise RuntimeError(stderr)

    mp4 = open(mp4_path,'rb').read()
    data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
    display.display(display.HTML(f'<video controls loop><source src="{data_url}" type="video/mp4"></video>') )
    
    if make_gif:
         gif_path = os.path.splitext(mp4_path)[0]+'.gif'
         cmd_gif = [
             'ffmpeg',
             '-y',
             '-i', mp4_path,
             '-r', str(fps),
             gif_path
         ]
         process_gif = subprocess.Popen(cmd_gif, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

 $\color{blue} {\large \textsf{Create Video From Frames(Alternative Version)}}$

In [None]:
#@markdown **Alternative Version**
from helpers.ffmpeg_helpers import make_mp4_ffmpeg
skip_video_for_run_all = False #@param {type: 'boolean'}

if skip_video_for_run_all == True:
    print('Skipping video creation, uncheck skip_video_for_run_all if you want to run it')
else:

    from helpers.ffmpeg_helpers import get_extension_maxframes, get_auto_outdir_timestring, get_ffmpeg_path, make_mp4_ffmpeg, make_gif_ffmpeg, patrol_cycle

    def ffmpegArgs():
        ffmpeg_mode = "auto" #@param ["auto","manual","timestring"]
        ffmpeg_outdir = "/content/drive/MyDrive/" #@param {type:"string"}
        ffmpeg_timestring = "20230203050018" #@param {type:"string"}
        ffmpeg_image_path = "/content/drive/MyDrive/AI/StableDiffusion/2023-02/hypernet/20230203050018_%05d.png" #@param {type:"string"}
        ffmpeg_mp4_path = "/content/drive/MyDrive/20230203050018.mp4" #@param {type:"string"}
        ffmpeg_gif_path = "" #@param {type:"string"}
        ffmpeg_extension = "png" #@param {type:"string"}
        ffmpeg_maxframes = 1120 #@param
        ffmpeg_fps = 30 #@param

        # determine auto paths
        if ffmpeg_mode == 'auto':
            ffmpeg_outdir, ffmpeg_timestring = get_auto_outdir_timestring(args,ffmpeg_mode)
        if ffmpeg_mode in ["auto","timestring"]:
            ffmpeg_extension, ffmpeg_maxframes = get_extension_maxframes(args,ffmpeg_outdir,ffmpeg_timestring)
            ffmpeg_image_path, ffmpeg_mp4_path, ffmpeg_gif_path = get_ffmpeg_path(ffmpeg_outdir, ffmpeg_timestring, ffmpeg_extension)
        return locals()

    ffmpeg_args_dict = ffmpegArgs()
    ffmpeg_args = SimpleNamespace(**ffmpeg_args_dict)
    make_mp4_ffmpeg(ffmpeg_args, display_ffmpeg=True, debug=True)
    # make_gif_ffmpeg(ffmpeg_args, debug=True)
    # patrol_cycle(args,ffmpeg_args)

In [None]:
#@title $\color{blue} {\textsf{Display The Video Using kora(may error out, try uninstalling with "pip uninstall kora" if it doesn't work)!}}$

from kora.drive import upload_public
from IPython.display import HTML
url = '/content/drive/MyDrive/20230202065924.mp4' #@param{type:'string'}
url = upload_public(url)
# mp4_path = "/content/drive/MyDrive/AI/StableDiffusion/2023-01/kong/20230123004448.mp4"
# mp4 = open(mp4_path,'rb').read()
# data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
# then display it
HTML(f"""<video src={url} width=500 controls/>""")

In [None]:
skip_disconnect_for_run_all = True #@param {type: 'boolean'}

if skip_disconnect_for_run_all == True:
    print('Skipping disconnect, uncheck skip_disconnect_for_run_all if you want to run it')
else:
    from google.colab import runtime
    runtime.unassign()

# $ \color{cyan} {\large \textsf{Model Merger}}$

In [None]:
#@title $ \color{cyan} {\large \textsf{Merge Model!}}$
import os
import re
import shutil

import torch
import tqdm
from types import SimpleNamespace
import safetensors.torch

def Load_checkpoints_to_be_merged():
  ckpt_dir = "/content/drive/MyDrive/sd/stable-diffusion-webui/models/Stable-diffusion" #@param{type:'string'}
  primary_model_name = "/content/drive/MyDrive/sd/stable-diffusion-webui/models/Stable-diffusion/protogenNovaExperimental_protogenX80.ckpt" #@param{type:'string'}
  secondary_model_name = "/content/drive/MyDrive/sd/stable-diffusion-webui/models/Stable-diffusion/DeliberateDream.ckpt" #@param{type:'string'}
  tertiary_model_name = "/content/drive/MyDrive/sd/stable-diffusion-webui/models/Stable-diffusion/DreamShaper.ckpt" #@param{type:'string'}
  output_modelname = "Deliberate_Nova_Dream" #@param{type:'string'}
  config_source = "/content/drive/MyDrive/sd/stable-diffusion-webui/models/Stable-diffusion/Dreamlike-Gen1.0.yaml" #@param{type:'string'}
  map_location = "cuda" #@param['cuda', 'cpu']
  interp_method = "Weighted sum" #@param["No interpolation", "Weighted sum", "Add difference"]
  multiplier = 0.6 #@param{type:'slider', min:0.1, max:1, step:0.1}
  save_as_half = False #@param{type:'boolean'}
  custom_name = "cuda" #@param['cuda', 'cpu']
  checkpoint_format = "ckpt" #@param['ckpt', 'safetensor']
  bake_in_vae = "/content/drive/MyDrive/AI/models/vae-ft-mse-840000-ema-pruned.ckpt" #@param{type:'string'}
  discard_weights = True #@param{type:'boolean'}
  return locals()

load_checkpoints_to_be_merged = Load_checkpoints_to_be_merged()
load_checkpoints_to_be_merged = SimpleNamespace(**load_checkpoints_to_be_merged)

chckpoint_dict_replacements = {
    'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
    'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
    'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
}

def _load_vae_dict(model, vae_dict_1):
    model.first_stage_model.load_state_dict(vae_dict_1)
    model.first_stage_model.to("cuda")

def load_vae_dict(filename, map_location):
    vae_ckpt = read_state_dict(filename, map_location=map_location)
    vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
    return vae_dict_1


def load_vae(model, vae_file=None, vae_source="from unknown source"):
    global vae_dict, loaded_vae_file
    # save_settings = False


    if vae_file:
      _load_vae_dict(model, vae_file)
    else:
      assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}"
      print(f"Loading VAE weights {vae_source}: {vae_file}")
            

    vae_dict_1 = load_vae_dict(vae_file, load_checkpoints_to_be_merged.map_location)
    _load_vae_dict(model, vae_dict_1)

    vae_file = vae_dict_1.copy()

    loaded_vae_file = vae_file



def find_checkpoint_config(info):
    config = os.path.splitext(load_checkpoints_to_be_merged.primary_model_name) + ".yaml"
    if os.path.exists(config):
        return config

    return config

def transform_checkpoint_dict_key(k):
    for text, replacement in chckpoint_dict_replacements.items():
        if k.startswith(text):
            k = replacement + k[len(text):]

    return k

def get_state_dict_from_checkpoint(pl_sd):
    pl_sd = pl_sd.pop("state_dict", pl_sd)
    pl_sd.pop("state_dict", None)

    sd = {}
    for k, v in pl_sd.items():
        new_key = transform_checkpoint_dict_key(k)

        if new_key is not None:
            sd[new_key] = v

    pl_sd.clear()
    pl_sd.update(sd)

    return pl_sd

def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
    _, extension = os.path.splitext(checkpoint_file)
    if extension.lower() == ".safetensors":
        device = load_checkpoints_to_be_merged.map_location
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
    else:
        pl_sd = torch.load(checkpoint_file, map_location=map_location)

    if print_global_state and "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")

    sd = get_state_dict_from_checkpoint(pl_sd)
    return sd

def create_config(ckpt_result, config_source, a, b, c):

    cfg = load_checkpoints_to_be_merged.config_source

    filename, _ = os.path.splitext(load_checkpoints_to_be_merged.output_modelname)
    checkpoint_filename = os.path.dirname(load_checkpoints_to_be_merged.primary_model_name) + "/" + ckpt_result + ".yaml"
    print("Copying config:")
    print("   from:", cfg)
    print("     to:", checkpoint_filename)
    shutil.copyfile(cfg, checkpoint_filename)


checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]


def to_half(tensor, enable):
    if enable and tensor.dtype == torch.float:
        return tensor.half()

    return tensor

vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
vae_dict = {}
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights):

    def fail(message):
        textinfo = message
        return message

    def weighted_sum(theta0, theta1, alpha):
        return ((1 - alpha) * theta0) + (alpha * theta1)

    def get_difference(theta1, theta2):
        return theta1 - theta2

    def add_difference(theta0, theta1_2_diff, alpha):
        return theta0 + (alpha * theta1_2_diff)

    def filename_weighted_sum():
        a = load_checkpoints_to_be_merged.primary_model_name
        b = load_checkpoints_to_be_merged.secondary_model_name
        Ma = round(1 - load_checkpoints_to_be_merged.multiplier, 2)
        Mb = round(load_checkpoints_to_be_merged.multiplier, 2)

        return f"{Ma}({a}) + {Mb}({b})"

    def filename_add_difference():
        a = load_checkpoints_to_be_merged.primary_model_name
        b = load_checkpoints_to_be_merged.secondary_model_name
        c = load_checkpoints_to_be_merged.tertiary_model_name
        M = round(load_checkpoints_to_be_merged.multiplier, 2)

        return f"{a} + {M}({b} - {c})"

    def filename_nothing():
        return load_checkpoints_to_be_merged.primary_model_name

    theta_funcs = {
        "Weighted sum": (filename_weighted_sum, None, weighted_sum),
        "Add difference": (filename_add_difference, get_difference, add_difference),
        "No interpolation": (filename_nothing, None, None),
    }
    filename_generator, theta_func1, theta_func2 = theta_funcs[load_checkpoints_to_be_merged.interp_method]
    job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)

    if not primary_model_name:
        return fail("Failed: Merging requires a primary model.")

    primary_model_info = load_checkpoints_to_be_merged.primary_model_name

    if theta_func2 and not secondary_model_name:
        return fail("Failed: Merging requires a secondary model.")

    secondary_model_info = load_checkpoints_to_be_merged.secondary_model_name if theta_func2 else None

    if theta_func1 and not tertiary_model_name:
        return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")

    tertiary_model_info = load_checkpoints_to_be_merged.tertiary_model_name if theta_func1 else None

    result_is_inpainting_model = False

    if theta_func2:
        textinfo = f"Loading B"
        print(f"Loading {secondary_model_info}...")
        theta_1 = read_state_dict(secondary_model_info, map_location='cpu')
    else:
        theta_1 = None

    if theta_func1:
        textinfo = f"Loading C"
        print(f"Loading {tertiary_model_info}...")
        theta_2 = read_state_dict(tertiary_model_info, map_location='cpu')

        textinfo = 'Merging B and C'
        sampling_steps = len(theta_1.keys())
        for key in tqdm.tqdm(theta_1.keys()):
            if key in checkpoint_dict_skip_on_merge:
                continue

            if 'model' in key:
                if key in theta_2:
                    t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
                    theta_1[key] = theta_func1(theta_1[key], t2)
                else:
                    theta_1[key] = torch.zeros_like(theta_1[key])

            sampling_steps += 1
        del theta_2


    textinfo = f"Loading {primary_model_info}..."
    print(f"Loading {primary_model_info}...")
    theta_0 = read_state_dict(primary_model_info, map_location='cpu')

    print("Merging...")
    textinfo = 'Merging A and B'
    sampling_steps = len(theta_0.keys())
    for key in tqdm.tqdm(theta_0.keys()):
        if theta_1 and 'model' in key and key in theta_1:

            if key in checkpoint_dict_skip_on_merge:
                continue

            a = theta_0[key]
            b = theta_1[key]

            # this enables merging an inpainting model (A) with another one (B);
            # where normal model would have 4 channels, for latenst space, inpainting model would
            # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
            if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:
                if a.shape[1] == 4 and b.shape[1] == 9:
                    raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")

                assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"

                theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
                result_is_inpainting_model = True
            else:
                theta_0[key] = theta_func2(a, b, multiplier)

            theta_0[key] = to_half(theta_0[key], save_as_half)

        sampling_steps += 1

    del theta_1

    bake_in_vae_filename = load_checkpoints_to_be_merged.bake_in_vae
    if bake_in_vae_filename is not None:
        print(f"Baking in VAE from {bake_in_vae_filename}")
        textinfo = 'Baking in VAE'
        vae_dict = load_vae_dict(bake_in_vae_filename, map_location='cpu')

        for key in vae_dict.keys():
            theta_0_key = 'first_stage_model.' + key
            if theta_0_key in theta_0:
                theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half)

        del vae_dict

    if save_as_half and not theta_func2:
        for key in theta_0.keys():
            theta_0[key] = to_half(theta_0[key], save_as_half)

    # if discard_weights:
    #     regex = re.compile(discard_weights)
    #     for key in list(theta_0):
    #         if re.search(regex, key):
    #             theta_0.pop(key, None)

    ckpt_dir = load_checkpoints_to_be_merged.ckpt_dir
    filename = filename_generator() if custom_name == '' else custom_name
    filename += ".inpainting" if result_is_inpainting_model else ""
    filename += "." + checkpoint_format

    output_modelname = os.path.join(ckpt_dir, filename)

    textinfo = "Saving"
    print(f"Saving to {output_modelname}...")

    _, extension = os.path.splitext(output_modelname)
    if extension.lower() == ".safetensors":
        safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
    else:
        torch.save(theta_0, output_modelname)


    ckpt_cfg = create_config(load_checkpoints_to_be_merged.output_modelname,
                  load_checkpoints_to_be_merged.config_source,
                  primary_model_info,
                  secondary_model_info,
                  tertiary_model_info)

    print(f"Checkpoint saved to {output_modelname}.")
    textinfo = "Checkpoint saved"
    

    return load_checkpoints_to_be_merged.output_modelname, ckpt_cfg

run_modelmerger(load_checkpoints_to_be_merged.primary_model_name,
                load_checkpoints_to_be_merged.secondary_model_name, 
                load_checkpoints_to_be_merged.tertiary_model_name, 
                load_checkpoints_to_be_merged.interp_method, 
                load_checkpoints_to_be_merged.multiplier, 
                load_checkpoints_to_be_merged.save_as_half, 
                load_checkpoints_to_be_merged.output_modelname, 
                load_checkpoints_to_be_merged.checkpoint_format, 
                load_checkpoints_to_be_merged.config_source, 
                load_checkpoints_to_be_merged.bake_in_vae, 
                load_checkpoints_to_be_merged.discard_weights)

print("\033[92m..Model Merging Has Completed, Congratulations..")

# $\color{salmon} {\large \textsf{REAL-ESRGAN Video Upscaling!}}$

Please Be Sure to Run The Enviornment Setup Cell!


In [None]:
import os
import cv2
#@title $\color{salmon} {\large \textsf { 📹 RUN REAL-ESRGAN VIDEO UPSCALE PHASE (from video) 📹}}$ 
model_name = 'realesr-general-x4v3' #@param ['realesr-general-x4v3', 'realesr-general-wdn-x4v3', 'RealESRGAN_x4plus','RealESRGAN_x4plus_anime_6B','RealESRGAN_x2plus'] {type:"string"}
lowres_vid = '/content/drive/MyDrive/LINUM/lowres/lowres_input_video/cell.mp4' #@param{type:'string'}
#@markdown $\color{salmon} {\textsf {input your lowres video path here to be upscaled}}$
highres_vid_dir = '/content/drive/MyDrive/upscaled_storage' #@param{type:'string'}
#@markdown $\color{salmon} {\textsf {input your high res video output path here}}$
if not os.path.exists(highres_vid_dir):
  os.makedirs(highres_vid_dir)
outscale = "4" #@param ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13']
denoising_strength = 1 #@param{type:'slider', step:0.01, min:0.01, max:1}
target_fps = 30 #@param{type:'number'}
fp32 = 'fp32'
gpu_id = 0 #@param ['0', '1', '2', '3', '4']

%cd '/content/Real-ESRGAN/'
!python3 /content/Real-ESRGAN/inference_realesrgan_video.py --model_name $model_name --outscale $outscale -dn $denoising_strength --fp32 --input $lowres_vid --fps $target_fps --output $highres_vid_dir

%cd '/content/'
print("\033[92m...Upscaling Phase Has Completed, You May Proceed...")

In [None]:
import os
import cv2
#@title $\color{salmon} {\large \textsf { 📹 !BATCH! RUN REAL-ESRGAN VIDEO UPSCALE PHASE (from video) 📹}}$ 
model_name = 'realesr-general-x4v3' #@param ['realesr-general-x4v3', 'realesr-general-wdn-x4v3', 'RealESRGAN_x4plus','RealESRGAN_x4plus_anime_6B','RealESRGAN_x2plus'] {type:"string"}
batch_viddir = "/content/drive/MyDrive/LINUM/lowres/lowres_input_video/" #@param{type:'string'}
#@markdown $\color{salmon} {\textsf {input folder path with videos to upscale **WILL UPSCALE ALL VIDEOS IN THE FOLDER**}}$
highres_vid_dir = '/content/drive/MyDrive/LINUM/4k' #@param{type:'string'}
#@markdown $\color{salmon} {\textsf {input your high res video output path here}}$
if not os.path.exists(highres_vid_dir):
  os.makedirs(highres_vid_dir)
outscale = "4" #@param ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13']
denoising_strength = 1 #@param{type:'slider', step:0.01, min:0.01, max:1}
target_fps = 30 #@param{type:'number'}
fp32 = 'fp32'
gpu_id = 0 #@param ['0', '1', '2', '3', '4']
updir = os.path.split(highres_vid_dir)[0] + "/"

for vid in os.listdir(batch_viddir):
  new_vid = os.path.join(batch_viddir, vid)
  print(new_vid)
  %cd '/content/Real-ESRGAN'
  !python3 /content/Real-ESRGAN/inference_realesrgan_video.py --model_name $model_name --outscale $outscale -dn $denoising_strength --fp32 --input $new_vid --fps $target_fps --output $highres_vid_dir
  !mv $new_vid $updir --verbose
  print(f"\033[93mMoving {new_vid} to {updir}.")
%cd '/content/'
print("\033[92m...Upscaling Phase Has Completed, You May Proceed...")

# $\color{orange} {\large \textsf{RIFE 4.6 Interpolation!}}$

Please Be Sure to Run The Enviornment Setup Cell!

In [None]:
#@title $\color{orange} {\large \textsf{Define Params and Run RIFE!}}$

input_RIFE = '/content/drive/MyDrive/upscaled_storage/cell_out.mp4' # @param {type: 'string'}
#@markdown $\color{orange} {\textsf{THIS IS THE FILE NAME FOR YOUR COMPRESSED VID}}$

output_RIFE = '/content/drive/MyDrive/LINUM/4k/cell_out_RIFE.mp4' # @param {type: 'string'}
#@markdown $\color{orange} {\textsf{INPUT YOUR OUTPUT NAME FOR THE INTERPOLATED VIDEO}}$
fps_RIFE = 60 # @param {type: 'number'}
target_length_RIFE = 35 # @param {type: 'number'}
target_scale_RIFE = 0.5 #@param ['0.25', '0.5', '1.0', '2.0', '4.0']
_skip_RIFEsmoothing = False
#@markdown $\color{orange} {\textsf{Additional Params}}$
print("\n ------------------------------------\n")
print("\n Beginning RIFE motion smoothing phase... \n")

%cd /content/Practical-RIFE/

class Detection:
  #def __init__(self):
  #  pass
  def detect_fps(input): #needs portable
    import re
    fps_ffprobe = !ffprobe -v error -select_streams v -of default=noprint_wrappers=1:nokey=1 -show_entries stream=avg_frame_rate $input
    fps_unfinished = [str(i) for i in fps_ffprobe] # Converting integers into strings
    fps_unfinishedTwo = str("".join(fps_unfinished)) # Join the string values into one string
    numbers = re.findall('[0-9]+', fps_unfinishedTwo)
    newNum = numbers[0:1]
    strings = [str(integer) for integer in newNum]
    a_string = "".join(strings)
    fps = int(a_string)
    #print("Detected FPS is",fps)
    return fps
  def detect_duration(input):  #needs portable
    import re
    duration_ffprobe = !ffprobe -v error -select_streams v:0 -show_entries stream=duration -of default=noprint_wrappers=1:nokey=1 $input
    duration_unfinished = [str(i) for i in duration_ffprobe] # Converting integers into strings
    duration_unfinishedTwo = str("".join(duration_unfinished)) # Join the string values into one string
    numbers = re.findall('[0-9]+', duration_unfinishedTwo)
    newNum = numbers[0:1]
    strings = [str(integer) for integer in newNum]
    a_string = "".join(strings)
    duration = float(int(a_string))
    #print("Detected duration INTEGER (in seconds) is",duration)
    return duration
  def exp_calc(measured_duration,target_length_RIFE): #needs portable
    import numpy as np
    a = measured_fps * measured_duration
    b = fps_RIFE * target_length_RIFE
    c = b / a
    l = np.log(c) / np.log(2)
    print("Un-rounded --exp is",l)
    x = round(l)
    if x < 1:
      x = 1
    print("Rounding up to an --exp of ",x)
    return x

#----------------------------
_import_mp4_file = False

if _import_mp4_file:
  measured_fps = Detection.detect_fps(input_RIFE)
  print("\n NOTICE: Detected average FPS of ",input_RIFE," is ",measured_fps)
  measured_duration = Detection.detect_duration(input_RIFE)
else: 
  measured_fps = Detection.detect_fps(input_RIFE)
  print("\n NOTICE: Detected average FPS of ",{input_RIFE}," is ",measured_fps)
  measured_duration = Detection.detect_duration(input_RIFE)

print("\n NOTICE: Detected duration INTEGER (in seconds) is ",measured_duration)

if measured_duration < 1: #failsafe
  print("\n NOTICE: Your input appears to be less than one second... \n")
  measured_duration = 1

exp_value = Detection.exp_calc(measured_duration,target_length_RIFE)


print("\n NOTICE: Target duration currently rounds to the closest --exp RIFE can handle. \n")

if (exp_value < 0.5):
  _skip_RIFEsmoothing = False
  print("\n NOTICE: Your fps_RIFE doesn't necessitate RIFE motion smoothing. Skipping RIFE...\n")

if _skip_RIFEsmoothing:
  print("\n NOTICE: Skipping RIFE motion smoothing...\n")
else:
  #---RUN RIFE------------------------------------
  print("\n NOTICE: Running RIFE... \n")
  %cd /content/Practical-RIFE/
  !python3 /content/Practical-RIFE/inference_video.py --fps=$fps_RIFE --exp=$exp_value  --video=$input_RIFE --scale=$target_scale_RIFE --output=$output_RIFE
  #--exp=$exp_value
  #---END--------------------------------------
#!ffmpeg -y -i $input $visual_effects -c:v hevc_nvenc -rc vbr -cq $_constant_quality -qmin $_constant_quality -qmax $_constant_quality -b:v 0 $output_fullpathname
#END OF MOTION SMOOTHING PHASE
print("\n \033[92mEnd of RIFE interpolation phase.\n")

In [None]:
#@title $\color{orange} {\large \textsf{!BATCH! Define Params and Run RIFE!}}$
rife_viddir = "/content/drive/MyDrive/RIFE_Batch" #@param{type:'string'}
#@markdown $\color{orange} {\textsf{THIS IS THE PATH TO THE FOLDER WITH YOUR VIDEOS TO BE INTERPOLATED}}$
for rife_vid in os.listdir(rife_viddir):
    new_vid = os.path.join(rife_viddir, rife_vid)
    print(new_vid)
fps_RIFE = 60 # @param {type: 'number'}
target_length_RIFE = 35 # @param {type: 'number'}
target_scale_RIFE = 0.5 #@param ['0.25', '0.5', '1.0', '2.0', '4.0']
_skip_RIFEsmoothing = False
#@markdown $\color{orange} {\textsf{Additional Params}}$
print("\n ------------------------------------\n")
print("\n Beginning RIFE motion smoothing phase... \n")


%cd /content/Practical-RIFE/

class Detection:
  #def __init__(self):
  #  pass
  def detect_fps(input): #needs portable
    import re
    fps_ffprobe = !ffprobe -v error -select_streams v -of default=noprint_wrappers=1:nokey=1 -show_entries stream=avg_frame_rate $input
    fps_unfinished = [str(i) for i in fps_ffprobe] # Converting integers into strings
    fps_unfinishedTwo = str("".join(fps_unfinished)) # Join the string values into one string
    numbers = re.findall('[0-9]+', fps_unfinishedTwo)
    newNum = numbers[0:1]
    strings = [str(integer) for integer in newNum]
    a_string = "".join(strings)
    fps = int(a_string)
    #print("Detected FPS is",fps)
    return fps
  def detect_duration(input):  #needs portable
    import re
    duration_ffprobe = !ffprobe -v error -select_streams v:0 -show_entries stream=duration -of default=noprint_wrappers=1:nokey=1 $input
    duration_unfinished = [str(i) for i in duration_ffprobe] # Converting integers into strings
    duration_unfinishedTwo = str("".join(duration_unfinished)) # Join the string values into one string
    numbers = re.findall('[0-9]+', duration_unfinishedTwo)
    newNum = numbers[0:1]
    strings = [str(integer) for integer in newNum]
    a_string = "".join(strings)
    duration = float(int(a_string))
    #print("Detected duration INTEGER (in seconds) is",duration)
    return duration
  def exp_calc(measured_duration,target_length_RIFE): #needs portable
    import numpy as np
    a = measured_fps * measured_duration
    b = fps_RIFE * target_length_RIFE
    c = b / a
    l = np.log(c) / np.log(2)
    print("Un-rounded --exp is",l)
    x = round(l)
    if x < 1:
      x = 1
    print("Rounding up to an --exp of ",x)
    return x

#----------------------------
_import_mp4_file = False

if _import_mp4_file:
  measured_fps = Detection.detect_fps(new_vid)
  print("\n NOTICE: Detected average FPS of ",new_vid," is ",measured_fps)
  measured_duration = Detection.detect_duration(new_vid)
else: 
  measured_fps = Detection.detect_fps(new_vid)
  print("\n NOTICE: Detected average FPS of ",new_vid," is ",measured_fps)
  measured_duration = Detection.detect_duration(new_vid)

print("\n NOTICE: Detected duration INTEGER (in seconds) is ",measured_duration)

if measured_duration < 1: #failsafe
  print("\n NOTICE: Your input appears to be less than one second... \n")
  measured_duration = 1

exp_value = Detection.exp_calc(measured_duration,target_length_RIFE)


print("\n NOTICE: Target duration currently rounds to the closest --exp RIFE can handle. \n")

if (exp_value < 0.5):
  _skip_RIFEsmoothing = False
  print("\n NOTICE: Your fps_RIFE doesn't necessitate RIFE motion smoothing. Skipping RIFE...\n")

if _skip_RIFEsmoothing:
  print("\n NOTICE: Skipping RIFE motion smoothing...\n")
else:
  #---RUN RIFE------------------------------------
  print("\n ------------------------------------\n")
  print("\n NOTICE: Running RIFE... \n")
  for rife_vid in os.listdir(rife_viddir):
    new_vid = os.path.join(rife_viddir, rife_vid)
    output_RIFE = f"{os.path.splitext(new_vid)[0]}_RIFE{os.path.splitext(new_vid)[1]}"
    print(f"Saving Video to: ", output_RIFE)
    %cd /content/Practical-RIFE/
    !python3 /content/Practical-RIFE/inference_video.py --fps=$fps_RIFE --exp=$exp_value  --video=$new_vid --scale=$target_scale_RIFE --output=$output_RIFE
  
  #---END--------------------------------------
#!ffmpeg -y -i $input $visual_effects -c:v hevc_nvenc -rc vbr -cq $_constant_quality -qmin $_constant_quality -qmax $_constant_quality -b:v 0 $output_fullpathname
#END OF MOTION SMOOTHING PHASE
print("\n \033[92mEnd of RIFE interpolation phase.\n")

In [None]:
#@title $\color{orange} {\large \textsf {Analyze Lowres Vid Dimensions and FPS}}$ 
vcap = cv2.VideoCapture(output_RIFE)

width = vcap.get(cv2.CAP_PROP_FRAME_WIDTH )
height = vcap.get(cv2.CAP_PROP_FRAME_HEIGHT )
fps =  vcap.get(cv2.CAP_PROP_FPS)
print(f"Width is: \033[92m{width}.")
print(f"Height is: \033[92m{height}.")
print(f"FPS is: \033[92m{fps}.")

# $\color{blue} {\large \textsf{REAL-ESRGAN Single Image or Frame Folder Upscaling!}}$


In [None]:
#@title $\color{blue} {\large \textsf { 🚀 RUN REAL-ESRGAN UPSCALE PHASE (from frames) 🚀}}$
#@markdown $\color{blue} {\large \textsf {Input_dir is either a single image or a folder of images}}$
model_name = 'realesr-general-x4v3' #@param ['realesr-general-x4v3', 'realesr-general-wdn-x4v3', 'RealESRGAN_x4plus','RealESRGAN_x4plus_anime_6B','RealESRGAN_x2plus'] {type:"string"}
input_dir = '/content/drive/MyDrive/AI/StableDiffusion/2023-02/upscale/20230205235356_00002_dreamlikeart_portrait_of_LuisapDarkagepinguinstyleDarkage_hulking_herculean_extra_terrestrial_aquatic_grey_alien_horns_muted_colors_dark_gore_creepy_luisap.png' #@param{type:'string'}
output_dir = '/content/drive/MyDrive/' #@param{type:'string'}
outscale = "12" #@param ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13']
denoising_strength = 1 #@param{type:'slider', step:0.01, min:0.01, max:1}
fp32 = 'fp32'
gpu_id = 0 #@param ['0', '1', '2', '3', '4']
%cd '/content/'
%cd '/content/Real-ESRGAN/'
!python3 /content/Real-ESRGAN/inference_realesrgan.py --model_name $model_name --outscale $outscale -dn $denoising_strength --fp32 --gpu-id $gpu_id --input $input_dir --output $output_dir
%cd '/content/'
print("\033[92m...Upscaling Phase Has Completed, You May Proceed...")

In [None]:
from PIL import Image
display_the_image = False #@param{type:'boolean'}
display_image = "/content/drive/MyDrive/20230205235356_00002_dreamlikeart_portrait_of_LuisapDarkagepinguinstyleDarkage_hulking_herculean_extra_terrestrial_aquatic_grey_alien_horns_muted_colors_dark_gore_creepy_luisap_out.png" #@param{type:'string'}
#@markdown Path to the Image you want to Display
if display_the_image:
    im = Image.open(display_image)
    display.display(im)

def analyze_image_dimensions(file_path):
    with Image.open(file_path) as img:
        width, height = img.size
        print(f"Image Width: \033[92m{width}, \033[0mImage Height: \033[92m{height}")

analyze_image_dimensions(display_image)

Image Width: [92m6144, [0mImage Height: [92m12288


# $\color{red} {\large \textsf{Disconnect Runtime}}$

In [None]:
#@title $\color{red} {\large \textsf{Disconnect Runtime}}$
from google.colab import runtime
runtime.unassign()