#Using CLIP and VQGAN Models to Generate Images from Text Prompts

In this notebook, we will the CLIP and VQGAN models to generate paintings from text prompts. We will then prepare this notebook to be used as an API for generating art on demand.

## Acknowledgements
The original notebook was made by [Katherine Crowson](https://github.com/crowsonkb).

With further modifications by [Justin John](https://colab.research.google.com/github/justinjohn0306/VQGAN-CLIP/blob/main/VQGAN%2BCLIP_%28z%2Bquantize_method_with_augmentations%2C_user_friendly_interface%29.ipynb#scrollTo=c3d7a8be-73ce-4cee-be70-e21c1210a7a6).


First, double checking the type of GPU that the notebook is using a GPU  since this notebook require significant computational power. <br> 
Depending on your Colab subscription, you will be likely getting different GPUs:

(slowest, not recommended) **P4 << K80 << T4 << P100 << V100** (fastest)


In [None]:
!nvidia-smi

## Initialize the System

Download the needed dependencies, we will be downloading the weights pretrained on the wikiart dataset.

In [None]:
# Initialize the System
!pip install --upgrade --no-cache-dir gdown
!nvidia-smi
print("Downloading CLIP...")
!git clone https://github.com/openai/CLIP                 &> /dev/null
 
print("Installing Python Libraries for AI")
!git clone https://github.com/CompVis/taming-transformers &> /dev/null
!pip install ftfy regex tqdm omegaconf pytorch-lightning  &> /dev/null
!pip install kornia                                       &> /dev/null
!pip install einops                                       &> /dev/null
print("Installing transformers library...")
!pip install transformers   
print("Installing taming.models...")   
!pip install taming.models                           &> /dev/null

print("Installing Python Libraries for API Dev")
!pip install pycurl fastapi uvicorn nest-asyncio pyngrok python-multipart py_eureka_client

%reload_ext autoreload
%autoreload                  &> /dev/null

!curl -L  'http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.yaml' > wikiart_16384.yaml
!curl -L  'http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.ckpt' > wikiart_16384.ckpt
print("Installation finished.")

Set up needed libraries and methods:

In [None]:
import argparse
import math
from pathlib import Path
import sys
 
sys.path.append('./taming-transformers')
from IPython import display
from base64 import b64encode
from omegaconf import OmegaConf
from PIL import Image
from taming.models import cond_transformer, vqgan
import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm
 
from CLIP import clip
import kornia.augmentation as K
import numpy as np
import imageio
from PIL import ImageFile, Image
import json
ImageFile.LOAD_TRUNCATED_IMAGES = True
 
def sinc(x):
    return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
 
def lanczos(x, a):
    cond = torch.logical_and(-a < x, x < a)
    out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
    return out / out.sum()
 
def ramp(ratio, width):
    n = math.ceil(width / ratio + 1)
    out = torch.empty([n])
    cur = 0
    for i in range(out.shape[0]):
        out[i] = cur
        cur += ratio
    return torch.cat([-out[1:].flip([0]), out])[1:-1]
 
def resample(input, size, align_corners=True):
    n, c, h, w = input.shape
    dh, dw = size
 
    input = input.view([n * c, 1, h, w])
 
    if dh < h:
        kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
        pad_h = (kernel_h.shape[0] - 1) // 2
        input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
        input = F.conv2d(input, kernel_h[None, None, :, None])
 
    if dw < w:
        kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
        pad_w = (kernel_w.shape[0] - 1) // 2
        input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
        input = F.conv2d(input, kernel_w[None, None, None, :])
 
    input = input.view([n, c, h, w])
    return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)
 
class ReplaceGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x_forward, x_backward):
        ctx.shape = x_backward.shape
        return x_forward
    @staticmethod
    def backward(ctx, grad_in):
        return None, grad_in.sum_to_size(ctx.shape)

replace_grad = ReplaceGrad.apply
 
class ClampWithGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, min, max):
        ctx.min = min
        ctx.max = max
        ctx.save_for_backward(input)
        return input.clamp(min, max)
 
    @staticmethod
    def backward(ctx, grad_in):
        input, = ctx.saved_tensors
        return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None
 
clamp_with_grad = ClampWithGrad.apply

def vector_quantize(x, codebook):
    d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T
    indices = d.argmin(-1)
    x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
    return replace_grad(x_q, x)
 
class Prompt(nn.Module):
    def __init__(self, embed, weight=1., stop=float('-inf')):
        super().__init__()
        self.register_buffer('embed', embed)
        self.register_buffer('weight', torch.as_tensor(weight))
        self.register_buffer('stop', torch.as_tensor(stop))
 
    def forward(self, input):
        input_normed = F.normalize(input.unsqueeze(1), dim=2)
        embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
        dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
        dists = dists * self.weight.sign()
        return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()
 
def parse_prompt(prompt):
    vals = prompt.rsplit(':', 2)
    vals = vals + ['', '1', '-inf'][len(vals):]
    return vals[0], float(vals[1]), float(vals[2])
 
class MakeCutouts(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow
        self.augs = nn.Sequential(
            K.RandomHorizontalFlip(p=0.5),
            # K.RandomSolarize(0.01, 0.01, p=0.7),
            K.RandomSharpness(0.3,p=0.4),
            K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'),
            K.RandomPerspective(0.2,p=0.4),
            K.ColorJitter(hue=0.01, saturation=0.01, p=0.7))
        self.noise_fac = 0.1
 
    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
        batch = self.augs(torch.cat(cutouts, dim=0))
        if self.noise_fac:
            facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
            batch = batch + facs * torch.randn_like(batch)
        return batch
 
def load_vqgan_model(config_path, checkpoint_path):
    config = OmegaConf.load(config_path)
    if config.model.target == 'taming.models.vqgan.VQModel':
        model = vqgan.VQModel(**config.model.params)
        model.eval().requires_grad_(False)
        model.init_from_ckpt(checkpoint_path)
    elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer':
        parent_model = cond_transformer.Net2NetTransformer(**config.model.params)
        parent_model.eval().requires_grad_(False)
        parent_model.init_from_ckpt(checkpoint_path)
        model = parent_model.first_stage_model
    else:
        raise ValueError(f'unknown model type: {config.model.target}')
    del model.loss
    return model
 
def resize_image(image, out_size):
    ratio = image.size[0] / image.size[1]
    area = min(image.size[0] * image.size[1], out_size[0] * out_size[1])
    size = round((area * ratio)**0.5), round((area / ratio)**0.5)
    return image.resize(size, Image.LANCZOS)

!mkdir steps
!mkdir output
!mkdir /content/current

##  Prepare the Arguments
This is the main method to prepare the arguments needed to generate art.  

In [None]:
import argparse
def prep_args(prompts = "rainbow sunflowers",
              width =  64,
              height =  64,
              display_frequency =  25,
              initial_image = "",
              target_images = "",
              learning_rate = 0.1,
              max_iterations = 100,
              input_images = ""):
  
  seed = None
  if initial_image == "None":
      initial_image = None
  if target_images == "None" or not target_images:
      target_images = []
  else:
      target_images = target_images.split("|")
      target_images = [image.strip() for image in target_images]

  if initial_image or target_images != []:
      input_images = True

  prompts = [phrase.strip() for phrase in prompts.split("|")]
  if prompts == ['']:
      prompts = []
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

  print('Using device:', device)
  if prompts:
      print('Using text prompt:', prompts)
  if target_images:
      print('Using image prompts:', target_images)

  args = argparse.Namespace(
      prompts=prompts,
      image_prompts=target_images,
      noise_prompt_seeds=[],
      noise_prompt_weights=[],
      size=[width, height],
      init_image=initial_image,
      init_weight=0.,
      clip_model='ViT-B/32',
      step_size=learning_rate,
      cutn=64,
      cut_pow=1.,
      display_freq=display_frequency,
      seed=seed,
    vqgan_config='wikiart_16384.yaml',         
      vqgan_checkpoint='wikiart_16384.ckpt',
      device=device,
      max_iterations=max_iterations)
  return args

## Create the Text-to-Image Generator
This is the main code for the Art generator

In [None]:
class Text2ImageGenerator():
  def __init__(self,args):
    torch.cuda.empty_cache()
    if args.seed is None:
        self.seed = torch.seed()
    else:
        self.seed = args.seed
    torch.manual_seed(self.seed)
    print('Using seed:', self.seed)
    self.max_iterations=args.max_iterations
    self.init_weight=args.init_weight
    self.display_freq=args.display_freq
    self.model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(args.device)
    self.perceptor = clip.load(args.clip_model, jit=False)[0].eval().requires_grad_(False).to(args.device)

    self.cut_size = self.perceptor.visual.input_resolution
    self.e_dim = self.model.quantize.e_dim
    self.f = 2**(self.model.decoder.num_resolutions - 1)
    self.make_cutouts = MakeCutouts(self.cut_size, args.cutn, cut_pow=args.cut_pow)
    self.n_toks = self.model.quantize.n_e
    self.toksX, self.toksY = args.size[0] // self.f, args.size[1] // self.f
    self.sideX, self.sideY = self.toksX * self.f, self.toksY * self.f
    self.z_min = self.model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
    self.z_max = self.model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]


    if args.init_image:
        self.pil_image = Image.open(args.init_image).convert('RGB')
        self.pil_image = self.pil_image.resize((self.sideX, self.sideY), Image.LANCZOS)
        self.z, *_ = self.model.encode(TF.to_tensor(self.pil_image).to(args.device).unsqueeze(0) * 2 - 1)
    else:
        self.one_hot = F.one_hot(torch.randint(self.n_toks, [self.toksY * self.toksX], device=args.device), self.n_toks).float()
        self.z = self.one_hot @ self.model.quantize.embedding.weight
        self.z = self.z.view([-1, self.toksY, self.toksX, self.e_dim]).permute(0, 3, 1, 2)
    self.z_orig = self.z.clone()
    self.z.requires_grad_(True)
    self.opt = optim.Adam([self.z], lr=args.step_size)

    self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                  std=[0.26862954, 0.26130258, 0.27577711])

    self.pMs = []

    for prompt in args.prompts:
        txt, weight, stop = parse_prompt(prompt)
        embed = self.perceptor.encode_text(clip.tokenize(txt).to(args.device)).float()
        self.pMs.append(Prompt(embed, weight, stop).to(args.device))

    for prompt in args.image_prompts:
        path, weight, stop = parse_prompt(prompt)
        img = resize_image(Image.open(path).convert('RGB'), (self.sideX, self.sideY))
        batch = self.make_cutouts(TF.to_tensor(img).unsqueeze(0).to(args.device))
        embed = self.perceptor.encode_image(self.normalize(batch)).float()
        self.pMs.append(Prompt(embed, weight, stop).to(args.device))

    for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights):
        gen = torch.Generator().manual_seed(seed)
        embed = torch.empty([1, self.perceptor.visual.output_dim]).normal_(generator=gen)
        self.pMs.append(Prompt(embed, weight).to(args.device))

  def synth(self):
      self.z_q = vector_quantize(self.z.movedim(1, 3), self.model.quantize.embedding.weight).movedim(3, 1)
      return clamp_with_grad(self.model.decode(self.z_q).add(1).div(2), 0, 1)

  @torch.no_grad()
  def checkin(self):
      losses_str = ', '.join(f'{loss.item():g}' for loss in self.losses)
      tqdm.write(f'i: {self.i}, loss: {sum(self.losses).item():g}, losses: {losses_str}')
      out = self.synth()
      TF.to_pil_image(out[0].cpu()).save('progress.png')
      display.display(display.Image('progress.png'))

  def ascend_txt(self):
      out = self.synth()
      iii = self.perceptor.encode_image(self.normalize(self.make_cutouts(out))).float()
      result = []
      if self.init_weight:
          result.append(F.mse_loss(self.z, self.z_orig) * self.init_weight / 2)
      for prompt in self.pMs:
          result.append(prompt(iii))
      img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:]
      img = np.transpose(img, (1, 2, 0))
      filename = f"steps/{self.i:04}.png"
      imageio.imwrite(filename, np.array(img))
      imageio.imwrite("/content/current/current.png", np.array(img))
      with open("/content/current/current.txt","w") as f:
        f.write(str(self.i))

      return result

  def train(self):
      self.opt.zero_grad()
      self.losses = self.ascend_txt()
      if self.i % self.display_freq == 0:
          self.checkin()
      loss = sum(self.losses)
      loss.backward()
      self.opt.step()
      with torch.no_grad():
          self.z.copy_(self.z.maximum(self.z_min).minimum(self.z_max))
  
  def generate(self):
    self.i = 0
    try:
        with tqdm() as pbar:
            while True:
                self.train()
                if self.i == self.max_iterations:
                    break
                self.i += 1
                status=self.i
                print(self.i)
                pbar.update()
                #yield from open(f"steps/{self.i:04}.png",'rb')
    except KeyboardInterrupt:
        pass

## Preparing the API
In order to run the API from colab, you will need to authenticate  pyngrok with your own token. Make sure to create an account and copy the token below:

In [None]:
!pyngrok  authtoken "ADD_YOUR_TOKEN_HERE"

Now let's build the Fast API application:

In [None]:
from fastapi import FastAPI
import nest_asyncio
import uvicorn
from pyngrok import ngrok
from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI, File, UploadFile, Form, BackgroundTasks
from fastapi.responses import StreamingResponse
from io import BytesIO
import asyncio
from http import HTTPStatus
import os
from py_eureka_client import eureka_client

app = FastAPI()
rest_server_port=8000
eureka_client.init(eureka_server="https://artist-block-discovery-service.herokuapp.com/eureka",
                  app_name="GAN-Model",
                   instance_port=rest_server_port)


origins = ["*"]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

Let's define now the endpoints we will be exposing in our API:

In [None]:
from starlette.concurrency import run_in_threadpool

@app.post('/v1/generate_text2img')
async def generate_text2img(prompt:str,width:int,height:int):
  args=prep_args(prompts=prompt,width=width,height=height,max_iterations=100)
  gen_obj=Text2ImageGenerator(args)
  imarray = np.random.rand(width,height,3) * 255
  im = Image.fromarray(imarray.astype('uint8')).convert('RGB')
  im.save('/content/current/current.png')
  await run_in_threadpool(gen_obj.generate)
  output_image = Image.open("/content/steps/0100.png")
  sent_image = BytesIO()
  output_image.save(sent_image, "JPEG")
  sent_image.seek(0)
  return StreamingResponse(sent_image, media_type="image/jpeg")

@app.post('/v1/generate_textimg2img')
async def generate_text2img(prompt:str,width:int,height:int,image: UploadFile=File(...)):
  args=prep_args(prompts=prompt,initial_image=image.file,width=width,height=height,max_iterations=100)
  gen_obj=Text2ImageGenerator(args)
  imarray = np.random.rand(width,height,3) * 255
  im = Image.fromarray(imarray.astype('uint8')).convert('RGB')
  im.save('/content/current/current.png')
  await run_in_threadpool(gen_obj.generate)
  output_image = Image.open("/content/steps/0100.png")
  sent_image = BytesIO()
  output_image.save(sent_image, "JPEG")
  sent_image.seek(0)
  return StreamingResponse(sent_image, media_type="image/jpeg")


@app.post('/v1/generate_text2img_progressimg')
async def generate_text2img_progressimgs():
  output_image = Image.open("/content/current/current.png")
  sent_image = BytesIO()
  output_image.save(sent_image, "JPEG")
  sent_image.seek(0)
  return StreamingResponse(sent_image, media_type="image/jpeg")

@app.post('/v1/generate_text2img_progressbar')
async def generate_text2img_progressbar():
  with open("/content/current/current.txt","r") as f:
        i=f.read()
  return i


Now we are ready to set up the API:

In [None]:
ngrok_tunnel = ngrok.connect(8000)
print('Public URL:', ngrok_tunnel.public_url)
print('Public URL docs:', os.path.join(ngrok_tunnel.public_url,"docs"))
nest_asyncio.apply()
uvicorn.run(app, host="0.0.0.0", port=8000)

## Testing The Model
We can even test the model here:

In [None]:
#@title Generate an Image
import argparse
prompts = "flowering magenta orchids in a rainy day impressionist" #@param {type:"string"}
width =  32#@param {type:"number"}
height =  32#@param {type:"number"}
display_frequency =  25#@param {type:"number"}
initial_image = ""#@param {type:"string"}
target_images = ""#@param {type:"string"}
learning_rate = 0.1 #@param {type:"slider", min:0.01, max:1.0, step:0.01}
max_iterations = 200#@param {type:"number"}
input_images = ""

seed = None
if initial_image == "None":
    initial_image = None
if target_images == "None" or not target_images:
    target_images = []
else:
    target_images = target_images.split("|")
    target_images = [image.strip() for image in target_images]

if initial_image or target_images != []:
    input_images = True

prompts = [frase.strip() for frase in prompts.split("|")]
if prompts == ['']:
    prompts = []

args = argparse.Namespace(
    prompts=prompts,
    image_prompts=target_images,
    noise_prompt_seeds=[],
    noise_prompt_weights=[],
    size=[width, height],
    init_image=initial_image,
    init_weight=0.,
    clip_model='ViT-B/32',
    step_size=learning_rate,
    cutn=64,
    cut_pow=1.,
    display_freq=display_frequency,
    seed=seed,
  vqgan_config='wikiart_16384.yaml',         
    vqgan_checkpoint='wikiart_16384.ckpt'
    )

import torch
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
if prompts:
    print('Using text prompt:', prompts)
if target_images:
    print('Using image prompts:', target_images)
if args.seed is None:
    seed = torch.seed()
else:
    seed = args.seed
torch.manual_seed(seed)
print('Using seed:', seed)

#model = torch.load("g4nshAr31w4tErwEI6hTz.ckpt").to(device)
model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device)
perceptor = clip.load(args.clip_model, jit=False)[0].eval().requires_grad_(False).to(device)

cut_size = perceptor.visual.input_resolution
e_dim = model.quantize.e_dim
f = 2**(model.decoder.num_resolutions - 1)
make_cutouts = MakeCutouts(cut_size, args.cutn, cut_pow=args.cut_pow)
n_toks = model.quantize.n_e
toksX, toksY = args.size[0] // f, args.size[1] // f
sideX, sideY = toksX * f, toksY * f
z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]


if args.init_image:
    pil_image = Image.open(args.init_image).convert('RGB')
    pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
    z, *_ = model.encode(TF.to_tensor(pil_image).to(device).unsqueeze(0) * 2 - 1)
else:
    one_hot = F.one_hot(torch.randint(n_toks, [toksY * toksX], device=device), n_toks).float()
    z = one_hot @ model.quantize.embedding.weight
    z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
z_orig = z.clone()
z.requires_grad_(True)
opt = optim.Adam([z], lr=args.step_size)

normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                 std=[0.26862954, 0.26130258, 0.27577711])

pMs = []

for prompt in args.prompts:
    txt, weight, stop = parse_prompt(prompt)
    embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
    pMs.append(Prompt(embed, weight, stop).to(device))

for prompt in args.image_prompts:
    path, weight, stop = parse_prompt(prompt)
    img = resize_image(Image.open(path).convert('RGB'), (sideX, sideY))
    batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
    embed = perceptor.encode_image(normalize(batch)).float()
    pMs.append(Prompt(embed, weight, stop).to(device))

for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights):
    gen = torch.Generator().manual_seed(seed)
    embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen)
    pMs.append(Prompt(embed, weight).to(device))

def synth(z):
    z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1)
    # print("\n", z_q.shape)
    return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1)

@torch.no_grad()
def checkin(i, losses):
    losses_str = ', '.join(f'{loss.item():g}' for loss in losses)
    tqdm.write(f'i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}')
    out = synth(z)
    TF.to_pil_image(out[0].cpu()).save('progress.png')
    display.display(display.Image('progress.png'))

def ascend_txt():
    global i
    out = synth(z)
    iii = perceptor.encode_image(normalize(make_cutouts(out))).float()

    result = []

    if args.init_weight:
        result.append(F.mse_loss(z, z_orig) * args.init_weight / 2)

    for prompt in pMs:
        result.append(prompt(iii))
    img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:]
    img = np.transpose(img, (1, 2, 0))
    filename = f"steps/{i:04}.png"
    imageio.imwrite(filename, np.array(img))
    return result

def train(i):
    opt.zero_grad()
    lossAll = ascend_txt()
    if i % args.display_freq == 0:
        checkin(i, lossAll)
    loss = sum(lossAll)
    loss.backward()
    opt.step()
    with torch.no_grad():
        z.copy_(z.maximum(z_min).minimum(z_max))

i = 0
try:
    with tqdm() as pbar:
        while True:
            train(i)
            if i == max_iterations:
                break
            i += 1
            pbar.update()
except KeyboardInterrupt:
    pass

print("Not for sale.")