# CLIP+DIP for pixel art
Original Author: Daniel Russell ([@danielrussruss](https://twitter.com/danielrussruss))

This notebook uses [OpenAI's CLIP](https://github.com/openai/CLIP) model to optimize the weights of the [Deep Image Prior](https://github.com/DmitryUlyanov/deep-image-prior) skip network to output an image that matches a text prompt. 

This is a somewhat minimal implementation intended for research and artistic exploration. Do not use this project or a derivative for commercial work.

This notebook would not be possible without the foundational work established by [Ryan Murdock](https://twitter.com/advadnoun) and [Katherine Crowson](https://twitter.com/rivershavewings).

Made some small modification, tidied up things, added a neat GUI, added SLIP support (as an alternative to CLIP) and optimized the initial parameters for pixel art generation, by Philipuss#4066

SLIP implementation stolen from [this notebook](https://colab.research.google.com/drive/1bItz4NdhAPHg5-u87KcH-MmJZjK-XqHN).

## Citations and Licenses

```
@misc{radford2021learning,
      title={Learning Transferable Visual Models From Natural Language Supervision}, 
      author={Alec Radford and Jong Wook Kim and Chris Hallacy and Aditya Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever},
      year={2021},
      eprint={2103.00020},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

@article{UlyanovVL17,
    author    = {Ulyanov, Dmitry and Vedaldi, Andrea and Lempitsky, Victor},
    title     = {Deep Image Prior},
    journal   = {arXiv:1711.10925},
    year      = {2017}
}

@article{wright2021ranger21,
      title={Ranger21: a synergistic deep learning optimizer}, 
      author={Wright, Less and Demeure, Nestor},
      year={2021},
      journal={arXiv preprint arXiv:2106.13731},
}
```

### CLIP
```
MIT License
Copyright (c) 2021 OpenAI

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
```

### Deep Image Prior
```
Copyright 2018 Dmitry Ulyanov
"Please contact me if you want to use this software in a commercial application." - Dmitry Ulyanov

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
```

# Install and import libraries

In [None]:
!git clone https://github.com/DmitryUlyanov/deep-image-prior
!pip install kornia einops git+https://github.com/openai/clip madgrad

!git clone https://github.com/lessw2020/Ranger21.git
%cd Ranger21
!python -m pip install -e .
%cd ../

!git clone https://github.com/facebookresearch/SLIP.git
!pip install timm

In [None]:
import sys
sys.path.append('./deep-image-prior')
sys.path.append('./SLIP')
from models import *
from utils.sr_utils import *
import clip
import time
import numpy as np
import torch
import torch.optim
from IPython import display
import cv2
from torch.nn import functional as F
import torchvision.transforms.functional as TF
import torchvision.transforms as T
import kornia.augmentation as K
from einops import rearrange
from madgrad import MADGRAD
from Ranger21.ranger21.ranger21 import Ranger21
import random
import math
from tqdm.notebook import tqdm

import os
import cv2
from google.colab.patches import cv2_imshow
from IPython.display import clear_output 
from SLIP.models import SLIP_VITB16, SLIP, SLIP_VITL16

device = torch.device('cuda')

In [None]:
# View GPU details:
!nvidia-smi

# Load and Configure CLIP

In [None]:
# CLIP works best, SLIP is kind of odd, but that might be because the settings are not optimized for it
model_type = 'CLIP'


if model_type == 'CLIP':
  clip_model = clip.load('ViT-B/16', device=device)[0]
  clip_model = clip_model.eval().requires_grad_(False)
  clip_size = clip_model.visual.input_resolution
elif model_type == 'SLIP':
  model_path = "/content/"
  clip_model = SLIP_VITB16(ssl_mlp_dim=4096, ssl_emb_dim=256)
  if not os.path.exists(f'{model_path}/slip_base_100ep.pt'):
    !wget https://dl.fbaipublicfiles.com/slip/slip_base_100ep.pt -P {model_path}
  sd = torch.load(f'{model_path}/slip_base_100ep.pt')
  real_sd = {}
  for k, v in sd['state_dict'].items():
    real_sd['.'.join(k.split('.')[1:])] = v
  del sd
  clip_model.load_state_dict(real_sd)
  clip_model.requires_grad_(False).eval().to(device)
  clip_size = 224

clip_normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])


def sinc(x):
    return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))


def lanczos(x, a):
    cond = torch.logical_and(-a < x, x < a)
    out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
    return out / out.sum()


def ramp(ratio, width):
    n = math.ceil(width / ratio + 1)
    out = torch.empty([n])
    cur = 0
    for i in range(out.shape[0]):
        out[i] = cur
        cur += ratio
    return torch.cat([-out[1:].flip([0]), out])[1:-1]


def resample(input, size, align_corners=True):
    n, c, h, w = input.shape
    dh, dw = size

    input = input.view([n * c, 1, h, w])

    if dh < h:
        kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
        pad_h = (kernel_h.shape[0] - 1) // 2
        input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
        input = F.conv2d(input, kernel_h[None, None, :, None])

    if dw < w:
        kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
        pad_w = (kernel_w.shape[0] - 1) // 2
        input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
        input = F.conv2d(input, kernel_w[None, None, None, :])

    input = input.view([n, c, h, w])
    return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)

def lerp(a, b, f):
    return (a * (1.0 - f)) + (b * f);

class ReplaceGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x_forward, x_backward):
        ctx.shape = x_backward.shape
        return x_forward

    @staticmethod
    def backward(ctx, grad_in):
        return None, grad_in.sum_to_size(ctx.shape)


replace_grad = ReplaceGrad.apply


class ClampWithGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, min, max):
        ctx.min = min
        ctx.max = max
        ctx.save_for_backward(input)
        return input.clamp(min, max)

    @staticmethod
    def backward(ctx, grad_in):
        input, = ctx.saved_tensors
        return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None


clamp_with_grad = ClampWithGrad.apply


class MakeCutoutsPhong(torch.nn.Module):
    def __init__(self, cut_size, cutn, cut_pow, augs):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow
        self.augs = T.Compose([
            K.RandomHorizontalFlip(p=0.5),
            K.RandomAffine(degrees=15, translate=0.1, p=0.8, padding_mode='border', resample='bilinear'),
            K.RandomPerspective(0.4, p=0.7, resample='bilinear'),
            K.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=0.7),
            K.RandomGrayscale(p=0.15),
        ])

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        if sideY != sideX:
            input = K.RandomAffine(degrees=0, shear=10, p=0.5)(input)

        max_size = min(sideX, sideY)
        cutouts = []
        for cn in range(self.cutn):
            if cn > self.cutn - self.cutn//4:
                cutout = input
            else:
                size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size/max_size), 1.))
                offsetx = torch.randint(0, sideX - size + 1, ())
                offsety = torch.randint(0, sideY - size + 1, ())
                cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
        cutouts = torch.cat(cutouts)
        cutouts = self.augs(cutouts)
        return cutouts


class MakeCutoutsJuu(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow, augs):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow
        self.augs = nn.Sequential(
            #K.RandomGaussianNoise(mean=0.0, std=0.5, p=0.1),
            K.RandomHorizontalFlip(p=0.5),
            K.RandomSharpness(0.3,p=0.4),
            K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'),
            K.RandomPerspective(0.2,p=0.4),
            K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
            K.RandomGrayscale(p=0.1),
        )
        self.noise_fac = 0.1 

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
        batch = self.augs(torch.cat(cutouts, dim=0))
        if self.noise_fac:
            facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
            batch = batch + facs * torch.randn_like(batch)
        return batch

class MakeCutoutsMoth(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow, augs, skip_augs=False):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow
        self.skip_augs = skip_augs
        self.augs = T.Compose([
            T.RandomHorizontalFlip(p=0.5),
            T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
            T.RandomAffine(degrees=15, translate=(0.1, 0.1)),
            T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
            T.RandomPerspective(distortion_scale=0.4, p=0.7),
            T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
            T.RandomGrayscale(p=0.15),
            T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
            # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
        ])

    def forward(self, input):
        input = T.Pad(input.shape[2]//4, fill=0)(input)
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)

        cutouts = []
        for ch in range(cutn):
            if ch > cutn - cutn//4:
                cutout = input.clone()
            else:
                size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size/max_size), 1.))
                offsetx = torch.randint(0, abs(sideX - size + 1), ())
                offsety = torch.randint(0, abs(sideY - size + 1), ())
                cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]

            if not self.skip_augs:
                cutout = self.augs(cutout)
            cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
            del cutout

        cutouts = torch.cat(cutouts, dim=0)
        return cutouts

class MakeCutoutsAaron(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow, augs):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow
        self.augs = augs
        self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
        self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))

    def set_cut_pow(self, cut_pow):
        self.cut_pow = cut_pow

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        cutouts_full = []
        
        min_size_width = min(sideX, sideY)
        lower_bound = float(self.cut_size/min_size_width)
        
        for ii in range(self.cutn):
            size = int(min_size_width*torch.zeros(1,).normal_(mean=.8, std=.3).clip(lower_bound, 1.)) # replace .5 with a result for 224 the default large size is .95
          
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))

        cutouts = torch.cat(cutouts, dim=0)

        return clamp_with_grad(cutouts, 0, 1)

class MakeCutoutsCumin(nn.Module):
    #from https://colab.research.google.com/drive/1ZAus_gn2RhTZWzOWUpPERNC0Q8OhZRTZ
    def __init__(self, cut_size, cutn, cut_pow, augs):
        super().__init__()
        self.cut_size = cut_size
        tqdm.write(f'cut size: {self.cut_size}')
        self.cutn = cutn
        self.cut_pow = cut_pow
        self.noise_fac = 0.1
        self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
        self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
        self.augs = nn.Sequential(
          #K.RandomHorizontalFlip(p=0.5),
          #K.RandomSharpness(0.3,p=0.4),
          #K.RandomGaussianBlur((3,3),(10.5,10.5),p=0.2),
          #K.RandomGaussianNoise(p=0.5),
          #K.RandomElasticTransform(kernel_size=(33, 33), sigma=(7,7), p=0.2),
          K.RandomAffine(degrees=15, translate=0.1, p=0.7, padding_mode='border'),
          K.RandomPerspective(0.7,p=0.7),
          K.ColorJitter(hue=0.1, saturation=0.1, p=0.7),
          K.RandomErasing((.1, .4), (.3, 1/.3), same_on_batch=True, p=0.7),)
            
    def set_cut_pow(self, cut_pow):
      self.cut_pow = cut_pow
    
    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        cutouts_full = []
        noise_fac = 0.1
        
        
        min_size_width = min(sideX, sideY)
        lower_bound = float(self.cut_size/min_size_width)
        
        for ii in range(self.cutn):
            
            
          # size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
          randsize = torch.zeros(1,).normal_(mean=.8, std=.3).clip(lower_bound,1.)
          size_mult = randsize ** self.cut_pow
          size = int(min_size_width * (size_mult.clip(lower_bound, 1.))) # replace .5 with a result for 224 the default large size is .95
          # size = int(min_size_width*torch.zeros(1,).normal_(mean=.9, std=.3).clip(lower_bound, .95)) # replace .5 with a result for 224 the default large size is .95

          offsetx = torch.randint(0, sideX - size + 1, ())
          offsety = torch.randint(0, sideY - size + 1, ())
          cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
          cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
        
        
        cutouts = torch.cat(cutouts, dim=0)
        cutouts = clamp_with_grad(cutouts, 0, 1)

        #if args.use_augs:
        cutouts = self.augs(cutouts)
        if self.noise_fac:
          facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_(0, self.noise_fac)
          cutouts = cutouts + facs * torch.randn_like(cutouts)
        return cutouts


class MakeCutoutsHolywater(nn.Module):
  def __init__(self, cut_size, cutn, cut_pow, augs):
    super().__init__()
    self.cut_size = cut_size
    tqdm.write(f'cut size: {self.cut_size}')
    self.cutn = cutn
    self.cut_pow = cut_pow
    self.noise_fac = 0.1
    self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
    self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
    self.augs = nn.Sequential(
            #K.RandomGaussianNoise(mean=0.0, std=0.5, p=0.1),
            K.RandomHorizontalFlip(p=0.5),
            K.RandomSharpness(0.3,p=0.4),
            K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'),
            K.RandomPerspective(0.2,p=0.4),
            K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
            K.RandomGrayscale(p=0.1),
        )

  def set_cut_pow(self, cut_pow):
    self.cut_pow = cut_pow

  def forward(self, input):
      sideY, sideX = input.shape[2:4]
      max_size = min(sideX, sideY)
      min_size = min(sideX, sideY, self.cut_size)
      cutouts = []
      cutouts_full = []
      noise_fac = 0.1
      min_size_width = min(sideX, sideY)
      lower_bound = float(self.cut_size/min_size_width)
      
      for ii in range(self.cutn):
        size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
        randsize = torch.zeros(1,).normal_(mean=.8, std=.3).clip(lower_bound,1.)
        size_mult = randsize ** self.cut_pow * ii + size
        size1 = int((min_size_width) * (size_mult.clip(lower_bound, 1.))) # replace .5 with a result for 224 the default large size is .95
        size2 = int((min_size_width) * torch.zeros(1,).normal_(mean=.9, std=.3).clip(lower_bound, .95)) # replace .5 with a result for 224 the default large size is .95
        offsetx = torch.randint(0, sideX - size1 + 1, ())
        offsety = torch.randint(0, sideY - size2 + 1, ())
        cutout = input[:, :, offsety:offsety + size2 + ii, offsetx:offsetx + size1 + ii]
        cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
      
      cutouts = torch.cat(cutouts, dim=0)
      cutouts = clamp_with_grad(cutouts, 0, 1)
      cutouts = self.augs(cutouts)
      facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_(0, self.noise_fac)
      cutouts = cutouts + facs * torch.randn_like(cutouts)
      return cutouts

class MakeCutoutsOldHolywater(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow, augs):
        super().__init__()
        self.cut_size = cut_size
        tqdm.write(f'cut size: {self.cut_size}')
        self.cutn = cutn
        self.cut_pow = cut_pow
        self.noise_fac = 0.1
        self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
        self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
        self.augs = nn.Sequential(
          #K.RandomHorizontalFlip(p=0.5),
          #K.RandomSharpness(0.3,p=0.4),
          #K.RandomGaussianBlur((3,3),(10.5,10.5),p=0.2),
          #K.RandomGaussianNoise(p=0.5),
          #K.RandomElasticTransform(kernel_size=(33, 33), sigma=(7,7), p=0.2),
          K.RandomAffine(degrees=180, translate=0.5, p=0.2, padding_mode='border'),
          K.RandomPerspective(0.6,p=0.9),
          K.ColorJitter(hue=0.03, saturation=0.01, p=0.1),
          K.RandomErasing((.1, .7), (.3, 1/.4), same_on_batch=True, p=0.2),)

    def set_cut_pow(self, cut_pow):
      self.cut_pow = cut_pow
    
    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        cutouts_full = []
        noise_fac = 0.1
        
        
        min_size_width = min(sideX, sideY)
        lower_bound = float(self.cut_size/min_size_width)
        
        for ii in range(self.cutn):
            
            
          # size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
          randsize = torch.zeros(1,).normal_(mean=.8, std=.3).clip(lower_bound,1.)
          size_mult = randsize ** self.cut_pow
          size = int(min_size_width * (size_mult.clip(lower_bound, 1.))) # replace .5 with a result for 224 the default large size is .95
          # size = int(min_size_width*torch.zeros(1,).normal_(mean=.9, std=.3).clip(lower_bound, .95)) # replace .5 with a result for 224 the default large size is .95

          offsetx = torch.randint(0, sideX - size + 1, ())
          offsety = torch.randint(0, sideY - size + 1, ())
          cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
          cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
        
        
        cutouts = torch.cat(cutouts, dim=0)
        cutouts = clamp_with_grad(cutouts, 0, 1)

        #if args.use_augs:
        cutouts = self.augs(cutouts)
        if self.noise_fac:
          facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_(0, self.noise_fac)
          cutouts = cutouts + facs * torch.randn_like(cutouts)
        return cutouts


class MakeCutoutsGinger(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow, augs):
        super().__init__()
        self.cut_size = cut_size
        tqdm.write(f'cut size: {self.cut_size}')
        self.cutn = cutn
        self.cut_pow = cut_pow
        self.noise_fac = 0.1
        self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
        self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
        self.augs = augs
        '''
        nn.Sequential(
          K.RandomHorizontalFlip(p=0.5),
          K.RandomSharpness(0.3,p=0.4),
          K.RandomGaussianBlur((3,3),(10.5,10.5),p=0.2),
          K.RandomGaussianNoise(p=0.5),
          K.RandomElasticTransform(kernel_size=(33, 33), sigma=(7,7), p=0.2),
          K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'), # padding_mode=2
          K.RandomPerspective(0.2,p=0.4, ),
          K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),)
'''

    def set_cut_pow(self, cut_pow):
      self.cut_pow = cut_pow

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        cutouts_full = []
        noise_fac = 0.1
        
        
        min_size_width = min(sideX, sideY)
        lower_bound = float(self.cut_size/min_size_width)
        
        for ii in range(self.cutn):
            
            
          # size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
          randsize = torch.zeros(1,).normal_(mean=.8, std=.3).clip(lower_bound,1.)
          size_mult = randsize ** self.cut_pow
          size = int(min_size_width * (size_mult.clip(lower_bound, 1.))) # replace .5 with a result for 224 the default large size is .95
          # size = int(min_size_width*torch.zeros(1,).normal_(mean=.9, std=.3).clip(lower_bound, .95)) # replace .5 with a result for 224 the default large size is .95

          offsetx = torch.randint(0, sideX - size + 1, ())
          offsety = torch.randint(0, sideY - size + 1, ())
          cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
          cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
        
        
        cutouts = torch.cat(cutouts, dim=0)
        cutouts = clamp_with_grad(cutouts, 0, 1)

        #if args.use_augs:
        cutouts = self.augs(cutouts)
        if self.noise_fac:
          facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_(0, self.noise_fac)
          cutouts = cutouts + facs * torch.randn_like(cutouts)
        return cutouts

class MakeCutoutsZynth(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow, augs):
        super().__init__()
        self.cut_size = cut_size
        tqdm.write(f'cut size: {self.cut_size}')
        self.cutn = cutn
        self.cut_pow = cut_pow
        self.noise_fac = 0.1
        self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
        self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
        self.augs = nn.Sequential(
        K.RandomHorizontalFlip(p=0.5),
        # K.RandomSolarize(0.01, 0.01, p=0.7),
        K.RandomSharpness(0.3,p=0.4),
        K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'),
        K.RandomPerspective(0.2,p=0.4),
        K.ColorJitter(hue=0.01, saturation=0.01, p=0.7))


    def set_cut_pow(self, cut_pow):
      self.cut_pow = cut_pow

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        cutouts_full = []
        noise_fac = 0.1
        
        
        min_size_width = min(sideX, sideY)
        lower_bound = float(self.cut_size/min_size_width)
        
        for ii in range(self.cutn):
            
            
          # size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
          randsize = torch.zeros(1,).normal_(mean=.8, std=.3).clip(lower_bound,1.)
          size_mult = randsize ** self.cut_pow
          size = int(min_size_width * (size_mult.clip(lower_bound, 1.))) # replace .5 with a result for 224 the default large size is .95
          # size = int(min_size_width*torch.zeros(1,).normal_(mean=.9, std=.3).clip(lower_bound, .95)) # replace .5 with a result for 224 the default large size is .95

          offsetx = torch.randint(0, sideX - size + 1, ())
          offsety = torch.randint(0, sideY - size + 1, ())
          cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
          cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
        
        
        cutouts = torch.cat(cutouts, dim=0)
        cutouts = clamp_with_grad(cutouts, 0, 1)

        #if args.use_augs:
        cutouts = self.augs(cutouts)
        if self.noise_fac:
          facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_(0, self.noise_fac)
          cutouts = cutouts + facs * torch.randn_like(cutouts)
        return cutouts

class MakeCutoutsWyvern(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow, augs):
        super().__init__()
        self.cut_size = cut_size
        tqdm.write(f'cut size: {self.cut_size}')
        self.cutn = cutn
        self.cut_pow = cut_pow
        self.noise_fac = 0.1
        self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
        self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
        self.augs = augs

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
        return clamp_with_grad(torch.cat(cutouts, dim=0), 0, 1)

flavors = {
    "phong": MakeCutoutsPhong,
    "cumin": MakeCutoutsCumin,
    "holywater": MakeCutoutsHolywater,
    "old_holywater": MakeCutoutsOldHolywater,
    "ginger": MakeCutoutsGinger,
    "zynth": MakeCutoutsZynth,
    "wyvern": MakeCutoutsWyvern,
    "aaron": MakeCutoutsAaron,
    "moth": MakeCutoutsMoth,
    "juu": MakeCutoutsJuu,
}

# Optimization loop

In [None]:
def spherical_dist_loss(x, y):
    x = F.normalize(x, dim=-1)
    y = F.normalize(y, dim=-1)
    return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)

def optimize_network(num_iterations, optimizer_type, lr, flavor, cutn, cut_pow, upsample_mode):
    global itt
    itt = 0

    if seed is not None:
        np.random.seed(seed)
        torch.manual_seed(seed)
        random.seed(seed)
    
    augs = None
    make_cutouts = flavors[flavor](clip_size, cutn, cut_pow, augs)

    # Initialize DIP skip network
    input_depth = 32
    net = get_net(
        input_depth, 'skip',
        pad='reflection',
        skip_n33d=128, skip_n33u=128,
        skip_n11=4, num_scales=5,
        upsample_mode=upsample_mode,
        # stride | avg | max | lanczos2
        # Stride, Avg and Max are pretty similar
        # Completely removing downsample_mode or using lanczos2 works best
        downsample_mode='lanczos2',
    ).to(device)

    # Initialize input noise
    net_input = torch.zeros([1, input_depth, sideY, sideX], device=device).normal_().div(10).detach()

    # Encode text prompt with CLIP
    target_embed = clip_model.encode_text(clip.tokenize(prompt).to(device)).float()

    if optimizer_type == 'Ranger21':
      optimizer = Ranger21(net.parameters(), lr, weight_decay=0.01, num_epochs=200, num_batches_per_epoch=1)
    elif optimizer_type == 'Adam':
      optimizer = torch.optim.Adam(net.parameters(), lr)
    elif optimizer_type == 'MadGrad':
      optimizer = MADGRAD(net.parameters(), lr, weight_decay=0.01, momentum=0.9)
        
    try:
        for _ in range(num_iterations):
            optimizer.zero_grad(set_to_none=True)
    
            out = net(net_input)
            cutouts = make_cutouts(out)
            image_embeds = clip_model.encode_image(clip_normalize(cutouts))
            loss = spherical_dist_loss(image_embeds, target_embed).mean()

            loss.backward()
            optimizer.step()

            itt += 1
            save_progress_video = False

            if itt % display_rate == 0 or save_progress_video:
                with torch.inference_mode():
                    image = TF.to_pil_image(out[0].clamp(0, 1))
                    if itt % display_rate == 0:
                        display.clear_output(wait=True)
                        display.display(image)
                        if display_augs:
                            aug_grid = torchvision.utils.make_grid(cutouts, nrow=math.ceil(math.sqrt(cutn)))
                            display.display(TF.to_pil_image(aug_grid.clamp(0, 1)))
                    if save_progress_video and itt > 15:
                        video_writer.append_data(np.asarray(image))

            if anneal_lr:
                optimizer.param_groups[0]['lr'] = max(0.00001, .99 * optimizer.param_groups[0]['lr'])

            print(f'Iteration:  {itt} / {num_iterations}')
    
    except KeyboardInterrupt:
        pass
    finally:
        return TF.to_pil_image(net(net_input)[0])

# Settings / Generate

In [None]:
#@title **Configure and Run**
#@markdown Make sure to add `#pixelart` at the end of your prompt
prompt = 'A slime monster. #pixelart' #@param{type:'string'}
aspect = "widescreen" #@param ["widescreen", "square"]
#@markdown seed 0 will generate a random seed every time
seed =  0#@param{type:'number'}
num_iterations = 250 #@param{type:'number'}
#@markdown Number of crops of image shown to CLIP, this can affect quality
cutn = 60 #@param{type:'number'}
#@markdown Values of cut_pow below 1 prioritize structure over detail, and vice versa for above 1
cut_pow = 0.85 #@param{type:'number'}
step_size = 0.0025 #@param{type:'number'}
display_rate = 50 #@param{type:'number'}

if aspect == 'widescreen':
  sideX = 640
  sideY = 384
  k = 5
elif aspect == 'square':
  sideX = 512
  sideY = 512
  k = 6
if seed == 0:
  seed = random.randint(0, 2**32)
  print("Seed: " + str(seed))
anneal_lr = True # True == lower the learning rate over time
lr = step_size

# Adam is pretty bad for pixelart
# MagGrad is a bit better
# Ranger21 is by far the best
opt_type = 'Ranger21' # Adam, MadGrad, Ranger21

# Most flavors give a neat effect, but not really pixelart-y
# Cumin is the best of them all, though moth is kind of neat as well
# All flavors are imported from hypertron V2
flavor = 'cumin' #["phong", "ginger", "cumin", "holywater", "zynth", "wyvern", "aaron", "moth", "juu"]

# Nearest is the best for pixel art
# Bicubic and bilinear are mostly the same, they make organic and smooth shapes
scaling_mode = "nearest" # nearest | bilinear | bicubic
display_augs = False # Display grid of augmented image, for debugging

# Begin optimization / generation
out = optimize_network(num_iterations, opt_type, lr, flavor, cutn, cut_pow, scaling_mode)

clear_output()

out.save(f'progress.png', quality=100)
input = cv2.imread('progress.png')
height, width = input.shape[:2]
w, h = (int(sideX/k), int(sideY/k))
temp = cv2.resize(input, (w, h), interpolation=cv2.INTER_LINEAR)
output = cv2.resize(temp, (width, height), interpolation=cv2.INTER_NEAREST)
cv2_imshow(output)