In [None]:
#@markdown #**Check GPU**
!nvidia-smi

In [None]:
#@markdown #**Install module**
!git clone https://github.com/openai/CLIP
!git clone https://github.com/CompVis/taming-transformers
!pip install ftfy regex tqdm omegaconf pytorch-lightning einops transformers
!pip install -e ./taming-transformers
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git

In [None]:
#@markdown #**Curl VQGAN**
!curl -L 'https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1' > vqgan_imagenet_f16_1024.yaml
!curl -L 'https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fckpts%2Flast.ckpt&dl=1' > vqgan_imagenet_f16_1024.ckpt
!curl -L 'https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1' > vqgan_imagenet_f16_16384.yaml
!curl -L 'https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fckpts%2Flast.ckpt&dl=1' > vqgan_imagenet_f16_16384.ckpt

In [None]:
#@markdown #**One Big Class**
import argparse
import io
import math
from pathlib import Path
import sys

sys.path.append('./taming-transformers')

import matplotlib.pyplot as plt
from IPython import display
from omegaconf import OmegaConf
from PIL import Image
import requests
from taming.models import cond_transformer, vqgan
import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm

from CLIP import clip


def sinc(x):
    return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))

def lanczos(x, a):
    cond = torch.logical_and(-a < x, x < a)
    out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
    return out / out.sum()

def ramp(ratio, width):
    n = math.ceil(width / ratio + 1)
    out = torch.empty([n])
    cur = 0
    for i in range(out.shape[0]):
        out[i] = cur
        cur += ratio
    return torch.cat([-out[1:].flip([0]), out])[1:-1]

def resample(input, size, align_corners=True):
    n, c, h, w = input.shape
    dh, dw = size

    input = input.view([n * c, 1, h, w])

    if dh < h:
        kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
        pad_h = (kernel_h.shape[0] - 1) // 2
        input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
        input = F.conv2d(input, kernel_h[None, None, :, None])

    if dw < w:
        kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
        pad_w = (kernel_w.shape[0] - 1) // 2
        input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
        input = F.conv2d(input, kernel_w[None, None, None, :])

    input = input.view([n, c, h, w])
    return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)

class ReplaceGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x_forward, x_backward):
        ctx.shape = x_backward.shape
        return x_forward

    @staticmethod
    def backward(ctx, grad_in):
        return None, grad_in.sum_to_size(ctx.shape)

replace_grad = ReplaceGrad.apply

class ClampWithGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, min, max):
        ctx.min = min
        ctx.max = max
        ctx.save_for_backward(input)
        return input.clamp(min, max)

    @staticmethod
    def backward(ctx, grad_in):
        input, = ctx.saved_tensors
        return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None

clamp_with_grad = ClampWithGrad.apply

def vector_quantize(x, codebook):
    d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T
    indices = d.argmin(-1)
    x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
    return replace_grad(x_q, x)

class Prompt(nn.Module):
    def __init__(self, embed, weight=1., stop=float('-inf')):
        super().__init__()
        self.register_buffer('embed', embed)
        self.register_buffer('weight', torch.as_tensor(weight))
        self.register_buffer('stop', torch.as_tensor(stop))

    def forward(self, input):
        input_normed = F.normalize(input.unsqueeze(1), dim=2)
        embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
        dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
        dists = dists * self.weight.sign()
        return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()

def fetch(url_or_path):
    if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
        r = requests.get(url_or_path)
        r.raise_for_status()
        fd = io.BytesIO()
        fd.write(r.content)
        fd.seek(0)
        return fd
    return open(url_or_path, 'rb')

def parse_prompt(prompt):
    if prompt.startswith('http://') or prompt.startswith('https://'):
        vals = prompt.rsplit(':', 3)
        vals = [vals[0] + ':' + vals[1], *vals[2:]]
    else:
        vals = prompt.rsplit(':', 2)
    vals = vals + ['', '1', '-inf'][len(vals):]
    return vals[0], float(vals[1]), float(vals[2])

class MakeCutouts(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
        return clamp_with_grad(torch.cat(cutouts, dim=0), 0, 1)

def load_vqgan_model(config_path, checkpoint_path):
    config = OmegaConf.load(config_path)
    if config.model.target == 'taming.models.vqgan.VQModel':
        model = vqgan.VQModel(**config.model.params)
        model.eval().requires_grad_(False)
        model.init_from_ckpt(checkpoint_path)
    elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer':
        parent_model = cond_transformer.Net2NetTransformer(**config.model.params)
        parent_model.eval().requires_grad_(False)
        parent_model.init_from_ckpt(checkpoint_path)
        model = parent_model.first_stage_model
    else:
        raise ValueError(f'unknown model type: {config.model.target}')
    del model.loss
    return model

def resize_image(image, out_size):
    ratio = image.size[0] / image.size[1]
    area = min(image.size[0] * image.size[1], out_size[0] * out_size[1])
    size = round((area * ratio)**0.5), round((area / ratio)**0.5)
    return image.resize(size, Image.LANCZOS)

#------------Main Function--------------#

# Generate Picture
def generate_pic(args,first):
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  print('Using device:', device)

  model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device)
  perceptor = clip.load(args.clip_model, jit=False)[0].eval().requires_grad_(False).to(device)


  cut_size = perceptor.visual.input_resolution
  e_dim = model.quantize.e_dim
  f = 2**(model.decoder.num_resolutions - 1)
  make_cutouts = MakeCutouts(cut_size, args.cutn, cut_pow=args.cut_pow)
  n_toks = model.quantize.n_e
  toksX, toksY = args.size[0] // f, args.size[1] // f
  sideX, sideY = toksX * f, toksY * f
  z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
  z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]

  if args.seed is not None:
      torch.manual_seed(args.seed)

  if args.init_image:
      pil_image = Image.open(fetch(args.init_image)).convert('RGB')
      pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
      z, *_ = model.encode(TF.to_tensor(pil_image).to(device).unsqueeze(0) * 2 - 1)
  else:
      one_hot = F.one_hot(torch.randint(n_toks, [toksY * toksX], device=device), n_toks).float()
      z = one_hot @ model.quantize.embedding.weight
      z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
  z_orig = z.clone()
  z.requires_grad_(True)
  opt = optim.Adam([z], lr=args.step_size)

  normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                  std=[0.26862954, 0.26130258, 0.27577711])

  pMs = []

  for prompt in args.prompts:
      txt, weight, stop = parse_prompt(prompt)
      embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
      pMs.append(Prompt(embed, weight, stop).to(device))

  for prompt in args.image_prompts:
      path, weight, stop = parse_prompt(prompt)
      img = resize_image(Image.open(fetch(path)).convert('RGB'), (sideX, sideY))
      batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
      embed = perceptor.encode_image(normalize(batch)).float()
      pMs.append(Prompt(embed, weight, stop).to(device))

  for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights):
      gen = torch.Generator().manual_seed(seed)
      embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen)
      pMs.append(Prompt(embed, weight).to(device))

  def synth(z):
      z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1)
      return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1)

  @torch.no_grad()
  def checkin(i, losses):
      losses_str = ', '.join(f'{loss.item():g}' for loss in losses)
      tqdm.write(f'i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}')
      out = synth(z)
      # Save pic #################
      pic_name = str(first) + 'folder/' + str(i) + '_' + 'pic.png'
      TF.to_pil_image(out[0].cpu()).save(pic_name)
      print(pic_name)
      display.display(display.Image(pic_name))

  def ascend_txt():
      out = synth(z)
      #######################
      iii = perceptor.encode_image(normalize(make_cutouts(out))).float()

      result = []

      if args.init_weight:
          result.append(F.mse_loss(z, z_orig) * args.init_weight / 2)

      for prompt in pMs:
          result.append(prompt(iii))

      return result

  def train(i):
      opt.zero_grad()
      lossAll = ascend_txt()
      if i % args.display_freq == 0:
          checkin(i, lossAll)
      loss = sum(lossAll)
      loss.backward()
      opt.step()
      with torch.no_grad():
          z.copy_(z.maximum(z_min).minimum(z_max))


  i = 0
  stop_iteration =  510#@param {type:"number"}
  try:
      with tqdm() as pbar:
          while True:
              if i==stop_iteration :
                break
              train(i)
              i += 1
              pbar.update()
  except KeyboardInterrupt:
      pass

# Predict Noun
def predict_noun(image):
  import os
  import clip
  import torch
  from torchvision.datasets import CIFAR100

  # Load the model
  device = "cuda" if torch.cuda.is_available() else "cpu"
  model, preprocess = clip.load('ViT-B/32', device)

  # Download the dataset
  cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)

  # Prepare the inputs
  image_input = preprocess(image).unsqueeze(0).to(device)
  text_inputs = torch.cat([clip.tokenize(f"a picture of a {c}") for c in cifar100.classes]).to(device)

  # Calculate features
  with torch.no_grad():
      image_features = model.encode_image(image_input)
      text_features = model.encode_text(text_inputs)

  # Pick the top 5 most similar labels for the image
  image_features /= image_features.norm(dim=-1, keepdim=True)
  text_features /= text_features.norm(dim=-1, keepdim=True)
  similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
  values, indices = similarity[0].topk(5)


  # Print the result
  print("\nTop predictions:\n")
  for value, index in zip(values, indices):
      print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")

# Generage_video
def generate_video(freq,first):
  import os
  import numpy as np

  frames = os.listdir('/content/'+str(first)+'folder')
  frames = len(list(filter(lambda filename: filename.endswith(".png"), frames))) #Get number of jpg generated

  init_frame = 0 #This is the frame where the video will start
  last_frame = frames #You can change i to the number of the last frame you want to generate. It will raise an error if that number of frames does not exist.

  min_fps = 10
  max_fps = 30

  total_frames = last_frame-init_frame

  #Desired video time in seconds
  video_length = 20

  frames = []
  tqdm.write('Generating video...')
  for i in range(init_frame,last_frame): #
      filename = '/content/'+str(first)+f'folder/{i*freq}_pic.png'
      frames.append(Image.open(filename))

  fps = np.clip(total_frames/video_length,min_fps,max_fps)

  from subprocess import Popen, PIPE
  p = Popen(['ffmpeg', '-y', '-f', 'image2pipe', '-vcodec', 'png', '-r', str(fps), '-i', '-', '-vcodec', 'libx264', '-r', str(fps), '-pix_fmt', 'yuv420p', '-crf', '17', '-preset', 'veryslow', 'video.mp4'], stdin=PIPE)
  for im in tqdm(frames):
      im.save(p.stdin, 'PNG')
  p.stdin.close()

  print("The video is now being compressed, wait...")
  p.wait()
  print("The video is ready")

  # Download
  from google.colab import files
  files.download("video.mp4")

  # Rename To move To folder
  os.rename("/content/video.mp4", '/content/'+str(first)+'folder/'+'video.mp4')

# Class AI_painter
class AI_painter:
  def paint(self):
    pic1 = 'https://raw.githubusercontent.com/ChiuYenHua/style-clip-draw/master/paintings/01picasso1-superJumbo.jpg'  #@param {type:"string"}
    pic2 = 'https://raw.githubusercontent.com/ChiuYenHua/style-clip-draw/master/paintings/2021_HGK_20415_0001_000(jean-michel_basquiat_warrior090115).jpg'#@param {type:"string"}
    pic3 = 'https://raw.githubusercontent.com/ChiuYenHua/style-clip-draw/master/paintings/Claude-Monet-Waterlilies-and-Japanese-Bridge-1899_HIGH-RES.jpg'#@param {type:"string"}
    pic4 = 'https://raw.githubusercontent.com/ChiuYenHua/style-clip-draw/master/paintings/GettyImages-1151386026.jpg'#@param {type:"string"}
    pic5 = 'https://raw.githubusercontent.com/ChiuYenHua/style-clip-draw/master/paintings/The-Starry-Night.jpg'#@param {type:"string"}
    pic6 = 'https://raw.githubusercontent.com/ChiuYenHua/style-clip-draw/master/paintings/Botticelli_-_Portrait_of_a_young_man_holding_a_medallion.jpg'#@param {type:"string"}
    pic7 = 'https://raw.githubusercontent.com/ChiuYenHua/style-clip-draw/master/paintings/Peter_Paul_Rubens_-_The_Feast_of_Venus_-_Google_Art_Project.jpg'#@param {type:"string"}
    pic8 = 'https://raw.githubusercontent.com/ChiuYenHua/style-clip-draw/master/paintings/pablo-picasso-painting.jpg'#@param {type:"string"}
    pic9 = 'https://raw.githubusercontent.com/ChiuYenHua/style-clip-draw/master/paintings/10067%20Lot%208%20-%20Claude%20Monet%2C%20Meules.jpg'#@param {type:"string"}
    picture_list = [pic1,pic2,pic3,pic4,pic5,pic6,pic7,pic8,pic9]
    noun = []

    # clip predict pic noun
    for path in picture_list:
      img = resize_image(Image.open(fetch(path)).convert('RGB'), (500, 500))

      imgplot = plt.imshow(img)
      plt.show()

      predict_noun(img)
      #noun.append(predict_noun_2343(img))

    # transform list to string
    sentence = ''.join(noun)
    sentence += ' in ghibli studios'

    picture_size = 600 #@param {type:"number"}
    display_frequency =  20#@param {type:"number"}
    sentenceToGenerate = 'pug in ghibli studios' #@param {type:"string"}

    for pic,first_name in zip(picture_list,range(len(picture_list))):
        # setting
        args = argparse.Namespace(
            prompts=[sentenceToGenerate],
            image_prompts=[pic],
            noise_prompt_seeds=[],
            noise_prompt_weights=[],
            size=[picture_size, picture_size],
            init_image=None,
            init_weight=0.,
            clip_model='ViT-B/32',
            vqgan_config='vqgan_imagenet_f16_1024.yaml',
            vqgan_checkpoint='vqgan_imagenet_f16_1024.ckpt',
            step_size=0.05,
            cutn=64,
            cut_pow=1.,
            display_freq=display_frequency,
            seed=0,
        )

        # create folder
        import os
        os.mkdir('/content/'+str(first_name)+'folder')

        # main_function
        generate_pic(args,first_name)

        # generate video
        generate_video(display_frequency,first_name)



In [None]:
test = AI_painter()
test.paint()