<a href="https://colab.research.google.com/github/afiaka87/deep-daze-ffm/blob/main/Text2Image_FFT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Text to Image

Based on [CLIP](https://github.com/openai/CLIP) + FFT from [Lucent](https://github.com/greentfrapp/lucent) // made by [eps696](https://github.com/eps696) [Vadim Epstein]  
thanks to [Ryan Murdock](https://rynmurdock.github.io/), [Jonathan Fly](https://twitter.com/jonathanfly), [@tg-bomze](https://github.com/tg-bomze) 

## Features 
* complex requests:
  * image and/or text as main prompts  
   (composition similarity controlled with [SSIM](https://github.com/Po-Hsun-Su/pytorch-ssim) loss)
  * additional text prompts for fine details and to subtract (avoid) things
  * criteria inversion (show "the opposite")

* generates [FFT-encoded](https://github.com/greentfrapp/lucent/blob/master/lucent/optvis/param/spatial.py) image (massive detailed textures, a la deepdream)
* ! fast convergence
* ! undemanding for RAM - fullHD/4K and above
* can use both CLIP models at once (ViT and RN50)


**Run this cell after each session restart**

In [None]:
#@title General setup

import subprocess
CUDA_version = [s for s in subprocess.check_output(["nvcc", "--version"]).decode("UTF-8").split(", ") if s.startswith("release")][0].split(" ")[-1]
print("CUDA version:", CUDA_version)

if CUDA_version == "10.0":
    torch_version_suffix = "+cu100"
elif CUDA_version == "10.1":
    torch_version_suffix = "+cu101"
elif CUDA_version == "10.2":
    torch_version_suffix = ""
else:
    torch_version_suffix = "+cu110"

!pip install torch==1.7.1{torch_version_suffix} torchvision==0.8.2{torch_version_suffix} -f https://download.pytorch.org/whl/torch_stable.html ftfy regex

try: 
  !pip3 install googletrans==3.1.0a0
  from googletrans import Translator, constants
  # from pprint import pprint
  translator = Translator()
except: pass
!pip install ftfy

import os
import time
from math import exp
import random
import imageio
import numpy as np
import PIL
from skimage import exposure
from base64 import b64encode

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.autograd import Variable

from IPython.display import HTML, Image, display, clear_output
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
import ipywidgets as ipy
# import glob
from google.colab import output, files

import warnings
warnings.filterwarnings("ignore")

!git clone https://github.com/openai/CLIP.git
%cd /content/CLIP/
import clip
perceptor, preprocess = clip.load('ViT-B/32')
model_vit, _ = clip.load('ViT-B/32')

workdir = '_out'
tempdir = os.path.join(workdir, 'ttt')
os.makedirs(tempdir, exist_ok=True)

clear_output()

###  FFT from Lucent library  https://github.com/greentfrapp/lucent

def pixel_image(shape, sd=2.):
    tensor = (torch.randn(*shape) * sd).cuda().requires_grad_(True)
    return [tensor], lambda: tensor

# From https://github.com/tensorflow/lucid/blob/master/lucid/optvis/param/spatial.py
def rfft2d_freqs(h, w):
    """Computes 2D spectrum frequencies."""
    fy = np.fft.fftfreq(h)[:, None]
    # when we have an odd input dimension we need to keep one additional frequency and later cut off 1 pixel
    if w % 2 == 1:
        fx = np.fft.fftfreq(w)[: w // 2 + 2]
    else:
        fx = np.fft.fftfreq(w)[: w // 2 + 1]
    return np.sqrt(fx * fx + fy * fy)

def fft_image(shape, sd=0.1, decay_power=1., smooth_col=1.):
    batch, channels, h, w = shape
    freqs = rfft2d_freqs(h, w)
    init_val_size = (batch, channels) + freqs.shape + (2,) # 2 for imaginary and real components
    spectrum_real_imag_t = (torch.randn(*init_val_size) * sd).cuda().requires_grad_(True)
    scale = 1.0 / np.maximum(freqs, 1.0 / max(w, h)) ** decay_power
    scale = torch.tensor(scale).float()[None, None, ..., None].cuda()

    def inner():
        scaled_spectrum_t = scale * spectrum_real_imag_t
        image = torch.irfft(scaled_spectrum_t, 2, normalized=True, signal_sizes=(h, w))
        image = image[:batch, :channels, :h, :w]
        image = image / (1 + image.std())**1.3 # keep contrast
        image = image * 4. / smooth_col # more desaturation, smoothen colors & contrast
        return image
    return [spectrum_real_imag_t], inner

def to_valid_rgb(image_f, decorrelate=True):
    def inner():
        image = image_f()
        if decorrelate:
            image = _linear_decorrelate_color(image)
        return torch.sigmoid(image)
    return inner
    
def _linear_decorrelate_color(tensor):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    t_permute = tensor.permute(0,2,3,1)
    t_permute = torch.matmul(t_permute, torch.tensor(color_correlation_normalized.T).to(device))
    tensor = t_permute.permute(0,3,1,2)
    return tensor

color_correlation_svd_sqrt = np.asarray([[0.26, 0.09, 0.02],
                                         [0.27, 0.00, -0.05],
                                         [0.27, -0.09, 0.03]]).astype("float32")
max_norm_svd_sqrt = np.max(np.linalg.norm(color_correlation_svd_sqrt, axis=0))
color_correlation_normalized = color_correlation_svd_sqrt / max_norm_svd_sqrt

### Libs

def slice_imgs(imgs, count, transform=None, uniform=False, micro=None):
  def map(x, a, b):
    return x * (b-a) + a

  rnd_size = torch.rand(count)
  if uniform is True:
    rnd_offx = torch.rand(count)
    rnd_offy = torch.rand(count)
  else: # normal around center
    rnd_offx = torch.clip(torch.randn(count) * 0.2 + 0.5, 0, 1) 
    rnd_offy = torch.clip(torch.randn(count) * 0.2 + 0.5, 0, 1)
  
  sz = [img.shape[2:] for img in imgs]
  sz_min = [np.min(s) for s in sz]
  if uniform is True:
    sz = [[2*s[0], 2*s[1]] for s in list(sz)]
    imgs = [pad_up_to(imgs[i], sz[i], type='centr') for i in range(len(imgs))]

  sliced = []
  for i, img in enumerate(imgs):
    cuts = []
    for c in range(count):
      if micro is True: # both scales, micro mode
        csize = map(rnd_size[c], 64, max(224, 0.25*sz_min[i])).int()
      elif micro is False: # both scales, macro mode
        csize = map(rnd_size[c], 0.5*sz_min[i], 0.98*sz_min[i]).int()
      else: # single scale
        csize = map(rnd_size[c], 64, 0.98*sz_min[i]).int()
      offsetx = map(rnd_offx[c], 0, sz[i][1] - csize).int()
      offsety = map(rnd_offy[c], 0, sz[i][0] - csize).int()
      cut = img[:, :, offsety:offsety + csize, offsetx:offsetx + csize]
      cut = torch.nn.functional.interpolate(cut, (224,224), mode='bicubic')
      if transform is not None: 
        cut = transform(cut)
      cuts.append(cut)
    sliced.append(torch.cat(cuts, 0))
  return sliced

def makevid(seq_dir, size=None):
  out_sequence = seq_dir + '/%03d.jpg'
  out_video = seq_dir + '.mp4'
  !ffmpeg -y -v warning -i $out_sequence $out_video
  data_url = "data:video/mp4;base64," + b64encode(open(out_video,'rb').read()).decode()
  wh = '' if size is None else 'width=%d height=%d' % (size, size)
  return """<video %s controls><source src="%s" type="video/mp4"></video>""" % (wh, data_url)

# Tiles an array around two points, allowing for pad lengths greater than the input length
# adapted from https://discuss.pytorch.org/t/symmetric-padding/19866/3
def tile_pad(xt, padding):
  h, w = xt.shape[-2:]
  left, right, top, bottom = padding

  def tile(x, minx, maxx):
    rng = maxx - minx
    mod = np.remainder(x - minx, rng)
    out = mod + minx
    return np.array(out, dtype=x.dtype)

  x_idx = np.arange(-left, w+right)
  y_idx = np.arange(-top, h+bottom)
  x_pad = tile(x_idx, -0.5, w-0.5)
  y_pad = tile(y_idx, -0.5, h-0.5)
  xx, yy = np.meshgrid(x_pad, y_pad)
  return xt[..., yy, xx]

def pad_up_to(x, size, type='centr'):
  sh = x.shape[2:][::-1]
  if list(x.shape[2:]) == list(size): return x
  padding = []
  for i, s in enumerate(size[::-1]):
    if 'side' in type.lower():
      padding = padding + [0, s-sh[i]]
    else: # centr
      p0 = (s-sh[i]) // 2
      p1 = s-sh[i] - p0
      padding = padding + [p0,p1]
  y = tile_pad(x, padding)
  return y

class ProgressBar(object):
  def __init__(self, task_num=10):
    self.pbar = ipy.IntProgress(min=0, max=task_num, bar_style='') # (value=0, min=0, max=max, step=1, description=description, bar_style='')
    self.labl = ipy.Label()
    display(ipy.HBox([self.pbar, self.labl]))
    self.task_num = task_num
    self.completed = 0
    self.start()

  def start(self, task_num=None):
    if task_num is not None:
      self.task_num = task_num
    if self.task_num > 0:
      self.labl.value = '0/{}'.format(self.task_num)
    else:
      self.labl.value = 'completed: 0, elapsed: 0s'
    self.start_time = time.time()

  def upd(self, *p, **kw):
    self.completed += 1
    elapsed = time.time() - self.start_time + 0.0000000000001
    fps = self.completed / elapsed if elapsed>0 else 0
    if self.task_num > 0:
      finaltime = time.asctime(time.localtime(self.start_time + self.task_num * elapsed / float(self.completed)))
      fin = ' end %s' % finaltime[11:16]
      percentage = self.completed / float(self.task_num)
      eta = int(elapsed * (1 - percentage) / percentage + 0.5)
      self.labl.value = '{}/{}, rate {:.3g}s, time {}s, left {}s, {}'.format(self.completed, self.task_num, 1./fps, shortime(elapsed), shortime(eta), fin)
    else:
      self.labl.value = 'completed {}, time {}s, {:.1f} steps/s'.format(self.completed, int(elapsed + 0.5), fps)
    self.pbar.value += 1
    if self.completed == self.task_num: self.pbar.bar_style = 'success'
    return 
    # return self.completed

def time_days(sec):
  return '%dd %d:%02d:%02d' % (sec/86400, (sec/3600)%24, (sec/60)%60, sec%60)
def time_hrs(sec):
  return '%d:%02d:%02d' % (sec/3600, (sec/60)%60, sec%60)
def shortime(sec):
  if sec < 60:
    time_short = '%d' % (sec)
  elif sec < 3600:
    time_short  = '%d:%02d' % ((sec/60)%60, sec%60)
  elif sec < 86400:
    time_short  = time_hrs(sec)
  else:
    time_short = time_days(sec)
  return time_short

# from https://github.com/Po-Hsun-Su/pytorch-ssim

def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()

def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window

def _ssim(img1, img2, window, window_size, channel, size_average = True):
  mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
  mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
  mu1_sq = mu1.pow(2)
  mu2_sq = mu2.pow(2)
  mu1_mu2 = mu1*mu2
  sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
  sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
  sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
  C1 = 0.01**2
  C2 = 0.03**2
  ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
  if size_average:
    return ssim_map.mean()
  else:
    return ssim_map.mean(1).mean(1).mean(1)

class SSIM(torch.nn.Module):
  def __init__(self, window_size = 11, size_average = True):
    super(SSIM, self).__init__()
    self.window_size = window_size
    self.size_average = size_average
    self.channel = 1
    self.window = create_window(window_size, self.channel)

  def forward(self, img1, img2):
    (_, channel, _, _) = img1.size()
    if channel == self.channel and self.window.data.type() == img1.data.type():
      window = self.window
    else:
      window = create_window(self.window_size, channel)
      if img1.is_cuda:
        window = window.cuda(img1.get_device())
      window = window.type_as(img1)
      self.window = window
      self.channel = channel
    return _ssim(img1, img2, window, self.window_size, channel, self.size_average)

!nvidia-smi -L
print('\nDone!')

Type some `text` and/or upload some image to start.  
`fine_details` input would add micro details of that topic.  
Put to `subtract` the topics, which you would like to avoid in the result.  
*NB: more prompts = more memory! (handled by auto-decreasing `samples` amount, hopefully you don't need to act).*  
`invert` the whole criteria, if you want to see "the totally opposite".

In [None]:
#@title Input

text = "a recursive forest" #@param {type:"string"}
fine_details = "the texture of a dead tree" #@param {type:"string"}
subtract = "sharpness" #@param {type:"string"}
translate = False #@param {type:"boolean"}
invert = False #@param {type:"boolean"}
upload_image = True #@param {type:"boolean"}

if translate:
  text = translator.translate(text, dest='en').text
if upload_image:
  uploaded = files.upload()

`uniform` option produces seamlessly tileable texture (when off, it's centered).  
`sync` value adds SSIM loss between the output and input image (if there's one), allowing to "redraw" it with controlled similarity.  
`smooth_col` scaler desaturates image. *There's some empirical auto-tuning in place already, so hopefully it's not really needed anymore.*  

Turn on `dual_model` to optimize with both CLIP models at once (eats more RAM!).  
Decrease `samples` if you face OOM for higher resolutions (especially when several prompts are used with dual model).  
Setting `steps` much higher (1000-..) will elaborate details much better, but will start throwing texts like graffiti everywhere.

In [None]:
#@title Generate

# from google.colab import drive
# drive.mount('/content/GDrive')
# clipsDir = '/content/GDrive/MyDrive/T2I ' + dtNow.strftime("%Y-%m-%d %H%M")

!rm -rf tempdir

sideX =  1920#@param {type:"integer"}
sideY =  1080#@param {type:"integer"}
#@markdown > Tweaks & tuning
dual_model = False #@param {type:"boolean"}
uniform = False #@param {type:"boolean"}
sync =  0 #@param {type:"number"}
smooth_col =  1.#@param {type:"number"}
#@markdown > Training
steps = 1000 #@param {type:"integer"}
samples = 128 #@param {type:"integer"}
learning_rate = .05 #@param {type:"number"}
#@markdown > Misc
save_freq =  50#@param {type:"integer"}
audio_notification = False #@param {type:"boolean"}

if dual_model is True:
  print(' using dual-model optimization')
  model_rn, _ = clip.load('RN50')
  samples = samples // 2
if len(fine_details) > 0:
  samples = int(samples * 0.9)
if len(subtract) > 0:
  samples = int(samples * 0.9)
print(' using %d samples' % samples)

norm_in = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
sign = 1. if invert is True else -1.

if upload_image:
  in_img = list(uploaded.values())[0]
  print(' image:', list(uploaded)[0])
  img_in = torch.from_numpy(imageio.imread(in_img).astype(np.float32)/255.).unsqueeze(0).permute(0,3,1,2).cuda()
  in_sliced = slice_imgs([img_in], samples, transform=norm_in)[0]
  img_enc = model_vit.encode_image(in_sliced).detach().clone()
  if dual_model is True:
    img_enc = torch.cat((img_enc, model_rn.encode_image(in_sliced).detach().clone()), 1)
  if sync > 0:
    ssim_loss = SSIM(window_size = 11)
    img_in = F.interpolate(img_in, (sideY, sideX)).float()
  else:
    del img_in
  del in_sliced; torch.cuda.empty_cache()

if len(text) > 2:
  print(' macro:', text)
  if translate:
    translator = Translator()
    text = translator.translate(text, dest='en').text
    print(' translated to:', text) 
  tx = clip.tokenize(text)
  txt_enc = model_vit.encode_text(tx.cuda()).detach().clone()
  if dual_model is True:
    txt_enc = torch.cat((txt_enc, model_rn.encode_text(tx.cuda()).detach().clone()), 1)

if len(fine_details) > 0:
  print(' micro:', fine_details)
  if translate:
      translator = Translator()
      fine_details = translator.translate(fine_details, dest='en').text
      print(' translated to:', fine_details) 
  tx2 = clip.tokenize(fine_details)
  txt_enc2 = model_vit.encode_text(tx2.cuda()).detach().clone()
  if dual_model is True:
      txt_enc2 = torch.cat((txt_enc2, model_rn.encode_text(tx2.cuda()).detach().clone()), 1)

if len(subtract) > 0:
  print(' without:', subtract)
  if translate:
      translator = Translator()
      subtract = translator.translate(subtract, dest='en').text
      print(' translated to:', subtract) 
  tx0 = clip.tokenize(subtract)
  txt_enc0 = model_vit.encode_text(tx0.cuda()).detach().clone()
  if dual_model is True:
      txt_enc0 = torch.cat((txt_enc0, model_rn.encode_text(tx0.cuda()).detach().clone()), 1)

shape = [1, 3, sideY, sideX]
param_f = fft_image 
# param_f = pixel_image
# learning_rate = 1.
params, image_f = param_f(shape, smooth_col=smooth_col)
image_f = to_valid_rgb(image_f)
optimizer = torch.optim.Adam(params, learning_rate)

def displ(img, fname=None):
  img = np.array(img)[:,:,:]
  img = np.transpose(img, (1,2,0))  
  img = exposure.equalize_adapthist(np.clip(img, 0., 1.))
  img = np.clip(img*255, 0, 255).astype(np.uint8)
  if fname is not None:
    imageio.imsave(fname, np.array(img))
    imageio.imsave('result.jpg', np.array(img))

def checkin(num):
  with torch.no_grad():
    img = image_f().cpu().numpy()[0]
  displ(img, os.path.join(tempdir, '%03d.jpg' % num))
  outpic.clear_output()
  with outpic:
    display(Image('result.jpg'))

def train(i):
  loss = 0
  img_out = image_f()

  micro = False if len(fine_details) > 0 else None
  imgs_sliced = slice_imgs([img_out], samples, norm_in, uniform=uniform, micro=micro)
  out_enc = model_vit.encode_image(imgs_sliced[-1])
  if dual_model is True: # use both clip models
      out_enc = torch.cat((out_enc, model_rn.encode_image(imgs_sliced[-1])), 1)
  if upload_image:
      loss += sign * 100*torch.cosine_similarity(img_enc, out_enc, dim=-1).mean()
  if len(text) > 0: # input text
      loss += sign * 100*torch.cosine_similarity(txt_enc, out_enc, dim=-1).mean()
  if len(subtract) > 0: # subtract text
      loss += -sign * 100*torch.cosine_similarity(txt_enc0, out_enc, dim=-1).mean()
  if sync > 0 and upload_image: # image composition sync
      loss *= 1. + sync * (steps/(i+1) * ssim_loss(img_out, img_in) - 1)
  if len(fine_details) > 0: # input text for micro details
      imgs_sliced = slice_imgs([img_out], samples, norm_in, uniform=uniform, micro=True)
      out_enc2 = model_vit.encode_image(imgs_sliced[-1])
      if dual_model is True:
          out_enc2 = torch.cat((out_enc2, model_rn.encode_image(imgs_sliced[-1])), 1)
      loss += sign * 100*torch.cosine_similarity(txt_enc2, out_enc2, dim=-1).mean()
      del out_enc2; torch.cuda.empty_cache()
  del img_out, imgs_sliced, out_enc; torch.cuda.empty_cache()

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  
  if i % save_freq == 0:
    checkin(i // save_freq)

outpic = ipy.Output()
outpic

pbar = ProgressBar(steps)
for i in range(steps):
  train(i)
  _ = pbar.upd()

HTML(makevid(tempdir))
files.download('_out/ttt.mp4')
if audio_notification == True: output.eval_js('new Audio("https://freesound.org/data/previews/80/80921_1022651-lq.ogg").play()')
