In [None]:
!pip install diffusers
!pip install transformers
!pip install accelerate
!pip install bitsandbytes

In [None]:
from typing import Any, Callable, Dict, List, Optional, Union
import gc
import inspect
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from PIL import Image
import accelerate
import diffusers
import transformers
from diffusers import DiffusionPipeline
from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D,CrossAttnUpBlock2D
from pathlib import Path

import requests
from io import BytesIO
import random

In [None]:
class DreamBoothDataset(Dataset):
    """
    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
    It pre-processes the images and the tokenizes prompts.
    """

    def __init__(
        self,
        concepts_list,
        tokenizer,
        with_prior_preservation=True,
        size=512,
        center_crop=False,
        num_class_images=None,
        hflip=False
    ):
        self.size = size
        self.center_crop = center_crop
        self.tokenizer = tokenizer
        self.with_prior_preservation = with_prior_preservation

        self.instance_images_path = []
        self.class_images_path = []

        for concept in concepts_list:
            inst_img_path = [
                (x, concept["instance_prompt"])
                for x in Path(concept["instance_data_dir"]).iterdir()
                if x.is_file() and not str(x).endswith(".txt")
            ]
            self.instance_images_path.extend(inst_img_path)

            if with_prior_preservation:
                class_img_path = [(x, concept["class_prompt"]) for x in Path(concept["class_data_dir"]).iterdir() if x.is_file()]
                self.class_images_path.extend(class_img_path[:num_class_images])

        random.shuffle(self.instance_images_path)
        self.num_instance_images = len(self.instance_images_path)
        self.num_class_images = len(self.class_images_path)
        self._length = max(self.num_class_images, self.num_instance_images)

        self.image_transforms = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(0.5 * hflip),
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

    def __len__(self):
        return self._length

    def __getitem__(self, index):
        example = {}
        instance_path, instance_prompt = self.instance_images_path[index % self.num_instance_images]


        instance_image = Image.open(instance_path)
        if not instance_image.mode == "RGB":
            instance_image = instance_image.convert("RGB")

        example["instance_images"] = self.image_transforms(instance_image)
        example["instance_prompt_ids"] = self.tokenizer(
            instance_prompt,
            padding="max_length",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
        ).input_ids

        if self.with_prior_preservation:
            class_path, class_prompt = self.class_images_path[index % self.num_class_images]
            class_image = Image.open(class_path)
            if not class_image.mode == "RGB":
                class_image = class_image.convert("RGB")
            example["class_images"] = self.image_transforms(class_image)
            example["class_prompt_ids"] = self.tokenizer(
                class_prompt,
                padding="max_length",
                truncation=True,
                max_length=self.tokenizer.model_max_length,
            ).input_ids

        return example

In [None]:
# The below code uses some code directly or modified from the Huggingface diffusers library.

class image_pipe(DiffusionPipeline):
  def __init__(self, model_link, device=torch.device('cpu'), dtype=torch.float16 ):
    super().__init__()
    self.unet=diffusers.UNet2DConditionModel.from_pretrained(model_link, torch_dtype=dtype, subfolder='unet').to(device).train()
    self.vae=diffusers.AutoencoderKL.from_pretrained(model_link, torch_dtype=dtype, subfolder='vae').to(device).train()
    self.scheduler=diffusers.LMSDiscreteScheduler.from_pretrained(model_link, torch_dtype=dtype, subfolder='scheduler')
    self.text_encoder=transformers.CLIPTextModel.from_pretrained(model_link,torch_dtype=dtype,subfolder='text_encoder').to(device).train()
    self.tokenizer=transformers.CLIPTokenizer.from_pretrained(model_link,torch_dtype=dtype,subfolder='tokenizer')
    self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)



  def _execution_device(self):
    return self.unet.device
  

  @staticmethod
  def numpy_to_pil(images):
    if images.ndim == 3:
      images = images[None, ...]
    images = (images * 255).round().astype("uint8")
    if images.shape[-1] == 1:
      pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
    else:
      pil_images = [Image.fromarray(image) for image in images]
    return pil_images
  
  # Find theta * del (L con)/ del (theta), and return, while setting new thetas
  def process_attn(self,Attention, rho):
    kv={}
    kv['k']=(Attention.to_k.weight * (Attention.to_k.weight.grad)).detach()
    kv['v']=(Attention.to_v.weight * (Attention.to_v.weight.grad)).detach()
    Attention.to_k.weight=torch.nn.Parameter(Attention.to_k.weight- rho * Attention.to_k.weight * Attention.to_k.weight * Attention.to_k.weight.grad,requires_grad=True)
    Attention.to_v.weight=torch.nn.Parameter(Attention.to_v.weight- rho * Attention.to_v.weight * Attention.to_v.weight * Attention.to_v.weight.grad,requires_grad=True)
    return kv



  def _encode_prompt(
        self,
        prompt,
        device,
        num_images_per_prompt,
        do_classifier_free_guidance,
        negative_prompt=None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
    ):
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        if prompt_embeds is None:
            text_inputs = self.tokenizer(
                prompt,
                padding="max_length",
                max_length=self.tokenizer.model_max_length,
                truncation=True,
                return_tensors="pt",
            )
            text_input_ids = text_inputs.input_ids
            untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids

            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
                text_input_ids, untruncated_ids
            ):
                removed_text = self.tokenizer.batch_decode(
                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
                )
            if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
                attention_mask = text_inputs.attention_mask.to(device)
            else:
                attention_mask = None

            prompt_embeds = self.text_encoder(
                text_input_ids.to(device),
                attention_mask=attention_mask,
            )
            prompt_embeds = prompt_embeds[0]

        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)

        bs_embed, seq_len, _ = prompt_embeds.shape
        # duplicate text embeddings for each generation per prompt, using mps friendly method
        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

        # get unconditional embeddings for classifier free guidance
        if do_classifier_free_guidance and negative_prompt_embeds is None:
            uncond_tokens: List[str]
            if negative_prompt is None:
                uncond_tokens = [""] * batch_size
            elif type(prompt) is not type(negative_prompt):
                raise TypeError(
                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
                    f" {type(prompt)}."
                )
            elif isinstance(negative_prompt, str):
                uncond_tokens = [negative_prompt]
            elif batch_size != len(negative_prompt):
                raise ValueError(
                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
                    " the batch size of `prompt`."
                )
            else:
                uncond_tokens = negative_prompt

            max_length = prompt_embeds.shape[1]
            uncond_input = self.tokenizer(
                uncond_tokens,
                padding="max_length",
                max_length=max_length,
                truncation=True,
                return_tensors="pt",
            )

            if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
                attention_mask = uncond_input.attention_mask.to(device)
            else:
                attention_mask = None

            negative_prompt_embeds = self.text_encoder(
                uncond_input.input_ids.to(device),
                attention_mask=attention_mask,
            )
            negative_prompt_embeds = negative_prompt_embeds[0]

        if do_classifier_free_guidance:
            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
            seq_len = negative_prompt_embeds.shape[1]

            negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)

            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

            # For classifier free guidance, we need to do two forward passes.
            # Here we concatenate the unconditional and text embeddings into a single batch
            # to avoid doing two forward passes
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

        return prompt_embeds


  def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
    shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
    if isinstance(generator, list) and len(generator) != batch_size:
        raise ValueError(
            f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
            f" size of {batch_size}. Make sure the batch size matches the length of the generators."
        )

    if latents is None:
        latents = diffusers.utils.randn_tensor(shape, generator=generator, device=device, dtype=dtype)
    else:
        latents = latents.to(device)

    # scale the initial noise by the standard deviation required by the scheduler
    latents = latents * self.scheduler.init_noise_sigma
    return latents
  
  def prepare_extra_step_kwargs(self, generator, eta):
    # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
    # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
    # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
    # and should be between [0, 1]

    accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
    extra_step_kwargs = {}
    if accepts_eta:
        extra_step_kwargs["eta"] = eta

    # check if the scheduler accepts generator
    accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
    if accepts_generator:
        extra_step_kwargs["generator"] = generator
    return extra_step_kwargs
  
  def decode_latents(self,latents):
    latents = 1/self.vae.config.scaling_factor * latents
    image=self.vae.decode(latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.cpu().permute(0, 2, 3, 1).float().numpy()
    return image
  
  def __call__(
        self,
        prompt,
        prompt_embeds: Optional[torch.FloatTensor]=None,
        guidance_scale: int=7.5,
        height: Optional[int] = None,
        width: Optional[int]  = None, 
        num_inference_steps: int = 50,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        output_type: str= 'pil',
        return_dict: bool = True,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        eta: float = 0.0,
        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
        callback_steps: int = 1,
        init_latents: Optional[torch.FloatTensor] = None,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
        ):
    
    if prompt is not None and isinstance(prompt, str):
      batch_size = 1
    elif prompt is not None and isinstance(prompt, list):
      batch_size = len(prompt)
    else:
      batch_size = prompt_embeds.shape[0]
    
    device=self._execution_device()

    do_classifier_free_guidance = guidance_scale > 1.0
    prompt_embeds = self._encode_prompt(
            prompt,
            device,
            1,
            do_classifier_free_guidance,
            negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
        )

    
    height = height or self.unet.config.sample_size * self.vae_scale_factor
    width = width or self.unet.config.sample_size * self.vae_scale_factor

    

    

    #Prepare timesteps
    self.scheduler.set_timesteps(num_inference_steps, device)
    timesteps = self.scheduler.timesteps

    #Prepare latents
      #None in prepare_latents is to pass generator as None for prepare latents, for non-deterministic generation
    num_channels_latents = self.unet.in_channels

    latents = self.prepare_latents(
       1 * 1,
        num_channels_latents,
        height,
        width,
        prompt_embeds.dtype,
        device,
        generator,
        init_latents,
    )

    num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order

    extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
    
    with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                # expand the latents if we are doing classifier free guidance
                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)


                # predict the noise residual
                noise_pred = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    cross_attention_kwargs=cross_attention_kwargs,
                ).sample


                # perform guidance
                if do_classifier_free_guidance:
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                # compute the previous noisy sample x_t -> x_t-1
                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

                # call the callback, if provided
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        callback(i, t, latents)

    # 8. Post-processing
    image = self.decode_latents(latents)

    # 10. Convert to PIL
    image = self.numpy_to_pil(image)

    return image

# Compute and return masks in the order of masks for downsampling blocks, middle blocks, and upsampling blocks.
  def cones_compute(self,train_data,steps, rho=2e-5,prior_loss_weight=1.0,activation_thresh=1e-3):
    up_dict={}
    down_dict={}
    mid_dict={}
    for step in range(steps):
      torch.cuda.empty_cache()
      for batch in train_data:
        for module in self.unet.modules():
          module.zero_grad()
        with torch.no_grad():
          latents = self.vae.encode(batch['pixel_values'].to(dtype=self.unet.dtype).to(self.unet.device)).latent_dist
          latents = latents.sample()* self.vae.config.scaling_factor
        # Get the text embedding for conditioning
        with torch.no_grad():
          encoder_hidden_states = self.text_encoder(batch["input_ids"].to(self.text_encoder.device))[0]
        # Sample noise that we'll add to the latents
        noise = torch.randn_like(latents)
        bsz = latents.shape[0]
        # Sample a random timestep for each image
        timesteps = torch.randint(0, self.scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
        timesteps = timesteps.long()
        self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_k.state_dict()['weight'].requires_grad=True
        # Add noise to the latents according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_latents = self.scheduler.add_noise(latents, noise, timesteps)


  
        # Predict the noise residual
        noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample

        # Get the target for loss depending on the prediction type
        if self.scheduler.config.prediction_type == "epsilon":
          target = noise
        elif self.scheduler.config.prediction_type == "v_prediction":
          target = self.scheduler.get_velocity(latents, noise, timesteps)
        else:
          raise ValueError(f"Unknown prediction type {self.scheduler.config.prediction_type}")

        # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
        noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
        target, target_prior = torch.chunk(target, 2, dim=0)

        # Compute instance loss
        loss = F.mse_loss(noise_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()

        # Compute prior loss
        prior_loss = F.mse_loss(noise_pred_prior.float(), target_prior.float(), reduction="mean")

        # Add the prior loss to the instance loss.
        loss = loss + prior_loss_weight * prior_loss

        loss.backward()

        # Compute theta* del(L con )/ del theta, and new thetas for thetas in k,v layers in upsampling blocks
        for i,block in enumerate(self.unet.up_blocks):
          if (isinstance(block,CrossAttnUpBlock2D)):
            for j in range(len(block.attentions)):
              for k in block.attentions[j].transformer_blocks:
                if str(i)+str(j)+'1' in up_dict:
                  params_1=self.process_attn(k.attn1,rho)
                  params_2=self.process_attn(k.attn2,rho)
                  up_dict[str(i)+str(j)+'1']['k']+=params_1['k']
                  up_dict[str(i)+str(j)+'1']['v']+=params_1['v']
                  up_dict[str(i)+str(j)+'2']['k']+=params_2['k']
                  up_dict[str(i)+str(j)+'2']['v']+=params_2['v']
                else:
                  up_dict[str(i)+str(j)+'1']=self.process_attn(k.attn1,rho)
                  up_dict[str(i)+str(j)+'2']=self.process_attn(k.attn2,rho)

        # Compute theta* del(L con )/ del theta, and new thetas for thetas in k,v layers in middle block
        if '001' in mid_dict:
          params_1=self.process_attn(self.unet.mid_block.attentions[0].transformer_blocks[0].attn1,rho)
          params_2=self.process_attn(self.unet.mid_block.attentions[0].transformer_blocks[0].attn2,rho)
          mid_dict['001']['k']+=params_1['k']
          mid_dict['001']['v']+=params_1['v']
          mid_dict['002']['k']+=params_2['k']
          mid_dict['002']['v']+=params_2['v']

        else:
          mid_dict['001']=self.process_attn(self.unet.mid_block.attentions[0].transformer_blocks[0].attn1,rho)
          mid_dict['002']=self.process_attn(self.unet.mid_block.attentions[0].transformer_blocks[0].attn2,rho)

        
        # Compute theta* del(L con )/ del theta, and new thetas for thetas in k,v layers in downsampling blocks
        for i,block in enumerate(self.unet.down_blocks):
          if (isinstance(block,CrossAttnDownBlock2D)):
            for j in range(len(block.attentions)):
              for k in block.attentions[j].transformer_blocks:
                if str(i)+str(j)+'1' in down_dict:
                  params_1=self.process_attn(k.attn1,rho)
                  params_2=self.process_attn(k.attn2,rho)
                  down_dict[str(i)+str(j)+'1']['k']+=params_1['k']
                  down_dict[str(i)+str(j)+'1']['v']+=params_1['v']
                  down_dict[str(i)+str(j)+'2']['k']+=params_2['k']
                  down_dict[str(i)+str(j)+'2']['v']+=params_2['v']
                else:
                  down_dict[str(i)+str(j)+'1']=self.process_attn(k.attn1,rho)
                  down_dict[str(i)+str(j)+'2']=self.process_attn(k.attn2,rho)

    #Find masks for k v layers
    for dicti in [down_dict, mid_dict, up_dict]:
      for key in dicti:
        for char in ['k','v']:
          dicti[key][char]= activation_thresh * torch.ones_like(dicti[key][char])-dicti[key][char]
          dicti[key][char]=torch.sign(F.relu(dicti[key][char]))
    return down_dict,mid_dict,up_dict
  

  #Take masks and modify diffusion pipeline to generate concept images
  def cones_inference(self,down_dict,mid_dict,up_dict, prompt, negative_prompt: Optional [Union[str, list[str]]]=None,num_inference_steps: int = 50, prompt_embeds:Optional[torch.FloatTensor] = None,height=512, width=512):
    orig_down={}
    orig_mid={}
    orig_up={}
    for i,block in enumerate(self.unet.down_blocks):
      if (isinstance(block,CrossAttnDownBlock2D)):
        for j in range(len(block.attentions)):
          for k in block.attentions[j].transformer_blocks:
            #attn1
            orig_down[str(i)+str(j)+'1']={'k': k.attn1.to_k.weight, 'v': k.attn1.to_v.weight}
            k.attn1.to_k.weight= nn.Parameter(k.attn1.to_k.weight * down_dict[str(i)+str(j)+'1']['k'],requires_grad=True)
            k.attn1.to_v.weight= nn.Parameter(k.attn1.to_v.weight * down_dict[str(i)+str(j)+'1']['v'],requires_grad=True)

            #attn2
            orig_down[str(i)+str(j)+'1']={'k': k.attn2.to_k.weight, 'v': k.attn2.to_v.weight}
            k.attn2.to_k.weight=nn.Parameter(k.attn2.to_k.weight * down_dict[str(i)+str(j)+'2']['k'],requires_grad=True)
            k.attn2.to_v.weight=nn.Parameter(k.attn2.to_v.weight * down_dict[str(i)+str(j)+'2']['v'],requires_grad=True)

    for i,block in enumerate(self.unet.up_blocks):
      if (isinstance(block,CrossAttnUpBlock2D)):
        for j in range(len(block.attentions)):
          for k in block.attentions[j].transformer_blocks:
            #attn1
            orig_up[str(i)+str(j)+'1']={'k': k.attn1.to_k.weight, 'v': k.attn1.to_v.weight}
            k.attn1.to_k.weight=nn.Parameter(k.attn1.to_k.weight * up_dict[str(i)+str(j)+'1']['k'],requires_grad=True)
            k.attn1.to_v.weight=nn.Parameter(k.attn1.to_v.weight * up_dict[str(i)+str(j)+'1']['v'],requires_grad=True)

            #attn2
            orig_up[str(i)+str(j)+'1']={'k': k.attn2.to_k.weight, 'v': k.attn2.to_v.weight}
            k.attn2.to_k.weight=nn.Parameter(k.attn2.to_k.weight * up_dict[str(i)+str(j)+'2']['k'],requires_grad=True)
            k.attn2.to_v.weight=nn.Parameter(k.attn2.to_v.weight * up_dict[str(i)+str(j)+'2']['v'],requires_grad=True)

    orig_mid['001']={'k': self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_k.weight , 'v': self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_v.weight }
    orig_mid['002']={'k': self.unet.mid_block.attentions[0].transformer_blocks[0].attn2.to_k.weight , 'v': self.unet.mid_block.attentions[0].transformer_blocks[0].attn2.to_v.weight }
    self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_k.weight=nn.Parameter(self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_k.weight* mid_dict['001']['k'],requires_grad=True)
    self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_v.weight=nn.Parameter(self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_v.weight* mid_dict['001']['v'],requires_grad=True)
    self.unet.mid_block.attentions[0].transformer_blocks[0].attn2.to_k.weight=nn.Parameter(self.unet.mid_block.attentions[0].transformer_blocks[0].attn2.to_k.weight* mid_dict['002']['k'],requires_grad=True)
    self.unet.mid_block.attentions[0].transformer_blocks[0].attn2.to_v.weight=nn.Parameter(self.unet.mid_block.attentions[0].transformer_blocks[0].attn2.to_v.weight* mid_dict['002']['v'],requires_grad=True)


    img=self.__call__(prompt=prompt,negative_prompt=negative_prompt,num_inference_steps=num_inference_steps,prompt_embeds=prompt_embeds,height=height,width=width)
    return img



In [None]:
cones=image_pipe('CompVis/stable-diffusion-v1-4',dtype=torch.float16, device='cuda')

### Configure the below cell

In [None]:
#instance_prompt is for the concept to find neuron masks for
#class prompt is for the class the concept belongs to, person, dog, cat, etc
#instance_data_dir and class_data_dir are for paths to directories containing the concept images and the class images, respectively.

concepts_list = [
    {
        "instance_prompt":      "photo of <V*> dog",
        "class_prompt":         "photo of a dog",
        "instance_data_dir":    "/content/drive/MyDrive/subject_cones",
        "class_data_dir":       "/content/drive/MyDrive/dog_imgs"
    }]


In [None]:
train_dataset = DreamBoothDataset(
    concepts_list=concepts_list,
    tokenizer= cones.tokenizer, 
    with_prior_preservation=True,
    size=512,
    center_crop=False,
    num_class_images=20,
    hflip=True,
)

In [None]:
def collate_fn(examples):
    input_ids = [example["instance_prompt_ids"] for example in examples]
    pixel_values = [example["instance_images"] for example in examples]
    input_ids += [example["class_prompt_ids"] for example in examples]
    pixel_values += [example["class_images"] for example in examples]
    pixel_values = torch.stack(pixel_values)
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
    input_ids = cones.tokenizer.pad(
        {"input_ids": input_ids},
        padding=True,
        return_tensors="pt",
    ).input_ids

    batch = {
        "input_ids": input_ids,
        "pixel_values": pixel_values,
    }
    return batch

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn, pin_memory=True
)

In [None]:
# steps is for passes through dataset, rho is learning rate, activation_thresh is the threshold for neuron masks
#p,q,r are masks for down_blocks,mid_blocks, up_blocks respectively
p,q,r=cones.cones_compute(train_dataloader,steps=100,rho=2e-5,activation_thresh=7e-3)

In [None]:
del cones
gc.collect()
torch.cuda.empty_cache()

In [None]:
cones=image_pipe('CompVis/stable-diffusion-v1-4',dtype=torch.float16, device='cuda')

In [None]:
with torch.no_grad():
  img1=cones.cones_inference(p,q,r,'a photo of <V*> dog')

In [None]:
img1[0]