In [None]:
!pip install DALLE2-pytorch==1.1.0 einops einops-exts kornia ftfy vector-quantize-pytorch resize-right clip-anytorch rotary-embedding-torch coca-pytorch pytorch-warmup ema-pytorch x-clip accelerate gradio
!git clone https://github.com/LAION-AI/dalle2-laion.git
!mv 'dalle2-laion/dalle2_laion' dalle2_laion
#!git clone https://github.com/lucidrains/DALLE2-pytorch.git
#!mv 'DALLE2-pytorch/dalle2_pytorch' dalle2_pytorch


In [1]:
from dataclasses import dataclass
from typing import Any, Tuple, Optional, TypeVar, Generic, List,Iterator,  Dict
from dalle2_laion.config import DecoderLoadConfig, SingleDecoderLoadConfig, PriorLoadConfig, ModelLoadConfig
from dalle2_pytorch import __version__ as Dalle2Version, Decoder, DiffusionPrior, Unet
from dalle2_pytorch.train_configs import TrainDecoderConfig, TrainDiffusionPriorConfig, DecoderConfig, UnetConfig, DiffusionPriorConfig
import torch
from torch import LongTensor, FloatTensor, BoolTensor,nn
from packaging import version

from accelerate import init_empty_weights


def get_keys_to_submodule(model: nn.Module) -> Dict[str, nn.Module]:
    keys_to_submodule = {}
    # iterate all submodules
    for submodule_name, submodule in model.named_modules():
        # iterate all paramters in each submobule
        for param_name, param in submodule.named_parameters():
            # param_name is organized as <name>.<subname>.<subsubname> ...
            splitted_param_name = param_name.split('.')
            # we cannot go inside it anymore. This is the actual parameter
            is_leaf_param = len(splitted_param_name) == 1
            if is_leaf_param:
                # we recreate the correct key
                key = f"{submodule_name}.{param_name}"
                # we associate this key with this submodule
                keys_to_submodule[key] = submodule
                
    return keys_to_submodule

def load_state_dict_with_low_memory(model: nn.Module, state_dict):
    print('======hacky load======')
    # free up memory by placing the model in the `meta` device
    keys_to_submodule = get_keys_to_submodule(model)
    mste=model.state_dict()
    for key, submodule in keys_to_submodule.items():
        # get the valye from the state_dict
        if key in state_dict:
          val = state_dict[key]
        else:
          print(key)
          val = torch.ones(mste[key].shape, dtype= torch.float32)
        # we need to substitute the parameter inside submodule, 
        # remember key is composed of <name>.<subname>.<subsubname>
        # the actual submodule's parameter is stored inside the 
        # last subname. If key is `in_proj.weight`, the correct field if `weight`
        param_name = key.split('.')[-1]
        #param_dtype = getattr(submodule, param_name).dtype
        #val = val.to(param_dtype)
        # create a new parameter
        new_val = torch.nn.Parameter(val)
        setattr(submodule, param_name, new_val)



def exists(obj: Any) -> bool:
    return obj is not None

@dataclass
class DataRequirements:
    image_embedding: bool
    text_encoding: bool
    image: bool
    text: bool
    can_generate_embedding: bool
    image_size: int

    def has_clip(self):
        self.can_generate_embedding = True

    def is_valid(
        self,
        has_image_emb: bool = False, has_text_encoding: bool = False,
        has_image: bool = False, has_text: bool = False,
        image_size: Optional[int] = None
    ):
        # The image size must be equal to or greater than the required size
        # Verify that the text input is valid
        errors = []
        is_valid = True
        if self.text_encoding:
            # Then we need to some way to get the text encoding
            if not (has_text_encoding or (self.can_generate_embedding and has_text)):
                errors.append('Text encoding is required, but no text encoding or text was provided')
                is_valid = False
        if self.text:
            # Then this requires text be passed in explicitly
            if not has_text:
                errors.append('Text is required, but no text was provided')
                is_valid = False

        # Verify that the image input is valid
        image_size_greater = exists(image_size) and image_size >= self.image_size
        if self.image_embedding:
            # Then we need to some way to get the image embedding
            # In this case, we also need to make sure that the image size is big enough to generate the embedding
            if not (has_image_emb or (self.can_generate_embedding and has_image and image_size_greater)):
                errors.append('Image embedding is required, but no image embedding or image was provided or the image was too small')
                is_valid = False
        if self.image:
            # Then this requires an image be passed in explicitly
            # In this case we also need to make sure the image is big enough to be used
            if not (has_image and image_size_greater):
                errors.append('Image is required, but no image was provided or the image was too small')
                is_valid = False
        return is_valid, errors

    def __add__(self, other: 'DataRequirements') -> 'DataRequirements':
        return DataRequirements(
            image_embedding=self.image_embedding or other.image_embedding,  # If either needs an image embedding, the combination needs one
            text_embedding=self.text_embedding or other.text_embedding,  # If either needs a text embedding, the combination needs one
            image=self.image or other.image,  # If either needs an image, the combination needs it  
            text=self.text or other.text,  # If either needs a text, the combination needs it
            can_generate_embedding=self.can_generate_embedding and other.can_generate_embedding,  # If either cannot generate an embedding, we know that trying to replace an embedding with raw data will not work
            image_size=max(self.image_size, other.image_size)  # We can downsample without loss of information, so we use the larger image size
        )

ModelType = TypeVar('ModelType', Decoder, DiffusionPrior)

@dataclass
class ModelInfo(Generic[ModelType]):
    model: ModelType
    model_version: Optional[version.Version]
    requires_clip: bool
    data_requirements: DataRequirements

class DalleModelManager:
    """
    Used to load priors and decoders and to provide a simple interface to run general scripts against
    """
    def __init__(self, model_load_config: ModelLoadConfig, check_updates: bool = True):
        """
        Downloads the models and loads them into memory.
        If check_updates is True, then the models will be re-downloaded if checksums do not match.
        """
        self.check_updates = check_updates
        self.model_config = model_load_config
        self.current_version = version.parse(Dalle2Version)
        self.single_device = isinstance(model_load_config.devices, str)
        self.devices = [torch.device(model_load_config.devices)] if self.single_device else [torch.device(d) for d in model_load_config.devices]
        self.load_device = torch.device('cpu') if model_load_config.load_on_cpu else self.devices[0]
        self.strict_loading = model_load_config.strict_loading

        if model_load_config.decoder is not None:
            self.decoder_info = self.load_decoder(model_load_config.decoder)
        else:
            self.decoder_info = None

        if model_load_config.prior is not None:
            self.prior_info = self.load_prior(model_load_config.prior)
        else:
            self.prior_info = None

        if (exists(self.decoder_info) and self.decoder_info.requires_clip) or (exists(self.prior_info) and self.prior_info.requires_clip):
            assert model_load_config.clip is not None, 'Your model requires clip to be loaded. Please provide a clip config.'
            self.clip = model_load_config.clip.create()
            # Update the data requirements to include the clip model
            if exists(self.decoder_info):
                self.decoder_info.data_requirements.has_clip()
            if exists(self.prior_info):
                self.prior_info.data_requirements.has_clip()
        else:
            if model_load_config.clip is not None:
                print(f'WARNING: Your model does not require clip, but you provided a clip config. This will be ignored.')

    def _get_decoder_data_requirements(self, decoder_config: DecoderConfig, min_unet_number: int = 1) -> DataRequirements:
        """
        Returns the data requirements for a decoder
        """
        return DataRequirements(
            image_embedding=True,
            text_encoding=any(unet_config.cond_on_text_encodings for unet_config in decoder_config.unets[min_unet_number - 1:]),
            image=min_unet_number > 1,  # If this is an upsampler we need an image
            text=False,  # Text is never required for anything
            can_generate_embedding=False,  # This might be added later if clip is being used
            image_size=decoder_config.image_sizes[min_unet_number - 1]  # The input image size is the input to the first unet we are using
        )

    def _load_single_decoder(self, load_config: SingleDecoderLoadConfig) -> Tuple[Decoder, DecoderConfig, Optional[version.Version], bool]:
        """
        Loads a single decoder from a model and a config file
        """
        unet_sample_timesteps = load_config.default_sample_timesteps
        def apply_default_config(config: DecoderConfig):
            if unet_sample_timesteps is not None:
                base_sample_timesteps = [None] * len(config.unets)
                for unet_number, timesteps in zip(load_config.unet_numbers, unet_sample_timesteps):
                    base_sample_timesteps[unet_number - 1] = timesteps
                config.sample_timesteps = base_sample_timesteps
        
        with load_config.load_model_from.as_local_file(check_update=self.check_updates) as model_file:
            model_state_dict = torch.load(model_file, map_location=self.load_device)
            if 'version' in model_state_dict:
                model_version = model_state_dict['version']
                if model_version != self.current_version:
                    print(f'WARNING: This decoder was trained on version {model_version} but the current version is {self.current_version}. This may result in the model failing to load.')
                    print(f'FIX: Switch to this version with `pip install DALLE2-pytorch=={model_version}`. If different models suggest different versions, you may just need to choose one.')
            else:
                print(f'WARNING: This decoder was trained on an old version of Dalle2. This may result in the model failing to load or it may lead to producing garbage results.')
                model_version = None  # No version info in the model
            
            requires_clip = False
            if 'config' in model_state_dict:
                # Then we define the decoder config from this object
                decoder_config = TrainDecoderConfig(**model_state_dict['config']).decoder
                apply_default_config(decoder_config)
                if decoder_config.clip is not None:
                    # We don't want to load clip with the model
                    requires_clip = True
                    decoder_config.clip = None
                with init_empty_weights():
                  decoder = decoder_config.create()
                load_state_dict_with_low_memory(decoder, model_state_dict['model'])
                #decoder.load_state_dict(model_state_dict['model'], strict=self.strict_loading)  # If the model has a config included, then we know the model_state_dict['model'] is the actual model
            else:
                # In this case, the state_dict is the model itself. This means we also must load the config from an external file
                assert load_config.load_config_from is not None
                with load_config.load_config_from.as_local_file(check_update=self.check_updates) as config_file:
                    decoder_config = TrainDecoderConfig.from_json_path(config_file).decoder
                    apply_default_config(decoder_config)
                    if decoder_config.clip is not None:
                        # We don't want to load clip with the model
                        requires_clip = True
                        decoder_config.clip = None
                with init_empty_weights():
                  decoder = decoder_config.create()
                load_state_dict_with_low_memory(decoder, model_state_dict)
                #decoder.load_state_dict(model_state_dict, strict=self.strict_loading)
            del model_state_dict
            return decoder.requires_grad_(False).eval(), decoder_config, model_version, requires_clip

    def load_decoder(self, load_config: DecoderLoadConfig) -> 'ModelInfo[Decoder]':
        """
        Loads a decoder from a model and a config file
        """
        if len(load_config.unet_sources) == 1:
            # Then we are loading only one model
            decoder, decoder_config, decoder_version, requires_clip = self._load_single_decoder(load_config.unet_sources[0])
            decoder_data_requirements = self._get_decoder_data_requirements(decoder_config)
            return ModelInfo(decoder.requires_grad_(False).to(torch.float32), decoder_version, requires_clip, decoder_data_requirements)
        else:
            true_unets: List[Unet] = [None] * load_config.final_unet_number  # Stores the unets that will replace the ones in the true decoder
            true_unet_configs: List[UnetConfig] = [None] * load_config.final_unet_number  # Stores the unet configs that will replace the ones in the true decoder config
            true_upsampling_sizes: List[Tuple[int, int]] = [None] * load_config.final_unet_number  # Stores the progression of upsampling sizes for each unet so that we can validate these unets actually work together
            true_train_timesteps: List[int] = [None] * load_config.final_unet_number  # Stores the number of timesteps that each unet trained with
            true_beta_schedules: List[str] = [None] * load_config.final_unet_number  # Stores the beta scheduler that each unet used
            true_uses_learned_variance: List[bool] = [None] * load_config.final_unet_number  # Stores whether each unet uses learned variance
            true_sample_timesteps: List[int] = [None] * load_config.final_unet_number  # Stores the number of timesteps that each unet used to sample

            requires_clip = False
            for source in load_config.unet_sources:
                decoder, decoder_config, decoder_version, unets_requires_clip = self._load_single_decoder(source)
                if unets_requires_clip:
                    requires_clip = True
                if source.default_sample_timesteps is not None:
                    assert len(source.default_sample_timesteps) == len(source.unet_numbers)
                for i, unet_number in enumerate(source.unet_numbers):
                    unet_index = unet_number - 1
                    # Now we need to insert the unet into the true unets and the unet config into the true config
                    if source.default_sample_timesteps is not None:
                        true_sample_timesteps[unet_index] = source.default_sample_timesteps[i]
                    true_unets[unet_index] = decoder.unets[unet_index]
                    true_unet_configs[unet_index] = decoder_config.unets[unet_index]
                    true_upsampling_sizes[unet_index] = None if unet_index == 0 else decoder_config.image_sizes[unet_index - 1], decoder_config.image_sizes[unet_index]
                    true_train_timesteps[unet_index] = decoder_config.timesteps
                    true_beta_schedules[unet_index] = decoder_config.beta_schedule[unet_index]
                    true_uses_learned_variance[unet_index] = decoder_config.learned_variance if isinstance(decoder_config.learned_variance, bool) else decoder_config.learned_variance[unet_index]

            true_decoder_config_obj = {}
            # Insert the true configs into the true decoder config
            true_decoder_config_obj['unets'] = true_unet_configs
            true_image_sizes = []
            for i in range(load_config.final_unet_number):
                if i == 0:
                    true_image_sizes.append(true_upsampling_sizes[i][1])
                else:
                    assert true_upsampling_sizes[i - 1][1] == true_upsampling_sizes[i][0], f"The upsampling sizes for unet {i} are not compatible with unet {i - 1}."
                    true_image_sizes.append(true_upsampling_sizes[i][1])
            true_decoder_config_obj['image_sizes'] = true_image_sizes
            # All unets must have been trained with the same number of sampling timesteps in order to be compatible
            assert all(true_train_timesteps[0] == t for t in true_train_timesteps), f"All unets must have been trained with the same number of sampling timesteps in order to be compatible."
            true_decoder_config_obj['timesteps'] = true_train_timesteps[0]
            true_decoder_config_obj['beta_schedule'] = true_beta_schedules
            true_decoder_config_obj['learned_variance'] = true_uses_learned_variance
            # If any of the sample_timesteps are not None, then we need to insert them into the true decoder config
            if any(true_sample_timesteps):
                true_decoder_config_obj['sample_timesteps'] = true_sample_timesteps

            # Now we can create the decoder and substitute the unets
            true_decoder_config = DecoderConfig(**true_decoder_config_obj)
            decoder_data_requirements = self._get_decoder_data_requirements(true_decoder_config)
            print('trudec shit???')
            with init_empty_weights():
              decoder = true_decoder_config.create()
            decoder.unets = nn.ModuleList(true_unets)
            return ModelInfo(decoder.requires_grad_(False).to(torch.float32).to(self.devices[0]).eval(), decoder_version, requires_clip, decoder_data_requirements)
            
    def _get_prior_data_requirements(self, config: DiffusionPriorConfig) -> DataRequirements:
        """
        Returns the data requirements for a diffusion prior
        """
        return DataRequirements(
            image_embedding=False,  # This is kinda the whole point
            text_encoding=True,  # This is also kinda the whole point
            image=False,  # The prior is never conditioned on the image
            text=False,  # Text is never required for anything
            can_generate_embedding=False,  # This might be added later if clip is being used
            image_size=[-1, -1]  # This is not used
        )

    def load_prior(self, load_config: PriorLoadConfig) -> 'ModelInfo[DiffusionPrior]':
        """
        Loads a prior from a model and a config file
        """
        sample_timesteps = load_config.default_sample_timesteps
        def apply_default_config(config: DiffusionPriorConfig) -> DiffusionPriorConfig:
            """
            Applies the default config to the given config
            """
            if sample_timesteps is not None:
                config.sample_timesteps = sample_timesteps

        with load_config.load_model_from.as_local_file(check_update=self.check_updates) as model_file:
            model_state_dict = torch.load(model_file, map_location=self.load_device)
            if 'version' in model_state_dict:
                model_version = model_state_dict['version']
                if model_version != self.current_version:
                    print(f'WARNING: This prior was trained on version {model_version} but the current version is {self.current_version}. This may result in the model failing to load.')
                    print(f'FIX: Switch to this version with `pip install DALLE2-pytorch=={model_version}`. If different models suggest different versions, you may just need to choose one.')
            else:
                print('WARNING: This prior was trained on an old version of Dalle2. This may result in the model failing to load or it may produce garbage results.')
                model_version = None

            requires_clip = False
            if 'config' in model_state_dict:
                # Then we define the prior config from this object
                prior_config = TrainDiffusionPriorConfig(**model_state_dict['config']).prior
                apply_default_config(prior_config)
                if prior_config.clip is not None:
                    # We don't want to load clip with the model
                    prior_config.clip = None
                    requires_clip = True
                with init_empty_weights():
                  prior = prior_config.create()
                load_state_dict_with_low_memory(prior, model_state_dict['model'])
                #prior.load_state_dict(model_state_dict['model'], strict=self.strict_loading)
            else:
                # In this case, the state_dict is the model itself. This means we also must load the config from an external file
                assert load_config.load_config_from is not None
                with load_config.load_config_from.as_local_file(check_update=self.check_updates) as config_file:
                    prior_config = TrainDiffusionPriorConfig.from_json_path(config_file).prior
                    apply_default_config(prior_config)
                    if prior_config.clip is not None:
                        # We don't want to load clip with the model
                        prior_config.clip = None
                        requires_clip = True
                with init_empty_weights():
                  prior = prior_config.create()
                load_state_dict_with_low_memory(prior, model_state_dict)
                #prior.load_state_dict(model_state_dict, strict=self.strict_loading)
            del model_state_dict
            data_requirements = self._get_prior_data_requirements(prior_config)
            return ModelInfo(prior.requires_grad_(False).to(torch.float32).to(self.devices[0]).eval(), model_version, requires_clip, data_requirements)


In [None]:
import gradio as gr
from pathlib import Path
from typing import Dict, List
from PIL import Image as PILImage
from dalle2_laion import ModelLoadConfig, utils
from dalle2_laion.scripts import BasicInference, ImageVariation, BasicInpainting

config_path = '/content/dalle2-laion/configs/gradio.example.json'
model_config = ModelLoadConfig.from_json_path(config_path)
model_manager = DalleModelManager(model_config)

output_path = Path('/content/dalle2-laion/output/gradio')
output_path.mkdir(parents=True, exist_ok=True)

cond_scale_sliders = [gr.Slider(minimum=0.5, maximum=5, step=0.05, label="Prior Cond Scale", value=1),]
for i in range(model_manager.model_config.decoder.final_unet_number):
    cond_scale_sliders.append(gr.Slider(minimum=0.5, maximum=5, step=0.05, label=f"Decoder {i+1} Cond Scale", value=1))

def dream(text: str, samples_per_prompt: int, prior_cond_scale: float, *decoder_cond_scales: List[float]):
    prompts = text.split('\n')[:8]

    script = BasicInference(model_manager, verbose=True)
    output = script.run(prompts, prior_sample_count=samples_per_prompt, decoder_batch_size=40, prior_cond_scale=prior_cond_scale, decoder_cond_scale=decoder_cond_scales)
    all_outputs = []
    for text, embedding_outputs in output.items():
        for index, embedding_output in embedding_outputs.items():
            all_outputs.extend(embedding_output)
    return all_outputs
dream_interface = gr.Interface(
    dream,
    inputs=[
        gr.Textbox(placeholder="A corgi wearing a top hat...", lines=8),
        gr.Slider(minimum=1, maximum=4, step=1, label="Samples per prompt", value=1),
        *cond_scale_sliders
    ],
    outputs=[
        gr.Gallery()
    ],
    title="Dalle2 Dream",
    description="Generate images from text. You can give a maximum of 8 prompts at a time. Any more will be ignored. Generation takes around 5 minutes so be patient.",
)

def variation(image: PILImage.Image, text: str, num_generations: int, *decoder_cond_scales: List[float]):
    print("Variation using text:", text)
    img = utils.center_crop_to_square(image)

    script = ImageVariation(model_manager, verbose=True)
    output = script.run([img], [text], sample_count=num_generations, cond_scale=decoder_cond_scales)
    all_outputs = []
    for index, embedding_output in output.items():
        all_outputs.extend(embedding_output)
    return all_outputs
variation_interface = gr.Interface(
    variation,
    inputs=[
        gr.Image(value="https://www.thefarmersdog.com/digest/wp-content/uploads/2021/12/corgi-top-1400x871.jpg", source="upload", interactive=True, type="pil"),
        gr.Text(),
        gr.Slider(minimum=1, maximum=6, label="Number to generate", value=2, step=1),
        *cond_scale_sliders[1:]
    ],
    outputs=[
        gr.Gallery()
    ],
    title="Dalle2 Variation",
    description="Generates images similar to the input image.\nGeneration takes around 5 minutes so be patient.",
)

def inpaint(image: Dict[str, PILImage.Image], text: str, num_generations: int, prior_cond_scale: float, *decoder_cond_scales: List[float]):
    print("Inpainting using text:", text)
    img, mask = image['image'], image['mask']
    # Remove alpha from img
    img = img.convert('RGB')
    img = utils.center_crop_to_square(img)
    mask = utils.center_crop_to_square(mask)

    script = BasicInpainting(model_manager, verbose=True)
    mask = ~utils.get_mask_from_image(mask)
    output = script.run(images=[img], masks=[mask], text=[text], sample_count=num_generations, prior_cond_scale=prior_cond_scale, decoder_cond_scale=decoder_cond_scales)
    all_outputs = []
    for index, embedding_output in output.items():
        all_outputs.extend(embedding_output)
    return all_outputs
inpaint_interface = gr.Interface(
    inpaint,
    inputs=[
        gr.Image(value="https://www.thefarmersdog.com/digest/wp-content/uploads/2021/12/corgi-top-1400x871.jpg", source="upload", tool="sketch", interactive=True, type="pil"),
        gr.Text(),
        gr.Slider(minimum=1, maximum=6, label="Number to generate", value=2, step=1),
        *cond_scale_sliders
    ],
    outputs=[
        gr.Gallery()
    ],
    title="Dalle2 Inpainting",
    description="Fills in the details of areas you mask out.\nGeneration takes around 5 minutes so be patient.",
)

demo = gr.TabbedInterface(interface_list=[dream_interface, variation_interface, inpaint_interface], tab_names=["Dream", "Variation", "Inpaint"])



In [None]:
demo.launch(share=True, enable_queue=True)

In [12]:
dream(text='a blue cat on fire', samples_per_prompt=1, prior_cond_scale=1)

Generating images for texts: ['a blue cat on fire']
Generating prior embeddings...
Sampling prior with cond_scale: 1
Prior batched inputs into 1 batches. Total number of samples: 1.


sampling loop time step:   0%|          | 0/64 [00:00<?, ?it/s]

Finished generating prior embeddings.
Grouped 1 texts into 1 embeddings.
Sampling from decoder...
Sampling decoder with cond_scale: ()
Decoder batched inputs into 1 batches. Total number of samples: 1.


AssertionError: ignored

In [11]:
!nvidia-smi -a



Timestamp                                 : Sat Aug  6 05:07:41 2022
Driver Version                            : 460.32.03
CUDA Version                              : 11.2

Attached GPUs                             : 1
GPU 00000000:00:04.0
    Product Name                          : Tesla T4
    Product Brand                         : Tesla
    Display Mode                          : Enabled
    Display Active                        : Disabled
    Persistence Mode                      : Disabled
    MIG Mode
        Current                           : N/A
        Pending                           : N/A
    Accounting Mode                       : Disabled
    Accounting Mode Buffer Size           : 4000
    Driver Model
        Current                           : N/A
        Pending                           : N/A
    Serial Number                         : 1561920024640
    GPU UUID                              : GPU-e160ebdd-6388-011c-1052-c9d0ce8bf0c8
    Minor Number              

In [4]:
#force prior to gpu?
model_manager.prior_info.model=model_manager.prior_info.model.to(model_manager.devices[0])

In [None]:
zk=model_manager.decoder_info.model.state_dict()
for ky in zk:
  print(zk[ky].device)


In [None]:
zk=model_manager.prior_info.model.state_dict()
for ky in zk:
  print(zk[ky].device)

In [7]:
model_manager.devices[0]

device(type='cuda', index=0)