# CLIP-guided StyleGAN2 image generation

This notebook is a direct conversion of `generate_images_from_prompt.py`.

It performs CLIP-guided latent optimization on a StyleGAN2 generator (.pkl)
to create an image matching a text prompt.

Usage notes:
- Provide a local StyleGAN2 `.pkl` (NVIDIA/stylegan2-ada format) and set `pkl_path` in the config cell.
- Ensure the stylegan2-ada `legacy.py` and `dnnlib` are on your PYTHONPATH.


In [None]:
# Install (uncomment to run in notebook)
# !pip install -r ../src/requirements.txt

import os
import sys
from pathlib import Path
import torch
from torchvision import transforms
from PIL import Image
import clip
from tqdm.notebook import tqdm
from IPython.display import display

# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

In [None]:
# Notebook configuration (replace values as needed)
prompt = 'a fantasy castle on a cliff at sunset'
pkl_path = '/path/to/stylegan2-ffhq.pkl'  # <- set to your local checkpoint
outdir = Path('outputs')
steps = 300
lr = 0.1
seed = 42

outdir.mkdir(parents=True, exist_ok=True)
torch.manual_seed(seed)
print('Outdir:', outdir)

In [None]:
print('Loading CLIP model...')
clip_model, clip_preprocess = clip.load('ViT-B/32', device=device)
clip_model.eval()
print('CLIP loaded on', device)

In [None]:
def load_stylegan2_g(pkl_path, device):
    try:
        import dnnlib
        import legacy
    except Exception:
        raise RuntimeError(
            'Missing stylegan2-ada loader dependencies.\n'
            'Clone https://github.com/NVlabs/stylegan2-ada-pytorch and ensure legacy.py and dnnlib are on PYTHONPATH.'
        )
    with open(pkl_path, 'rb') as f:
        G = legacy.load_network_pkl(f)['G_ema'].to(device)
    G.eval()
    return G

# Usage example (uncomment and set pkl_path):
# G = load_stylegan2_g(pkl_path, device)

In [None]:
def synth_image_from_w(G, w):
    with torch.no_grad():
        if w.ndim == 2:
            w_in = w.unsqueeze(1).repeat(1, G.num_ws, 1)
        else:
            w_in = w
        img = G.synthesis(w_in, noise_mode='const')
    img = (img.clamp(-1, 1) + 1) / 2
    return img


def preprocess_for_clip(img_tensor):
    pil = transforms.ToPILImage()(img_tensor.squeeze(0).cpu())
    img = transforms.Resize((224, 224))(pil)
    img = transforms.ToTensor()(img).unsqueeze(0)
    return img.to(device)


In [None]:
def run_optimization(prompt, pkl_path, outdir, steps=300, lr=0.1, seed=42, device=device, display_every=50):
    torch.manual_seed(seed)
    clip_model.eval()
    G = load_stylegan2_g(pkl_path, device)

    # initialize latent
    z = torch.randn(1, G.z_dim, device=device)
    with torch.no_grad():
        w = G.mapping(z, None)
    w_opt = w.detach().clone()
    w_opt.requires_grad = True

    # text feature
    text_tokens = clip.tokenize([prompt]).to(device)
    with torch.no_grad():
        text_feat = clip_model.encode_text(text_tokens)
        text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True)

    optimizer = torch.optim.Adam([w_opt], lr=lr)

    best_score = float('-inf')
    best_img = None

    for i in range(steps):
        optimizer.zero_grad()
        img = synth_image_from_w(G, w_opt)
        clip_in = preprocess_for_clip(img)
        image_feat = clip_model.encode_image(clip_in)
        image_feat = image_feat / image_feat.norm(dim=-1, keepdim=True)

        similarity = (image_feat @ text_feat.T).squeeze()
        loss = -similarity
        loss.backward()
        optimizer.step()

        score = float(similarity.item())
        if score > best_score:
            best_score = score
            best_img = img.detach().cpu()

        if (i + 1) % display_every == 0 or i == steps - 1:
            print(f"step {i+1}/{steps} score={score:.4f} best={best_score:.4f}")
            if best_img is not None:
                display(Image.fromarray((best_img.squeeze(0) * 255).clamp(0, 255).permute(1, 2, 0).numpy().astype('uint8')))

    out_path = Path(outdir) / f"result_prompt_{seed}.png"
    img_to_save = (best_img.squeeze(0) * 255).clamp(0, 255).permute(1, 2, 0).numpy().astype('uint8')
    Image.fromarray(img_to_save).save(out_path)
    print(f"Saved best image (score={best_score:.4f}) to {out_path}")
    return out_path, best_score


In [None]:
# Example run (end-to-end)
if not Path(pkl_path).exists():
    print('Please set `pkl_path` to a valid StyleGAN2 .pkl file and re-run this cell.')
else:
    out_path, score = run_optimization(prompt, pkl_path, outdir, steps=steps, lr=lr, seed=seed, device=device, display_every=max(1, steps//10))
    print('Done:', out_path, 'score=', score)
    display(Image.open(out_path))

**Notes**:

- This notebook requires a StyleGAN2 `.pkl` checkpoint in NVIDIA/stylegan2-ada format (contains `G_ema`).
- If you don't have `legacy.py` and `dnnlib` available, clone https://github.com/NVlabs/stylegan2-ada-pytorch and add it to `PYTHONPATH`.
- For faster runs use a GPU-enabled container and reduce `steps` for quick previews.
