diff --git a/.gitignore b/.gitignore index 96cffea..0baa399 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,20 @@ # Notebooks can get filled with random files so we specify the exact ones we want to keep notebooks/ -!notebooks/dalle2_laion_alpha.ipynb \ No newline at end of file +!notebooks/dalle2_laion_alpha.ipynb + +# Configuration Files +configs/* +!configs/*.example.json +!configs/README.md + +# Build Files +dist/ +*.egg-info/ + +# Model testing +models*/ +output*/ + +# Environment Files +.env*/ +__pycache__/ \ No newline at end of file diff --git a/configs/README.md b/configs/README.md new file mode 100644 index 0000000..4127b04 --- /dev/null +++ b/configs/README.md @@ -0,0 +1,5 @@ +# Configuration +The root configuration has defines the global properties of how models will be loaded. + +| Option | Required | Default | Description | +| ------ | -------- | ------- | ----------- | \ No newline at end of file diff --git a/configs/load.example.json b/configs/load.example.json new file mode 100644 index 0000000..8505fdb --- /dev/null +++ b/configs/load.example.json @@ -0,0 +1,54 @@ +{ + "decoder": { + "unet_sources": [ + { + "unet_numbers": [1], + "load_model_from": { + "load_type": "url", + "path": "", + "cache_dir": "./models", + "filename_override": "first_decoder.pth" + }, + "load_config_from": { + "load_type": "url", + "path": "", + "cache_dir": "./models", + "filename_override": "first_decoder_config.json" + } + }, + { + "unet_numbers": [2], + "load_model_from": { + "load_type": "url", + "path": "", + "cache_dir": "./models", + "filename_override": "second_decoder.pth" + }, + "load_config_from": { + "load_type": "url", + "path": "", + "cache_dir": "./models", + "filename_override": "second_decoder_config.json" + } + } + ] + }, + "prior": { + "load_model_from": { + "load_type": "url", + "path": "", + "cache_dir": "./models" + }, + "load_config_from": { + "load_type": "url", + "path": "", + "cache_dir": "./models" + } + }, + "clip": { + "make": "openai", + "model": "ViT-L/14" + }, + + "devices": "cuda:0" +} \ No newline at end of file diff --git a/configs/upsampler.example.json b/configs/upsampler.example.json new file mode 100644 index 0000000..e1bdc45 --- /dev/null +++ b/configs/upsampler.example.json @@ -0,0 +1,56 @@ +{ + "decoder": { + "unet_sources": [ + { + "unet_numbers": [1], + "load_model_from": { + "load_type": "url", + "path": "https://huggingface.co/laion/DALLE2-PyTorch/resolve/main/decoder/1.5B_laion2B/latest.pth", + "cache_dir": "./models", + "filename_override": "first_decoder.pth" + }, + "load_config_from": { + "load_type": "url", + "path": "https://huggingface.co/laion/DALLE2-PyTorch/raw/main/decoder/1.5B_laion2B/decoder_config.json", + "cache_dir": "./models", + "filename_override": "first_decoder_config.json" + } + }, + { + "unet_numbers": [2], + "load_model_from": { + "load_type": "url", + "path": "https://huggingface.co/Veldrovive/upsamplers/resolve/main/working/latest.pth", + "cache_dir": "./models", + "filename_override": "second_decoder.pth" + }, + "load_config_from": { + "load_type": "url", + "path": "https://huggingface.co/Veldrovive/upsamplers/raw/main/working/decoder_config.json", + "cache_dir": "./models", + "filename_override": "second_decoder_config.json" + } + } + ] + }, + "prior": { + "load_model_from": { + "load_type": "url", + "path": "https://huggingface.co/laion/DALLE2-PyTorch/resolve/main/prior/latest.pth", + "cache_dir": "./models", + "filename_override": "prior.pth" + }, + "load_config_from": { + "load_type": "url", + "path": "https://huggingface.co/laion/DALLE2-PyTorch/raw/main/prior/prior_config.json", + "cache_dir": "./models" + } + }, + "clip": { + "make": "openai", + "model": "ViT-L/14" + }, + + "devices": "cuda:0", + "strict_loading": false +} \ No newline at end of file diff --git a/dalle2_laion/__init__.py b/dalle2_laion/__init__.py new file mode 100644 index 0000000..b30bc69 --- /dev/null +++ b/dalle2_laion/__init__.py @@ -0,0 +1,3 @@ +from dalle2_laion.dalle2_laion import DalleModelManager +from dalle2_laion.config import ModelLoadConfig +import dalle2_laion.scripts \ No newline at end of file diff --git a/dalle2_laion/config.py b/dalle2_laion/config.py new file mode 100644 index 0000000..dc08b2c --- /dev/null +++ b/dalle2_laion/config.py @@ -0,0 +1,133 @@ +from json import decoder +from pathlib import Path +from dalle2_pytorch.train_configs import AdapterConfig as ClipConfig +from typing import List, Optional, Union +from enum import Enum +from pydantic import BaseModel, root_validator, ValidationError +from contextlib import contextmanager +import tempfile +import urllib.request +import json + +class LoadLocation(str, Enum): + """ + Enum for the possible locations of the data. + """ + local = "local" + url = "url" + +class File(BaseModel): + load_type: LoadLocation + path: str + cache_dir: Optional[Path] = None + filename_override: Optional[str] = None + + def download_to(self, path: str): + """ + Downloads the file to the given path + """ + assert self.load_type == LoadLocation.url + urllib.request.urlretrieve(self.path, path) + + @property + def filename(self): + if self.filename_override is not None: + return self.filename_override + # The filename is everything after the last '/' but before the '?' if it exists + filename = self.path.split('/')[-1] + if '?' in filename: + filename = filename.split('?')[0] + return filename + + @contextmanager + def as_local_file(self): + if self.load_type == LoadLocation.local: + yield self.path + elif self.cache_dir is not None: + # Then we are caching the data in a local directory + self.cache_dir.mkdir(parents=True, exist_ok=True) + file_path = self.cache_dir / self.filename + if not file_path.exists(): + print(f"Downloading {self.path} to {file_path}") + self.download_to(file_path) + else: + print(f'{file_path} already exists. Skipping download. If you think this file should be re-downloaded, delete it and try again.') + yield file_path + else: + # Then we are not caching and the file should be stored in a temporary directory + with tempfile.TemporaryDirectory() as tmpdir: + tmpfile = tmpdir + "/" + self.filename + self.download_to(tmpfile) + yield tmpfile + +class SingleDecoderLoadConfig(BaseModel): + """ + Configuration for the single decoder load. + """ + unet_numbers: List[int] + load_model_from: File + load_config_from: Optional[File] # The config may be defined within the model file if the version is high enough + +class DecoderLoadConfig(BaseModel): + """ + Configuration for the decoder load. + """ + unet_sources: List[SingleDecoderLoadConfig] + + final_unet_number: int + + @root_validator(pre=True) + def compute_num_unets(cls, values): + """ + Gets the final unet number + """ + unet_numbers = [] + assert "unet_sources" in values, "No unet sources defined. Make sure `unet_sources` is defined in the decoder config." + for value in values["unet_sources"]: + unet_numbers.extend(value["unet_numbers"]) + final_unet_number = max(unet_numbers) + values["final_unet_number"] = final_unet_number + return values + + @root_validator + def verify_unet_numbers_valid(cls, values): + """ + The unets must go from 1 to some positive number not skipping any and not repeating any. + """ + unet_numbers = [] + for value in values["unet_sources"]: + unet_numbers.extend(value.unet_numbers) + unet_numbers.sort() + if len(unet_numbers) != len(set(unet_numbers)): + raise ValidationError("The decoder unet numbers must not repeat.") + if unet_numbers[0] != 1: + raise ValidationError("The decoder unet numbers must start from 1.") + differences = [unet_numbers[i] - unet_numbers[i - 1] for i in range(1, len(unet_numbers))] + if any(diff != 1 for diff in differences): + raise ValidationError("The decoder unet numbers must not skip any.") + return values + +class PriorLoadConfig(BaseModel): + """ + Configuration for the prior load. + """ + load_model_from: File + load_config_from: Optional[File] # The config may be defined within the model file if the version is high enough + +class ModelLoadConfig(BaseModel): + """ + Configuration for the model load. + """ + decoder: Optional[DecoderLoadConfig] = None + prior: Optional[PriorLoadConfig] = None + clip: Optional[ClipConfig] = None + + devices: Union[List[str], str] = 'cuda:0' # The device(s) to use for model inference. If a list, the first device is used for loading. + load_on_cpu: bool = True # Whether to load the state_dict on the first device or on the cpu + strict_loading: bool = True # Whether to error on loading if the model is not compatible with the current version of the code + + @classmethod + def from_json_path(cls, json_path): + with open(json_path) as f: + config = json.load(f) + return cls(**config) \ No newline at end of file diff --git a/dalle2_laion/dalle2_laion.py b/dalle2_laion/dalle2_laion.py new file mode 100644 index 0000000..e4177f8 --- /dev/null +++ b/dalle2_laion/dalle2_laion.py @@ -0,0 +1,268 @@ +from dataclasses import dataclass +from typing import Any, Tuple, Optional, NamedTuple, TypeVar, Generic, List +from dalle2_laion.config import DecoderLoadConfig, SingleDecoderLoadConfig, PriorLoadConfig, ModelLoadConfig +from dalle2_pytorch import __version__ as Dalle2Version, Decoder, DiffusionPrior, Unet +from dalle2_pytorch.train_configs import TrainDecoderConfig, TrainDiffusionPriorConfig, DecoderConfig, UnetConfig, DiffusionPriorConfig +import torch +import torch.nn as nn +from packaging import version + +def exists(obj: Any) -> bool: + return obj is not None + +@dataclass +class DataRequirements: + image_embedding: bool + text_encoding: bool + image: bool + text: bool + can_generate_embedding: bool + image_size: Tuple[int, int] + + def has_clip(self): + self.can_generate_embedding = True + + def is_valid( + self, + has_image_emb: bool = False, has_text_encoding: bool = False, + has_image: bool = False, has_text: bool = False, + image_size: Optional[Tuple[int, int]] = None + ): + # The image size must be equal to or greater than the required size + # Verify that the text input is valid + is_valid = True + if self.text_encoding: + # Then we need to some way to get the text encoding + if not (has_text_encoding or (self.can_generate_embedding and has_text)): + is_valid = False + if self.text: + # Then this requires text be passed in explicitly + if not has_text: + is_valid = False + + # Verify that the image input is valid + image_size_greater = exists(image_size) and image_size[0] >= self.image_size[0] and image_size[1] >= self.image_size[1] + if self.image_embedding: + # Then we need to some way to get the image embedding + # In this case, we also need to make sure that the image size is big enough to generate the embedding + if not (has_image_emb or (self.can_generate_embedding and has_image and image_size_greater)): + is_valid = False + if self.image: + # Then this requires an image be passed in explicitly + # In this case we also need to make sure the image is big enough to be used + if not (has_image and image_size_greater): + is_valid = False + return is_valid + + def __add__(self, other: 'DataRequirements') -> 'DataRequirements': + return DataRequirements( + image_embedding=self.image_embedding or other.image_embedding, # If either needs an image embedding, the combination needs one + text_embedding=self.text_embedding or other.text_embedding, # If either needs a text embedding, the combination needs one + image=self.image or other.image, # If either needs an image, the combination needs it + text=self.text or other.text, # If either needs a text, the combination needs it + can_generate_embedding=self.can_generate_embedding and other.can_generate_embedding, # If either cannot generate an embedding, we know that trying to replace an embedding with raw data will not work + image_size=max(self.image_size, other.image_size) # We can downsample without loss of information, so we use the larger image size + ) + +ModelType = TypeVar('ModelType', Decoder, DiffusionPrior) +class ModelInfo(NamedTuple, Generic[ModelType]): + model: ModelType + model_version: Optional[version.Version] + requires_clip: bool + data_requirements: DataRequirements + +class DalleModelManager: + """ + Used to load priors and decoders and to provide a simple interface to run general scripts against + """ + def __init__(self, model_load_config: ModelLoadConfig): + self.current_version = version.parse(Dalle2Version) + self.single_device = isinstance(model_load_config.devices, str) + self.devices = [torch.device(model_load_config.devices)] if self.single_device else [torch.device(d) for d in model_load_config.devices] + self.load_device = torch.device('cpu') if model_load_config.load_on_cpu else self.devices[0] + self.strict_loading = model_load_config.strict_loading + + if model_load_config.decoder is not None: + self.decoder_info = self.load_decoder(model_load_config.decoder) + else: + self.decoder_info = None + + if model_load_config.prior is not None: + self.prior_info = self.load_prior(model_load_config.prior) + else: + self.prior_info = None + + if (exists(self.decoder_info) and self.decoder_info.requires_clip) or (exists(self.prior_info) and self.prior_info.requires_clip): + assert model_load_config.clip is not None, 'Your model requires clip to be loaded. Please provide a clip config.' + self.clip = model_load_config.clip.create().to(self.devices[0]) + # Update the data requirements to include the clip model + if exists(self.decoder_info): + self.decoder_info.data_requirements.has_clip() + if exists(self.prior_info): + self.prior_info.data_requirements.has_clip() + else: + if model_load_config.clip is not None: + print(f'WARNING: Your model does not require clip, but you provided a clip config. This will be ignored.') + + def _get_decoder_data_requirements(self, decoder_config: DecoderConfig, min_unet_number: int = 1) -> DataRequirements: + """ + Returns the data requirements for a decoder + """ + return DataRequirements( + image_embedding=True, + text_encoding=any(unet_config.cond_on_text_encodings for unet_config in decoder_config.unets[min_unet_number - 1:]), + image=min_unet_number > 1, # If this is an upsampler we need an image + text=False, # Text is never required for anything + can_generate_embedding=False, # This might be added later if clip is being used + image_size=decoder_config.image_sizes[min_unet_number - 1] # The input image size is the input to the first unet we are using + ) + + def _load_single_decoder(self, load_config: SingleDecoderLoadConfig) -> Tuple[Decoder, DecoderConfig, Optional[version.Version], bool]: + """ + Loads a single decoder from a model and a config file + """ + with load_config.load_model_from.as_local_file() as model_file: + model_state_dict = torch.load(model_file, map_location=self.load_device) + if 'version' in model_state_dict: + model_version = version.parse(model_state_dict['version']) + if model_version != self.current_version: + print(f'WARNING: This decoder was trained on version {model_version} but the current version is {self.current_version}. This may result in the model failing to load.') + else: + print(f'WARNING: This decoder was trained on an old version of Dalle2. This may result in the model failing to load.') + model_version = None # No version info in the model + + requires_clip = False + if 'config' in model_state_dict: + # Then we define the decoder config from this object + decoder_config = TrainDecoderConfig(**model_state_dict['config']).decoder + if decoder_config.clip is not None: + # We don't want to load clip with the model + requires_clip = True + decoder_config.clip = None + decoder = decoder_config.create().to(self.devices[0]).eval() + decoder.load_state_dict(model_state_dict['model'], strict=self.strict_loading) # If the model has a config included, then we know the model_state_dict['model'] is the actual model + else: + # In this case, the state_dict is the model itself. This means we also must load the config from an external file + assert load_config.load_config_from is not None + with load_config.load_config_from.as_local_file() as config_file: + decoder_config = TrainDecoderConfig.from_json_path(config_file).decoder + if decoder_config.clip is not None: + # We don't want to load clip with the model + requires_clip = True + decoder_config.clip = None + decoder = decoder_config.create().to(self.devices[0]).eval() + decoder.load_state_dict(model_state_dict, strict=self.strict_loading) + + return decoder, decoder_config, model_version, requires_clip + + def load_decoder(self, load_config: DecoderLoadConfig) -> 'ModelInfo[Decoder]': + """ + Loads a decoder from a model and a config file + """ + if len(load_config.unet_sources) == 1: + # Then we are loading only one model + decoder, decoder_config, decoder_version, requires_clip = self._load_single_decoder(load_config.unet_sources[0]) + decoder_data_requirements = self._get_decoder_data_requirements(decoder_config) + decoder.to(torch.float32) + return ModelInfo(decoder, decoder_version, requires_clip, decoder_data_requirements) + else: + true_unets: List[Unet] = [None] * load_config.final_unet_number # Stores the unets that will replace the ones in the true decoder + true_unet_configs: List[UnetConfig] = [None] * load_config.final_unet_number # Stores the unet configs that will replace the ones in the true decoder config + true_upsampling_sizes: List[Tuple[int, int]] = [None] * load_config.final_unet_number # Stores the progression of upsampling sizes for each unet so that we can validate these unets actually work together + true_train_timesteps: List[int] = [None] * load_config.final_unet_number # Stores the number of timesteps that each unet trained with + true_beta_schedules: List[str] = [None] * load_config.final_unet_number # Stores the beta scheduler that each unet used + true_uses_learned_variance: List[bool] = [None] * load_config.final_unet_number # Stores whether each unet uses learned variance + + requires_clip = False + for source in load_config.unet_sources: + decoder, decoder_config, decoder_version, unets_requires_clip = self._load_single_decoder(source) + if unets_requires_clip: + requires_clip = True + for unet_number in source.unet_numbers: + print(f"Loading unet {unet_number}") + unet_index = unet_number - 1 + # Now we need to insert the unet into the true unets and the unet config into the true config + true_unets[unet_index] = decoder.unets[unet_index] + true_unet_configs[unet_index] = decoder_config.unets[unet_index] + true_upsampling_sizes[unet_index] = None if unet_index == 0 else decoder_config.image_sizes[unet_index - 1], decoder_config.image_sizes[unet_index] + true_train_timesteps[unet_index] = decoder_config.timesteps + true_beta_schedules[unet_index] = decoder_config.beta_schedule[unet_index] + true_uses_learned_variance[unet_index] = decoder_config.learned_variance if isinstance(decoder_config.learned_variance, bool) else decoder_config.learned_variance[unet_index] + + true_decoder_config_obj = {} + # Insert the true configs into the true decoder config + true_decoder_config_obj['unets'] = true_unet_configs + true_image_sizes = [] + for i in range(load_config.final_unet_number): + if i == 0: + true_image_sizes.append(true_upsampling_sizes[i][1]) + else: + assert true_upsampling_sizes[i - 1][1] == true_upsampling_sizes[i][0], f"The upsampling sizes for unet {i} are not compatible with unet {i - 1}." + true_image_sizes.append(true_upsampling_sizes[i][1]) + true_decoder_config_obj['image_sizes'] = true_image_sizes + # All unets must have been trained with the same number of sampling timesteps in order to be compatible + assert all(true_train_timesteps[0] == t for t in true_train_timesteps), f"All unets must have been trained with the same number of sampling timesteps in order to be compatible." + true_decoder_config_obj['timesteps'] = true_train_timesteps[0] + true_decoder_config_obj['beta_schedule'] = true_beta_schedules + true_decoder_config_obj['learned_variance'] = true_uses_learned_variance + + # Now we can create the decoder and substitute the unets + true_decoder_config = DecoderConfig(**true_decoder_config_obj) + decoder_data_requirements = self._get_decoder_data_requirements(true_decoder_config) + decoder = true_decoder_config.create().to(self.devices[0]).eval() + decoder.unets = nn.ModuleList(true_unets) + decoder.to(torch.float32) + return ModelInfo(decoder, decoder_version, requires_clip, decoder_data_requirements) + + def _get_prior_data_requirements(self, config: DiffusionPriorConfig) -> DataRequirements: + """ + Returns the data requirements for a diffusion prior + """ + return DataRequirements( + image_embedding=False, # This is kinda the whole point + text_encoding=True, # This is also kinda the whole point + image=False, # The prior is never conditioned on the image + text=False, # Text is never required for anything + can_generate_embedding=False, # This might be added later if clip is being used + image_size=[-1, -1] # This is not used + ) + + def load_prior(self, load_config: PriorLoadConfig) -> 'ModelInfo[DiffusionPrior]': + """ + Loads a prior from a model and a config file + """ + with load_config.load_model_from.as_local_file() as model_file: + model_state_dict = torch.load(model_file, map_location=self.load_device) + if 'version' in model_state_dict: + model_version = version.parse(model_state_dict['version']) + if model_version != self.current_version: + print(f'WARNING: This prior was trained on version {model_version} but the current version is {self.current_version}. This may result in the model failing to load.') + else: + print(f'WARNING: This prior was trained on an old version of Dalle2. This may result in the model failing to load.') + model_version = None + + requires_clip = False + if 'config' in model_state_dict: + # Then we define the prior config from this object + prior_config = TrainDiffusionPriorConfig(**model_state_dict['config']).prior + if prior_config.clip is not None: + # We don't want to load clip with the model + prior_config.clip = None + requires_clip = True + prior = prior_config.create().to(self.devices[0]).eval() + prior.load_state_dict(model_state_dict['model'], strict=self.strict_loading) + else: + # In this case, the state_dict is the model itself. This means we also must load the config from an external file + assert load_config.load_config_from is not None + with load_config.load_config_from.as_local_file() as config_file: + prior_config = TrainDiffusionPriorConfig.from_json_path(config_file).prior + if prior_config.clip is not None: + # We don't want to load clip with the model + prior_config.clip = None + requires_clip = True + prior = prior_config.create().to(self.devices[0]).eval() + prior.load_state_dict(model_state_dict, strict=self.strict_loading) + + data_requirements = self._get_prior_data_requirements(prior_config) + prior.to(torch.float32) + return ModelInfo(prior, model_version, requires_clip, data_requirements) \ No newline at end of file diff --git a/dalle2_laion/scripts/BasicInference.py b/dalle2_laion/scripts/BasicInference.py new file mode 100644 index 0000000..a07d26e --- /dev/null +++ b/dalle2_laion/scripts/BasicInference.py @@ -0,0 +1,119 @@ +""" +This inference script is used to do basic inference without any bells and whistles. +Pass in text, get out image. +""" + +from dalle2_laion.scripts import InferenceScript +from typing import Dict, List, Union +from PIL import Image as PILImage +import numpy as np +import torch + +class BasicInference(InferenceScript): + def sample_decoder_with_image_embedding( + self, + images: List[PILImage.Image] = None, image_embed: List[torch.Tensor] = None, + text: List[str] = None, text_encoding: List[torch.Tensor] = None, + cond_scale: float = 1.0, sample_count: int = 1, batch_size: int = 10, + ): + decoder_info = self.model_manager.decoder_info + assert decoder_info is not None, "No decoder loaded." + data_requirements = decoder_info.data_requirements + min_image_size = min(min(image.size) for image in images) if images is not None else None + assert data_requirements.is_valid( + has_image_emb=image_embed is not None, has_image=images is not None, + has_text_encoding=text_encoding is not None, has_text=text is not None, + image_size=min_image_size + ), "The data requirements for the decoder are not satisfied." + + # Prepare the data + image_embeddings = [] # The null case where nothing is done. This should never be used in actuality, but for stylistic consistency I'm keeping it. + if data_requirements.image_embedding: + if image_embed is None: + # Then we need to use clip to generate the image embedding + image_embed = self._embed_images(images) + # Then we need to group these tensors into batches of size batch_size such that the total number of samples is sample_count + image_embeddings, image_embeddings_map = self._repeat_tensors_with_batch_size(image_embed, repeat_num=sample_count, batch_size=batch_size) + print(f'Batched {torch.stack(image_embed).shape} to {torch.stack(image_embeddings).shape} with batch size {batch_size} and repeat num {sample_count}') + + if data_requirements.text_encoding: + if text_encoding is None: + text_encoding = self._encode_text(text) + text_encodings, text_encodings_map = self._repeat_tensors_with_batch_size(text_encoding, repeat_num=sample_count, batch_size=batch_size) + + assert len(image_embeddings) > 0, "No data provided for decoder inference." + output_image_map: Dict[int, List[PILImage.Image]] = {} + for i in range(len(image_embeddings)): + args = {} + embeddings_map = [] + if data_requirements.image_embedding: + args["image_embed"] = image_embeddings[i].to(self.device) + embeddings_map = image_embeddings_map[i] + if data_requirements.text_encoding: + args["text_encodings"] = text_encodings[i].to(self.device) + embeddings_map = text_encodings_map[i] + output_images = decoder_info.model.sample(**args, cond_scale=cond_scale) + for output_image, input_embedding_number in zip(output_images, embeddings_map): + if input_embedding_number not in output_image_map: + output_image_map[input_embedding_number] = [] + output_image_map[input_embedding_number].append(self._torch_to_pil(output_image)) + return output_image_map + + + def sample_prior_with_text_encoding(self, text: List[str], cond_scale: float = 1.0, sample_count: int = 1, batch_size: int = 100, num_samples_per_batch: int = 2): + assert self.model_manager.prior_info is not None + data_requirements = self.model_manager.prior_info.data_requirements + assert data_requirements.is_valid( + has_text_encoding=False, has_text=text is not None, + has_image_emb=False, has_image=False, + image_size=None + ), "The data requirements for the prior are not satisfied." + text_tokens = self._tokenize_text(text) + text_batches, text_batches_map = self._repeat_tensors_with_batch_size(text_tokens, repeat_num=sample_count, batch_size=batch_size) + embedding_map: Dict[int, List[torch.Tensor]] = {} + # Weirdly the prior requires clip be part of itself to work so we insert it + with self._clip_in_prior() as prior: + for text_batch, batch_map in zip(text_batches, text_batches_map): + text_batch = text_batch.to(self.device) + embeddings = prior.sample(text_batch, cond_scale=cond_scale, num_samples_per_batch=num_samples_per_batch) + for embedding, embedding_number in zip(embeddings, batch_map): + if embedding_number not in embedding_map: + embedding_map[embedding_number] = [] + embedding_map[embedding_number].append(embedding) + return embedding_map + + def dream( + self, + text: Union[str, List[str]], + prior_cond_scale: float = 1.0, decoder_cond_scale: float = 1.0, + prior_sample_count: int = 1, decoder_sample_count: int = 1, + prior_batch_size: int = 100, decoder_batch_size: int = 10, + prior_num_samples_per_batch: int = 2 + ): + if isinstance(text, str): + text = [text] + image_embedding_map = self.sample_prior_with_text_encoding(text, cond_scale=prior_cond_scale, sample_count=prior_sample_count, batch_size=prior_batch_size, num_samples_per_batch=prior_num_samples_per_batch) + # This is a map between the text index and the generated image embeddings + # In order to + image_embeddings: List[torch.Tensor] = [] + for i in range(len(text)): + image_embeddings.extend(image_embedding_map[i]) + # In order to get the original text from the image embeddings, we need to reverse the map + image_embedding_index_reverse_map = {i: [] for i in range(len(text))} + current_count = 0 + texts = [] + for i in range(len(text)): + for _ in range(len(image_embedding_map[i])): + texts.append(text[i]) + image_embedding_index_reverse_map[i].append(current_count) + current_count += 1 + # Now we can use the image embeddings to generate the images + image_map = self.sample_decoder_with_image_embedding(text=texts, image_embed=image_embeddings, cond_scale=decoder_cond_scale, sample_count=decoder_sample_count, batch_size=decoder_batch_size) + # Now we will reconstruct a map from text to a map of img_embedding indices to list of images + output_map: Dict[int, Dict[int, List[PILImage.Image]]] = {} + for i, text in enumerate(text): + output_map[text] = {} + embedding_indices = image_embedding_index_reverse_map[i] + for embedding_index in embedding_indices: + output_map[text][embedding_index] = image_map[embedding_index] + return output_map \ No newline at end of file diff --git a/dalle2_laion/scripts/InferenceScript.py b/dalle2_laion/scripts/InferenceScript.py new file mode 100644 index 0000000..78e5c39 --- /dev/null +++ b/dalle2_laion/scripts/InferenceScript.py @@ -0,0 +1,185 @@ +""" +This module contains an abstract class for inference scripts. +""" + +from typing import Any, List, Tuple, Union, TypeVar +from dalle2_pytorch.tokenizer import tokenizer +from dalle2_laion import DalleModelManager +from torchvision.transforms import ToPILImage, ToTensor +from PIL import Image as PILImage +import torch +import numpy as np +from contextlib import contextmanager + +RepeatObject = TypeVar('RepeatObject') + +class InferenceScript: + def __init__(self, model_manager: DalleModelManager): + self.model_manager = model_manager + self.device = model_manager.devices[0] + + @contextmanager + def _clip_in_decoder(self): + assert self.model_manager.decoder_info is not None, "Cannot use the decoder without a decoder model." + decoder = self.model_manager.decoder_info.model + clip = self.model_manager.clip + decoder.clip = clip + yield decoder + decoder.clip = None + + @contextmanager + def _clip_in_prior(self): + assert self.model_manager.prior_info is not None, "Cannot use the prior without a prior model." + prior = self.model_manager.prior_info.model + clip = self.model_manager.clip + prior.clip = clip + yield prior + prior.clip = None + + def _pil_to_torch(self, image: Union[PILImage.Image, List[PILImage.Image]]): + """ + Convert a PIL image into a torch tensor. + Tensor is of dimension 3 if one image is passed in, and of dimension 4 if a list of images is passed in. + """ + if isinstance(image, PILImage.Image): + return ToTensor()(image) + else: + return torch.stack([ToTensor()(image[i]) for i in range(len(image))]) + + def _torch_to_pil(self, image: torch.tensor): + """ + If the tensor is a batch of images, then we return a list of PIL images. + """ + + if len(image.shape) == 4: + return [ToPILImage(image[i]) for i in range(image.shape[0])] + else: + return ToPILImage()(image) + + def _repeat_tensors_with_batch_size(self, tensors: List[torch.Tensor], repeat_num: int, batch_size: int) -> Tuple[List[torch.Tensor], List[List[int]]]: + """ + Takes a list of tensors and converts it to a list of tensors of shape (<=batch_size, ...) such that the total number of the original tensors is repeat_num * len(tensors) + Since there are multiple tensor inputs, we also return a list of indices that correspond to the original tensors. + """ + assert repeat_num > 0 + assert batch_size > 0 + assert isinstance(tensors[0], torch.Tensor), f"Tensors must be torch tensors, not {type(tensors[0])}" + assert all(tensors[0].shape == tensor.shape for tensor in tensors), "All tensors must have the same shape to be repeated together." + num_dims = len(tensors[0].shape) + current_tensor_index = 0 + num_left = repeat_num + residual = 0 + result = [] + result_indices = [] + while current_tensor_index < len(tensors): + if residual > 0: + # Then we had some from the last tensor that we need to fill in before we start repeating the current tensor + residual_tensor = tensors[current_tensor_index - 1].repeat(residual, *[1] * num_dims) + num_to_add = min(num_left, batch_size - residual) + add_tensor = tensors[current_tensor_index].repeat(num_to_add, *[1] * num_dims) + result.append(torch.cat([residual_tensor, add_tensor], dim=0)) + result_indices.append([current_tensor_index - 1] * residual + [current_tensor_index] * num_to_add) + num_left -= num_to_add + # Expand the current tensor until we have too few to fill another batch + while num_left >= batch_size: + result.append(tensors[current_tensor_index].repeat(batch_size, *[1] * num_dims)) + result_indices.append([current_tensor_index] * batch_size) + num_left -= batch_size + # Now we need to add the remaining tensors to the next batch and then move on to the next tensor + residual = num_left + current_tensor_index += 1 + num_left = repeat_num + # Take care of the final residual + if residual > 0: + residual_tensor = tensors[current_tensor_index - 1].repeat(residual, *[1] * num_dims) + result.append(residual_tensor) + result_indices.append([current_tensor_index - 1] * residual) + return result, result_indices + + def _repeat_object_with_batch_size(self, objects: List[RepeatObject], repeat_num: int, batch_size: int) -> Tuple[List[List[RepeatObject]], List[List[int]]]: + """ + Takes a list of objects and converts it to a list of objects of shape (<=batch_size, ...) such that the total number of the original objects is repeat_num * len(objects) + Since there are multiple object inputs, we also return a list of indices that correspond to the original objects. + """ + assert repeat_num > 0 + assert batch_size > 0 + current_object_index = 0 + num_left = repeat_num + residual = 0 + result = [] + result_indices = [] + while current_object_index < len(objects): + if residual > 0: + # Then we had some from the last object that we need to fill in before we start repeating the current object + residual_objects = [objects[current_object_index - 1]] * residual + num_to_add = min(num_left, batch_size - residual) + add_objects = [objects[current_object_index]] * num_to_add + result.append(residual_objects + add_objects) + result_indices.append([current_object_index - 1] * residual + [current_object_index] * num_to_add) + num_left -= num_to_add + # Expand the current object until we have too few to fill another batch + while num_left >= batch_size: + result.append([objects[current_object_index]] * batch_size) + result_indices.append([current_object_index] * batch_size) + num_left -= batch_size + # Now we need to add the remaining objects to the next batch and then move on to the next object + residual = num_left + current_object_index += 1 + num_left = repeat_num + # Take care of the final residual + if residual > 0: + residual_objects = [objects[current_object_index - 1]] * residual + result.append(residual_objects) + result_indices.append([current_object_index - 1] * residual) + return result, result_indices + + def _embed_images(self, images: List[PILImage.Image]) -> torch.Tensor: + """ + Generates the clip embeddings for a list of images + """ + assert self.model_manager.decoder_info.data_requirements.can_generate_embedding, "Cannot generate embeddings for this model." + clip = self.model_manager.clip + image_embed = clip.embed_image(self._pil_to_torch(images)) + return image_embed.image_embed + + def _encode_text(self, text: List[str]) -> torch.Tensor: + """ + Generates the clip embeddings for a list of text + """ + assert self.model_manager.prior_info.data_requirements.can_generate_embedding, "Cannot generate embeddings for this model." + text_tokens = self._tokenize_text(text) + clip = self.model_manager.clip + text_embed = clip.embed_text(text_tokens.to(self.device)) + return text_embed.text_encodings + + def _tokenize_text(self, text: List[str]) -> torch.Tensor: + """ + Tokenizes a list of text + """ + return tokenizer.tokenize(text) + +class CliInferenceScript(InferenceScript): + def __init__(self, model_manager: DalleModelManager): + super().__init__(model_manager) + raise NotImplementedError("CliInferenceScript is not implemented cause I have no idea how to do it yet.") + +if __name__ == "__main__": + i = InferenceScript(None) + # t = torch.randn(10) + # r = i._repeat_with_batch_size(t, 20, 15) + # print([tens.shape for tens in r]) + + # t1 = torch.tensor([1] * 4) + # t2 = torch.tensor([2] * 4) + # t1 = torch.randn(4, 4) + # t2 = torch.randn(4, 4) + # r = i._repeat_tensors_with_batch_size([t1, t2], 5, 7) + # print(r[0]) + # print([(tens.shape, tens.min(), tens.max()) for tens in r[0]]) + # print(r[1]) + + t1 = 1 + t2 = "asdf" + r = i._repeat_object_with_batch_size([t1, t2], 5, 7) + print(r[0]) + print(r[1]) \ No newline at end of file diff --git a/dalle2_laion/scripts/__init__.py b/dalle2_laion/scripts/__init__.py new file mode 100644 index 0000000..93853c4 --- /dev/null +++ b/dalle2_laion/scripts/__init__.py @@ -0,0 +1,2 @@ +from dalle2_laion.scripts.InferenceScript import InferenceScript, CliInferenceScript +from dalle2_laion.scripts.BasicInference import BasicInference \ No newline at end of file diff --git a/example_inference.py b/example_inference.py new file mode 100644 index 0000000..83a538c --- /dev/null +++ b/example_inference.py @@ -0,0 +1,32 @@ +from dalle2_laion import DalleModelManager, ModelLoadConfig +from dalle2_laion.scripts import BasicInference +import os +import click + +@click.command() +@click.option('--model-config', default='./configs/upsampler.example.json', help='Path to model config file') +@click.option('--output-path', default='./output/', help='Path to output directory') +def run_basic_inference(model_config: str, output_path: str): + prompts = [] + print("Enter your prompts one by one. Enter an empty prompt to finish.") + while True: + prompt = click.prompt(f'Prompt {len(prompts)+1} ', default='', type=str) + if prompt == '': + break + prompts.append(prompt) + num_prior_samples = click.prompt('How many samples would you like to generate for each prompt?', default=1, type=int) + + print(f"Generating image for prompts: {prompts}") + config = ModelLoadConfig.from_json_path(model_config) + model_manager = DalleModelManager(config) + dreamer = BasicInference(model_manager) + output_map = dreamer.dream(prompts, prior_sample_count=num_prior_samples) + os.makedirs(output_path, exist_ok=True) + for text in output_map: + for embedding_index in output_map[text]: + for image in output_map[text][embedding_index]: + # Save the image + image.save(os.path.join(output_path, f"{text}_{embedding_index}.png")) + +if __name__ == "__main__": + run_basic_inference() \ No newline at end of file diff --git a/notebooks/dalle2_laion_alpha.ipynb b/notebooks/dalle2_laion_alpha.ipynb index 890971e..f2fffcc 100644 --- a/notebooks/dalle2_laion_alpha.ipynb +++ b/notebooks/dalle2_laion_alpha.ipynb @@ -84,6 +84,25 @@ " \"dalle2_install_path\": \"dalle2_pytorch==0.15.4\",\n", " \"decoder_path\": \"https://huggingface.co/laion/DALLE2-PyTorch/resolve/main/decoder/1.5B_laion2B/latest.pth\",\n", " \"config_path\": \"https://huggingface.co/laion/DALLE2-PyTorch/raw/main/decoder/1.5B_laion2B/decoder_config.json\"\n", + "},{\n", + " \"name\": \"Upsampler\",\n", + " \"dalle2_install_path\": \"git+https://github.com/Veldrovive/DALLE2-pytorch@b2549a4d17244dab09e7a9496a9cb6330b7d3070\",\n", + " \"decoder\": [\n", + " {\n", + " \"unets\": [0],\n", + " \"model_path\": \"https://huggingface.co/laion/DALLE2-PyTorch/resolve/main/decoder/1.5B_laion2B/latest.pth\",\n", + " \"config_path\": \"https://huggingface.co/laion/DALLE2-PyTorch/raw/main/decoder/1.5B_laion2B/decoder_config.json\"\n", + " },\n", + " {\n", + " \"unets\": [1],\n", + " \"model_path\": \"https://huggingface.co/Veldrovive/upsamplers/resolve/main/working/latest.pth\",\n", + " \"config_path\": \"https://huggingface.co/Veldrovive/upsamplers/raw/main/working/decoder_config.json\"\n", + " }\n", + " ],\n", + " \"prior\": {\n", + " \"model_path\": \"https://huggingface.co/zenglishuci/conditioned-prior/resolve/main/vit-l-14/prior_aes_finetune.pth\",\n", + " \"config_path\": \"\"\n", + " }\n", "}]\n", "\n", "decoder_options = [version[\"name\"] for version in decoder_versions]\n", @@ -700,6 +719,7 @@ "#@title\n", "from IPython.display import display, clear_output\n", "from ipywidgets import interact\n", + "import torchvision.transforms as T\n", "try:\n", " from google.colab import files\n", " can_download = True\n", @@ -714,7 +734,28 @@ " layout={'width': 'auto'},\n", " rows=10\n", ")\n", - "textbox_box = widgets.VBox([text_input], layout={'border': '2px solid grey'})\n", + "\n", + "file_input = widgets.FileUpload(\n", + "# description=\"Image Variation Upload\",\n", + " multiple=False\n", + ")\n", + "file_input.observe(lambda: print(\"Uploaded\"), names=\"_test\")\n", + "file_reset = widgets.Button(\n", + " description=\"Clear Uploads\"\n", + ")\n", + "file_box = widgets.HBox([file_input, file_reset])\n", + "\n", + "def reset_file_input(b):\n", + " global file_input\n", + " new_file_input = widgets.FileUpload(\n", + " multiple=False\n", + " )\n", + " file_input = new_file_input\n", + " file_box.children = (new_file_input, file_box.children[1])\n", + " render_layout()\n", + "file_reset.on_click(reset_file_input)\n", + "\n", + "textbox_box = widgets.VBox([text_input, file_box], layout={'border': '2px solid grey'})\n", "\n", "prior_label = widgets.HTML(value=\"Prior Options: Set how many sample to take from the prior and what conditioning scale to use.\")\n", "text_repeat = widgets.IntSlider(\n", @@ -805,26 +846,40 @@ "main_layout = widgets.VBox([textbox_box, main_options_box, final_options_box])\n", "\n", "def get_prompts():\n", - " import json\n", - " text = text_input.value\n", - " try:\n", - " prompts_array = json.loads(text)\n", - " assert isinstance(prompts_array, list)\n", - " return prompts_array\n", - " except Exception as e:\n", - "# print(\"Failed to read as json\", e)\n", - " pass\n", - " \n", - " try:\n", - " return list(filter(lambda v: len(v) > 0, text.split(\"\\n\")))\n", - " except Exception as e:\n", - " print(\"Failed to read as text with newlines\", e)\n", - " \n", - " return []\n", + " import json\n", + " from itertools import zip_longest\n", + " import io\n", + " text = text_input.value\n", + " try:\n", + " prompts_array = json.loads(text)\n", + " assert isinstance(prompts_array, list)\n", + " text_prompts = prompts_array\n", + " except Exception as e:\n", + " pass\n", "\n", - "def f(text_input, text_repeat, prior_conditioning, img_repeat, decoder_conditioning, include_prompt_checkbox, upsample_checkbox):\n", - " prompts = get_prompts()\n", - " total_images = len(prompts) * text_repeat * img_repeat\n", + " try:\n", + " text_prompts = list(filter(lambda v: len(v) > 0, text.split(\"\\n\")))\n", + " except Exception as e:\n", + " print(\"Failed to read as text with newlines\", e)\n", + " return []\n", + "\n", + " files = file_input.value\n", + " file = None\n", + " if len(files) > 0:\n", + " file_name, file_info = list(files.items())[0]\n", + " image_pil = Image.open(io.BytesIO(file_info['content'])).convert('RGB')\n", + " transforms = T.Compose([\n", + " T.CenterCrop(min(image_pil.size)),\n", + " T.Resize(clip.image_size)\n", + " ])\n", + " image_pil = transforms(image_pil)\n", + " file = (file_name, image_pil)\n", + " \n", + " return (text_prompts, file)\n", + "\n", + "def f(text_input, text_repeat, prior_conditioning, img_repeat, decoder_conditioning, include_prompt_checkbox, upsample_checkbox, image):\n", + " text_prompts, image_prompt = get_prompts()\n", + " total_images = len(text_prompts) * text_repeat * img_repeat\n", "\n", " global current_state\n", " current_state = {\n", @@ -837,12 +892,19 @@ " \"include_prompt_checkbox\": include_prompt_checkbox,\n", " \"upsample_checkbox\": upsample_checkbox,\n", " }\n", + " \n", + " def get_prompt_text(index):\n", + " text_prompt = text_prompts[index]\n", + " return f\"Prompt {index}: \\\"{text_prompt}\\\"\"\n", "\n", + " if image_prompt is not None:\n", + " print(\"Taking variation of image\")\n", + " display(image_prompt[1])\n", " output_strings = []\n", " output_strings.append(f\"Using model: {current_state['decoder']['name']}\")\n", " output_strings.append(f\"Total output images: {total_images}\")\n", " output_strings.append(\"\")\n", - " output_strings.extend([f\"Prompt {index}: {prompt}\" for index, prompt in enumerate(prompts)])\n", + " output_strings.extend([get_prompt_text(index) for index in range(len(text_prompts))])\n", " output_strings.append(\"\")\n", " output_strings.append(\"Including prompt text in output image\" if include_prompt_checkbox else \"Not including prompt text in output image\")\n", " output_strings.append(f\"Prior Conditioning Scale: {prior_conditioning}\")\n", @@ -850,16 +912,31 @@ " print('\\n'.join(output_strings))\n", " save_state()\n", "\n", - "out = widgets.interactive_output(f, {'text_input': text_input, 'text_repeat': text_repeat, 'prior_conditioning': prior_conditioning, 'img_repeat': img_repeat, 'decoder_conditioning': decoder_conditioning, 'include_prompt_checkbox': include_prompt_checkbox, 'upsample_checkbox': upsample_checkbox })\n", - "\n", - "display(main_layout, out)\n", + "def render_layout():\n", + " clear_output()\n", + " out = widgets.interactive_output(f, {'text_input': text_input, 'text_repeat': text_repeat, 'prior_conditioning': prior_conditioning, 'img_repeat': img_repeat, 'decoder_conditioning': decoder_conditioning, 'include_prompt_checkbox': include_prompt_checkbox, 'upsample_checkbox': upsample_checkbox, 'image': file_input })\n", + " display(main_layout, out)\n", + " \n", + "def get_image_embeddings(prompt_tokens, prompt_image, text_rep: int, prior_cond_scale: float):\n", + " if prompt_image is None:\n", + " print(\"Computing embedings using prior\")\n", + " with torch.no_grad():\n", + " image_embed = diffusion_prior.sample(prompt_tokens, cond_scale = prior_cond_scale).cpu().numpy()\n", + " else:\n", + " print(\"Computing embeddings from example image\")\n", + " image_tensor = T.ToTensor()(prompt_image[1]).unsqueeze_(0).to(device)\n", + " unbatched_image_embed, _ = clip.embed_image(image_tensor)\n", + " image_embed = torch.zeros(len(prompt_tokens), unbatched_image_embed.shape[-1])\n", + " for i in range(len(prompt_tokens)):\n", + " image_embed[i] = unbatched_image_embed\n", + " image_embed = image_embed.cpu().numpy()\n", + " return image_embed\n", "\n", "def on_start(_, recall_embeddings=False, recall_images=False):\n", " if os.path.exists(\"./output\"):\n", " shutil.rmtree(\"./output\")\n", - " clear_output()\n", - " display(main_layout, out)\n", - " prompts = get_prompts()\n", + " render_layout()\n", + " prompts, prompt_image = get_prompts()\n", " prior_cond_scale = prior_conditioning.value\n", " decoder_cond_scale = decoder_conditioning.value\n", " text_rep = text_repeat.value\n", @@ -877,9 +954,7 @@ " print(\"Loading embeddings\")\n", " image_embed = np.load('img_emb_prior.npy')\n", " else:\n", - " print(\"Running prior\")\n", - " with torch.no_grad():\n", - " image_embed = diffusion_prior.sample(tokens, cond_scale = prior_cond_scale).cpu().numpy()\n", + " image_embed = get_image_embeddings(tokens, prompt_image, text_rep, prior_cond_scale)\n", " np.save('img_emb_prior.npy', image_embed)\n", "\n", " embeddings = np.repeat(image_embed, img_rep, axis=0)\n", @@ -919,7 +994,8 @@ " os.makedirs(\"./output\", exist_ok=True)\n", " img.save(f\"./output/example_{index}.png\")\n", "\n", - "button.on_click(on_start)" + "button.on_click(on_start)\n", + "render_layout()" ] }, { @@ -1052,7 +1128,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.10" + "version": "3.8.5" } }, "nbformat": 4, diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..2660de4 --- /dev/null +++ b/setup.py @@ -0,0 +1,17 @@ +from setuptools import setup, find_packages + +setup( + name = "dalle2-laion", + version = "0.0.1", + packages = find_packages(exclude=[]), + include_package_data = True, + install_requires = [ + "packaging>=21.0", + "pydantic>=1.9.0", + "torch>=1.10", + "Pillow>=9.0.0", + "numpy>=1.20.0", + "click>=8.0.0" + "dalle2-pytorch" + ] +)