In [1]:
import torch
import clip

In [2]:
# Load CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

100%|████████████████████████████████████████| 338M/338M [06:36<00:00, 892kiB/s]


In [3]:
# Define sampler descriptions based on content
sampler_descriptions = {
    "DPM++ 2M Karras": "Portraits, faces, skin details, cinematic, soft lighting, expressions",
    "DPM++ SDE Karras": "Photorealism, ultra-HD, realistic lighting, hyper-detailed textures",
    "DPM++ 2M SDE Karras": "Extreme photorealism, hyper-detailed skin, deep shadows, cinematic depth",
    "Euler": "Illustrations, sketches, balanced style, general rendering",
    "Euler a": "Anime, manga, vibrant colors, fantasy, cel-shading, expressive characters",
    "DDIM": "Dreamy, abstract, concept art, surreal, artistic effects"
}

In [4]:
# Function to predict the best sampler using CLIP
def suggest_sampler_clip(prompt):
    text_inputs = clip.tokenize([prompt] + list(sampler_descriptions.values())).to(device)
    with torch.no_grad():
        text_features = model.encode_text(text_inputs)
    
    prompt_embedding = text_features[0]  # The input prompt
    sampler_embeddings = text_features[1:]  # The sampler descriptions
    
    # Compute cosine similarity
    similarities = torch.nn.functional.cosine_similarity(prompt_embedding.unsqueeze(0), sampler_embeddings)
    best_match_idx = similarities.argmax().item()
    
    return list(sampler_descriptions.keys())[best_match_idx]

In [7]:
# Example usage
prompt_text = "kimi no na wa., building, cityscape, cloud, cloudy sky, gradient sky, lens flare, no humans, outdoors, power lines, scenery, shooting star, sky, sparkle, star \(sky\), starry sky, sunset, tree, utility pole"
best_sampler = suggest_sampler_clip(prompt_text)

In [8]:
print("Suggested Sampler:", best_sampler)

Suggested Sampler: Euler a
