# **StyleGANXL + CLIP 🖼️**

## Generate images from text prompts using StyleGANXL with CLIP guidance.

(Modified by Katherine Crowson to optimize in W+ space)

This notebook is a work in progress, head over [here](https://github.com/CasualGANPapers/unconditional-StyleGAN-CLIP) if you want to be up to date with its changes.

Largely based on code by  [Katherine Crowson](https://github.com/crowsonkb) and [nshepperd](https://github.com/nshepperd).

Mostly made possible because of [StyleGAN-XL](https://github.com/autonomousvision/stylegan_xl) and [CLIP](https://github.com/openai/CLIP).

Created by [Eugenio Herrera](https://github.com/ouhenio) and [Rodrigo Mello](https://github.com/ryudrigo).


In [1]:
doClone = False
doPip = False
doCPU = False # << We can do this on CPU but it's a lot slower... (there's a final check for CUDA as well)

In [2]:
if doClone:
    !git clone https://github.com/autonomousvision/stylegan_xl
    !git clone https://github.com/openai/CLIP
    !git clone https://github.com/crowsonkb/esgd.git

In [3]:
if doPip:
    !pip install -e ./CLIP
    !pip install einops ninja
    !pip install timm==0.5.4
    !pip install dill

In [4]:
import sys
sys.path.append('./CLIP')
sys.path.append('./stylegan_xl')
sys.path.append('./esgd')

import io
import os, time, glob
import pickle
import shutil
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
import requests
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import clip
import unicodedata
import re
from esgd import ESGD
from tqdm.notebook import tqdm
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from IPython.display import display
from einops import rearrange
#from google.colab import files << No need for this local version
import dnnlib
import legacy

In [5]:
# Functions 

def fetch(url_or_path):
    if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
        r = requests.get(url_or_path)
        r.raise_for_status()
        fd = io.BytesIO()
        fd.write(r.content)
        fd.seek(0)
        return fd
    return open(url_or_path, 'rb')

def fetch_model(url_or_path):
    !wget -c '{url_or_path}'

def slugify(value, allow_unicode=False):
    """
    Taken from https://github.com/django/django/blob/master/django/utils/text.py
    Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
    dashes to single dashes. Remove characters that aren't alphanumerics,
    underscores, or hyphens. Convert to lowercase. Also strip leading and
    trailing whitespace, dashes, and underscores.
    """
    value = str(value)
    if allow_unicode:
        value = unicodedata.normalize('NFKC', value)
    else:
        value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
    value = re.sub(r'[^\w\s-]', '', value.lower())
    return re.sub(r'[-\s]+', '-', value).strip('-_')

def norm1(prompt):
    "Normalize to the unit sphere."
    return prompt / prompt.square().sum(dim=-1,keepdim=True).sqrt()

def spherical_dist_loss(x, y):
    x = F.normalize(x, dim=-1)
    y = F.normalize(y, dim=-1)
    return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)

def prompts_dist_loss(x, targets, loss):
    if len(targets) == 1: # Keeps consitent results vs previous method for single objective guidance
      return loss(x, targets[0])
    distances = [loss(x, target) for target in targets]
    return torch.stack(distances, dim=-1).sum(dim=-1)

class MakeCutouts(torch.nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
        return torch.cat(cutouts)

make_cutouts = MakeCutouts(224, 32, 0.5)

def embed_image(image):
  n = image.shape[0]
  cutouts = make_cutouts(image)
  embeds = clip_model.embed_cutout(cutouts)
  embeds = rearrange(embeds, '(cc n) c -> cc n c', n=n)
  return embeds

def embed_url(url):
  image = Image.open(fetch(url)).convert('RGB')
  return embed_image(TF.to_tensor(image).to(device).unsqueeze(0)).mean(0).squeeze(0)

#class CLIP(object):
 # def __init__(self):
  #  clip_model = "ViT-B/16"
   # self.model, _ = clip.load(clip_model)
    #self.model = self.model.requires_grad_(False)
    #self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                          #std=[0.26862954, 0.26130258, 0.27577711])

 # @torch.no_grad()
  #def embed_text(self, prompt):
   #   "Normalized clip text embedding."
    #  return norm1(self.model.encode_text(clip.tokenize(prompt).to(device)).float())
#
 # def embed_cutout(self, image):
  #    "Normalized clip image embedding."
   #   return norm1(self.model.encode_image(self.normalize(image)))

#clip_model = CLIP()

In [6]:
# Select model
Model = 'Imagenet-1024' #@param ["Imagenet-1024", "Imagenet-512", "Imagenet-256", "Imagenet-128", "Pokemon", "FFHQ"]

network_url = {
    "Imagenet-1024": "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet1024.pkl",
    "Imagenet-512": "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet512.pkl",
    "Imagenet-256": "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet256.pkl",
    "Imagenet-128": "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet128.pkl",
    "Pokemon-1024": "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon1024.pkl",
    "Pokemon-512": "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon512.pkl",
    "Pokemon-256": "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon256.pkl",
    "FFHQ-256": "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/ffhq256.pkl"
}

network_name = network_url[Model].split("/")[-1]
fetch_model(network_url[Model])

--2024-07-22 14:29:26--  https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet1024.pkl
Resolving s3.eu-central-1.amazonaws.com (s3.eu-central-1.amazonaws.com)... 3.5.135.211, 52.219.46.95, 52.219.170.93, ...
Connecting to s3.eu-central-1.amazonaws.com (s3.eu-central-1.amazonaws.com)|3.5.135.211|:443... connected.
HTTP request sent, awaiting response... 416 Requested Range Not Satisfiable

    The file is already fully retrieved; nothing to do.



In [7]:
# Set device to GPU or CPU
if doCPU:
    device = torch.device('cpu')
    print("Running on CPU")
else:
    device = torch.device('cuda:0')
    print('Using device:', device, file=sys.stderr)

  
# Load network
with dnnlib.util.open_url(network_name) as f:
    G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore

# Generate latent vectors and conditioning vectors on CPU
zs = torch.randn([10000, G.mapping.z_dim], device=device)
cs = torch.zeros([10000, G.mapping.c_dim], device=device)
for i in range(cs.shape[0]):
    cs[i, i // 10] = 1

w_stds = G.mapping(zs, cs)
w_stds = w_stds.reshape(10, 1000, G.num_ws, -1)
w_stds=w_stds.std(0).mean(0)[0]
w_all_classes_avg = G.mapping.w_avg.mean(0)

Using device: cuda:0


Setting up PyTorch plugin "bias_act_plugin"... 

If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


Done.


In [8]:
if doCPU:
    print('Only doing this on GPU')
else:
    tf = Compose([
      # Resize(224),
      lambda x: torch.clamp((x+1)/2,min=0,max=1),
    ])

    t = torch.cuda.get_device_properties(0).total_memory
    r = torch.cuda.memory_reserved(0)
    a = torch.cuda.memory_allocated(0)
    t,r,a

In [31]:
# Set up a loop for image collection
# CLASSES
class_idx1 = np.arange(0,1000)
class_idx2 = np.zeros(class_idx1.shape)

#class_idx1 = []
#class_idx2 = []

# STEPS
#steps = np.arange(0,1.1,0.1) # all steps
steps = [0] # EXAMPLAR ONLY

# SHOW IMAGES?
doDisplayImages = False

In [32]:
for idx, x in np.ndenumerate(class_idx1):
    
    # Assume G.mapping.z_dim and G.mapping.c_dim are defined, as well as tf function
    # Double-check if cuda is available, otherwise use CPU
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Define the initial and final class indices
    class_idx_1 = int(class_idx1[idx])  # Replace with the starting class index
    class_idx_2 = int(class_idx2[idx])  # Replace with the target class index
    if len(steps) == 1 and steps[0] == 0:
        img_lable = 'c' + str(class_idx_1).zfill(4)
    else:
        img_lable = 'c' + str(class_idx_1).zfill(4) + '-' + str(class_idx_2).zfill(4) 
    
    # Generate the initial latent vector and condition vector
    zs = torch.randn([1, G.mapping.z_dim], device=device)
    cs1 = torch.zeros([1, G.mapping.c_dim], device=device)
    cs1[0][class_idx_1] = 1
    
    # Map to the intermediate latent space
    w_stds1 = G.mapping(zs, cs1)
    w_avg1 = G.mapping.w_avg
    w_avg1 = w_avg1[cs1[0].bool()]
    w_avg1 = w_avg1.unsqueeze(1).repeat(1, G.mapping.num_ws, 1)
    w_stds1 = w_avg1
    
    # Generate the final latent vector and condition vector
    zs = torch.randn([1, G.mapping.z_dim], device=device)
    cs2 = torch.zeros([1, G.mapping.c_dim], device=device)
    cs2[0][class_idx_2] = 1
    
    # Map to the intermediate latent space
    w_stds2 = G.mapping(zs, cs2)
    w_avg2 = G.mapping.w_avg
    w_avg2 = w_avg2[cs2[0].bool()]
    w_avg2 = w_avg2.unsqueeze(1).repeat(1, G.mapping.num_ws, 1)
    w_stds2 = w_avg2
    
    # Generate and display images morphing from class 1 to class 2 in 10 steps
    #steps = np.arange(0,1.1,0.1)
    #steps = [0,1]
    for alpha in steps:
        w_interp = (1 - alpha) * w_stds1 + alpha * w_stds2
        caio = G.synthesis(w_interp, noise_mode='const')
        if doDisplayImages:
            display(TF.to_pil_image(tf(caio)[0]))
        img_lableidx = img_lable + '_i' + str(round(10*alpha)).zfill(2)
        img = TF.to_pil_image(tf(caio)[0])
        if len(steps) == 1 and steps[0] == 0:
            os.makedirs(os.path.join('imgout', 'examplar'),exist_ok=True)
            img.save(os.path.join('imgout', 'examplar', img_lableidx + '.png'))
        else:
            os.makedirs(os.path.join('imgout', img_lable),exist_ok=True)
            img.save(os.path.join('imgout', img_lable, img_lableidx + '.png'))
            

IndexError: index 1000 is out of bounds for dimension 0 with size 1000