<a href="https://colab.research.google.com/github/RavinduPabasara/Dall-3-Tests/blob/main/dall_3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# make sure to enable gpu via Edit -> Notebook settings
!nvidia-smi

In [None]:
# Install dependencies

!git clone https://github.com/openai/CLIP
!git clone https://github.com/Jack000/DALLE-pytorch
!git clone https://github.com/Jack000/guided-diffusion
!pip install -e ./CLIP
!pip install -e ./DALLE-pytorch
!pip install -e ./guided-diffusion

In [None]:
# Download dalle files

!curl -OL --http1.1 'https://dall-3.com/models/dalle/bpe.model'
!curl -OL --http1.1 'https://dall-3.com/models/dalle/dalle-latest.pt'

# Download vqgan files
!curl -L -o vqgan.yaml --http1.1 'https://heibox.uni-heidelberg.de/f/b24d14998a8d4f19a34f/?dl=1'
!curl -L -o vqgan.pt --http1.1 'https://heibox.uni-heidelberg.de/f/34a747d5765840b5a99d/?dl=1'

# Download diffusion model
!curl -OL --http1.1 'https://dall-3.com/models/guided-diffusion/256/model-latest.pt'

In [None]:
# imports

# torch

import torch

from einops import repeat

# vision imports

from PIL import Image
from torchvision.utils import make_grid, save_image
from torchvision.transforms import functional as TF

import sys
sys.path.append('./CLIP')
sys.path.append('./DALLE-pytorch')
sys.path.append('./guided-diffusion')

# dalle related classes and utils

from dalle_pytorch import DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE, DALLE
from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, YttmTokenizer, ChineseTokenizer

from einops import rearrange
import math

# diffusion
import gc
import clip
from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults
from torch import nn
from torch.nn import functional as F
from torchvision import transforms

import numpy as np

In [None]:
# load models

tokenizer = YttmTokenizer('bpe.model')
load_obj = torch.load('dalle-latest.pt', map_location='cpu')
dalle_params, vae_params, weights = load_obj.pop('hparams'), load_obj.pop('vae_params'), load_obj.pop('weights')

dalle_params.pop('vae', None) # cleanup later
vae = VQGanVAE('vqgan.pt', 'vqgan.yaml')

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

dalle = DALLE(vae = vae, **dalle_params).to(device)
dalle.load_state_dict(weights)


#DALLE-pytorch generation
generate images using dalle (this will create a number of low quality images, which we will refine using clip guided diffusion)

In [None]:
text = 'a girl smiling at the camera'
text = text.lower()

top_p = 0.85
temperature = 0.9
batch_size = 1
num_batches = 16

text_tokens = tokenizer.tokenize([text], dalle.text_seq_len).to(device)
text_tokens = repeat(text_tokens, '() n -> b n', b = batch_size)

outputs = []
image_tokens = []

for i in range(num_batches):
  out, tok = dalle.generate_images(text_tokens, temperature=temperature, top_p_thresh = top_p, return_tokens = True)
  outputs.append(out)
  for j in range(batch_size):
    pimg = TF.to_pil_image(out[j])

    print(len(image_tokens))
    display(pimg)
    image_tokens.append(tok[j])

#Clip guided diffusion

In [None]:
# setup clip guided diffusion

def fetch(url_or_path):
  if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
    r = requests.get(url_or_path)
    r.raise_for_status()
    fd = io.BytesIO()
    fd.write(r.content)
    fd.seek(0)
    return fd
  return open(url_or_path, 'rb')

def parse_prompt(prompt):
  if prompt.startswith('http://') or prompt.startswith('https://'):
    vals = prompt.rsplit(':', 2)
    vals = [vals[0] + ':' + vals[1], *vals[2:]]
  else:
    vals = prompt.rsplit(':', 1)
  vals = vals + ['', '1'][len(vals):]
  return vals[0], float(vals[1])

class MakeCutouts(nn.Module):
  def __init__(self, cut_size, cutn, cut_pow=1.):
    super().__init__()
    print(cut_size)
    self.cut_size = cut_size
    self.cutn = cutn
    self.cut_pow = cut_pow

  def forward(self, input):
    sideY, sideX = input.shape[2:4]
    max_size = min(sideX, sideY)
    min_size = min(sideX, sideY, self.cut_size)
    cutouts = []
    for _ in range(self.cutn):
      size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
      offsetx = torch.randint(0, sideX - size + 1, ())
      offsety = torch.randint(0, sideY - size + 1, ())
      cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
      cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
    return torch.cat(cutouts)


def spherical_dist_loss(x, y):
  x = F.normalize(x, dim=-1)
  y = F.normalize(y, dim=-1)
  return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)


def tv_loss(input):
  """L2 total variation loss, as in Mahendran et al."""
  input = F.pad(input, (0, 1, 0, 1), 'replicate')
  x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
  y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
  return (x_diff**2 + y_diff**2).mean([1, 2, 3])


def range_loss(input):
  return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])

model_params = {
  'attention_resolutions': '32, 16, 8',
  'class_cond': False,
  'diffusion_steps': 1000,
  'rescale_timesteps': True,
  'timestep_respacing': '1000',  # Modify this value to decrease the number of timesteps
  'image_size': 256,
  'learn_sigma': True,
  'noise_schedule': 'linear',
  'num_channels': 256,
  'num_head_channels': 64,
  'num_res_blocks': 2,
  'resblock_updown': True,
  'use_fp16': True,
  'use_scale_shift_norm': True,
  'emb_condition': True
}

model_config = model_and_diffusion_defaults()
model_config.update(model_params)

model, diffusion = create_model_and_diffusion(**model_config)
model.load_state_dict(torch.load('model-latest.pt', map_location='cpu'))
model.requires_grad_(False).eval().to(device)

for name, param in model.named_parameters():
  if 'qkv' in name or 'norm' in name or 'proj' in name:
    param.requires_grad_()

if model_config['use_fp16']:
  model.convert_to_fp16()

def set_requires_grad(model, value):
  for param in model.parameters():
    param.requires_grad = value

clip_model, clip_preprocess = clip.load('ViT-B/16', jit=False)
clip_model.eval().requires_grad_(False).to(device)
clip_size = clip_model.visual.input_resolution

normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])

# Edit settings and run

In [None]:
image_index = 0             # choose the image you want to refine
clip_guidance = False       # clip guidance is not actually necessary. Turning on clip guidance helps improve image quality but will be much slower
diffusion_batch_size = 1
diffusion_num_batches = 1
seed = 0
cutn = 16
prompts = [text]            # you can specify a different prompt for the diffusion process
image_prompts = []
clip_guidance_scale = 1000
tv_scale = 0
range_scale = 0
stop_at = 1000


def do_run():
  if seed is not None:
    torch.manual_seed(seed)

  make_cutouts = MakeCutouts(clip_size, cutn)
  side_x = side_y = model_config['image_size']

  target_embeds, weights = [], []

  for prompt in prompts:
    txt, weight = parse_prompt(prompt)
    target_embeds.append(clip_model.encode_text(clip.tokenize(prompt).to(device)).float())
    weights.append(weight)

  for prompt in image_prompts:
    path, weight = parse_prompt(prompt)
    img = Image.open(fetch(path)).convert('RGB')
    img = TF.resize(img, min(side_x, side_y, *img.size), transforms.InterpolationMode.LANCZOS)
    batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
    embed = clip_model.encode_image(normalize(batch)).float()
    target_embeds.append(embed)
    weights.extend([weight / cutn] * cutn)

  target_embeds = torch.cat(target_embeds)
  weights = torch.tensor(weights, device=device)
  if weights.sum().abs() < 1e-3:
    raise RuntimeError('The weights must not sum to 0.')
  weights /= weights.sum().abs()

  img_seq = image_tokens[image_index].unsqueeze(0)

  b, n = img_seq.shape
  one_hot_indices = F.one_hot(img_seq, num_classes = vae.num_tokens).float()
  embeds = one_hot_indices @ vae.model.quantize.embed.weight

  embeds = rearrange(embeds, 'b (h w) c -> b c h w', h = int(math.sqrt(n)))

  embeds = embeds.repeat(diffusion_batch_size, 1, 1, 1)

  cur_t = None

  def cond_fn(x, t, image_embeds=None):
    with torch.enable_grad():

      x = x.detach().requires_grad_()
      n = x.shape[0]

      my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t

      out = diffusion.p_mean_variance(model, x, my_t, clip_denoised=False, model_kwargs={'image_embeds': image_embeds})
      fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]
      x_in = out['pred_xstart'] * fac + x * (1 - fac)
      clip_in = normalize(make_cutouts(x_in.add(1).div(2)))
      clip_embeds = clip_model.encode_image(clip_in).float()
      dists = spherical_dist_loss(clip_embeds.unsqueeze(1), target_embeds.unsqueeze(0))
      dists = dists.view([cutn, n, -1])
      losses = dists.mul(weights).sum(2).mean(0)
      tv_losses = tv_loss(x_in)
      range_losses = range_loss(out['pred_xstart'])
      loss = losses.sum() * clip_guidance_scale + tv_losses.sum() * tv_scale + range_losses.sum() * range_scale
      return -torch.autograd.grad(loss, x)[0]

  if model_config['timestep_respacing'].startswith('ddim'):
    sample_fn = diffusion.ddim_sample_loop_progressive
  else:
    sample_fn = diffusion.p_sample_loop_progressive

  for i in range(diffusion_num_batches):
    cur_t = diffusion.num_timesteps - 1

    samples = sample_fn(
      model,
      (diffusion_batch_size, 3, side_y, side_x),
      clip_denoised=False,
      model_kwargs={'image_embeds': embeds},
      cond_fn=cond_fn if clip_guidance else None,
      progress=True,
    )

    for j, sample in enumerate(samples):
      cur_t -= 1

      if j % 100 == 0 or cur_t == -1 or j == 999 or j > stop_at:
        for k, image in enumerate(sample['pred_xstart']):
          pimg = TF.to_pil_image(image.add(1).div(2).clamp(0, 1))
          display(pimg)
      if j > stop_at:
        break

gc.collect()
do_run()