In [None]:
!pip install diffusers transformers scipy ftfy accelerate torch open_clip_torch

In [None]:
from diffusers import AutoPipelineForText2Image
import torch
from torchvision.models import resnet50, ResNet50_Weights

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# diffusion tti model optimized for limited resources
sd_pipe = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16").to(device)

# ResNet50 classification model
weights = ResNet50_Weights.IMAGENET1K_V2
preprocess = weights.transforms() # preprocessing transformation
classifier = resnet50(weights=weights)
classifier.eval().to(device)

In [None]:
# evaluation by ResNet50
def evaluate(tti, labels, n=1):
    overall_agg = 0
    for label in labels:
        class_agg = 0
        for _ in range(n):
            # generate image from prompt
            img = tti(label)

            # preprocess image and input into ResNet50
            batch = preprocess(img).unsqueeze(0).to(device)
            prediction = classifier(batch).squeeze(0).softmax(0)

            # obtain top category and corresponding score
            class_id = prediction.argmax().item()
            score = prediction[class_id].item()
            category_name = weights.meta["categories"][class_id]

            class_agg += score # update class aggregate score
            overall_agg += score # update overall aggregate score
        print(f"Prompt: {label} | Predicted category: {category_name} | Average score: {100 * class_agg / n:.1f}%")

    print("\n")
    print(f"Iterations per class: {n} | Average score: {100 * overall_agg / (len(labels) * n):.1f}%")

In [None]:
# 10 classes from ImageNet1K
imagenet_classes = [
    "tench, Tinca tinca",
    "goldfish, Carassius auratus",
    "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
    "tiger shark, Galeocerdo cuvieri",
    "hammerhead, hammerhead shark",
    "electric ray, crampfish, numbfish, torpedo",
    "stingray",
    "cock",
    "hen",
    "ostrich, Struthio camelus"
]

In [None]:
# Method 1: use class name as prompt
def classNameGeneration(label):
    return sd_pipe(label).images[0]

In [None]:
evaluate(classNameGeneration, imagenet_classes)

In [None]:
# Method 2: use template "A photo of a <class name>"
def classNameTemplateGeneration(label):
    return sd_pipe(f"A photo of a {label}").images[0]

In [None]:
evaluate(classNameTemplateGeneration, imagenet_classes)

In [None]:
import os

# download and store images from imagenet-sample-images repo

# create directory to store images
if not os.path.exists("imagenet_samples"):
    os.makedirs("imagenet_samples")

# download images and store in imagenet_samples directory
!wget -q https://github.com/EliSchwartz/imagenet-sample-images/blob/master/n01440764_tench.JPEG?raw=true -O imagenet_samples/tench.jpg
!wget -q https://github.com/EliSchwartz/imagenet-sample-images/blob/master/n01443537_goldfish.JPEG?raw=true -O imagenet_samples/goldfish.jpg
!wget -q https://github.com/EliSchwartz/imagenet-sample-images/blob/master/n01484850_great_white_shark.JPEG?raw=true -O imagenet_samples/great_white_shark.jpg
!wget -q https://github.com/EliSchwartz/imagenet-sample-images/blob/master/n01491361_tiger_shark.JPEG?raw=true -O imagenet_samples/tiger_shark.jpg
!wget -q https://github.com/EliSchwartz/imagenet-sample-images/blob/master/n01494475_hammerhead.JPEG?raw=true -O imagenet_samples/hammerhead.jpg
!wget -q https://github.com/EliSchwartz/imagenet-sample-images/blob/master/n01496331_electric_ray.JPEG?raw=true -O imagenet_samples/electric_ray.jpg
!wget -q https://github.com/EliSchwartz/imagenet-sample-images/blob/master/n01498041_stingray.JPEG?raw=true -O imagenet_samples/stingray.jpg
!wget -q https://github.com/EliSchwartz/imagenet-sample-images/blob/master/n01514668_cock.JPEG?raw=true -O imagenet_samples/cock.jpg
!wget -q https://github.com/EliSchwartz/imagenet-sample-images/blob/master/n01514859_hen.JPEG?raw=true -O imagenet_samples/hen.jpg
!wget -q https://github.com/EliSchwartz/imagenet-sample-images/blob/master/n01518878_ostrich.JPEG?raw=true -O imagenet_samples/ostrich.jpg

# list of image names
img_names = ["tench", "goldfish", "great_white_shark", "tiger_shark", "hammerhead",
             "electric_ray", "stingray", "cock", "hen", "ostrich"]

In [None]:
# Method 3: Use contrastive captioning (CoCa) to generate prompts from ImageNet images
import open_clip
from PIL import Image

# CoCa model and preprocessing transformation
CoCa, _, transform = open_clip.create_model_and_transforms(
    model_name="coca_ViT-L-14",
    pretrained="mscoco_finetuned_laion2B-s13B-b90k"
)
CoCa.to(device)

# generate captions with CoCa
captions = []
for name in img_names:
    # get image from imagenet_samples directory
    path = "/content/imagenet_samples/" + name + ".jpg"
    im = Image.open(path).convert("RGB")
    im = transform(im).unsqueeze(0).to(device)

    # generate caption for image
    with torch.no_grad(), torch.cuda.amp.autocast():
        generated = CoCa.generate(im)

    captions.append(open_clip.decode(generated[0]).split("<end_of_text>")[0].replace("<start_of_text>", ""))

In [None]:
evaluate(classNameGeneration, captions)

In [None]:
import torchvision.transforms as transforms
from diffusers import StableDiffusionImageVariationPipeline

# image variation model
sdiv_pipe = StableDiffusionImageVariationPipeline.from_pretrained(
  "lambdalabs/sd-image-variations-diffusers",
  revision="v2.0",
)
sdiv_pipe = sdiv_pipe.to(device)

# preprocessing transformation
tform = transforms.Compose([
      transforms.ToTensor(),
      transforms.Resize(
          (224, 224),
          interpolation=transforms.InterpolationMode.BICUBIC,
          antialias=False,
          ),
      transforms.Normalize(
        [0.48145466, 0.4578275, 0.40821073],
        [0.26862954, 0.26130258, 0.27577711]),
  ])

In [None]:
# Method 4: Image Variation
from PIL import Image

N = 1 # iterations per image

agg = 0
for name in img_names:
    img_agg = 0
    for _ in range(N):
        # get image from imagenet_samples directory
        path = "/content/imagenet_samples/" + name + ".jpg"
        im = Image.open(path).convert("RGB")

        # apply image variation
        im = tform(im).unsqueeze(0).to(device)
        img = sdiv_pipe(im, guidance_scale=3).images[0]

        # preprocess varied image and input into ResNet50
        batch = preprocess(img).unsqueeze(0).to(device)
        prediction = classifier(batch).squeeze(0).softmax(0)

        # obtain top category and corresponding score
        class_id = prediction.argmax().item()
        score = prediction[class_id].item()
        category_name = weights.meta["categories"][class_id]

        img_agg += score # update image-specific aggregate score
        agg += score # update overall aggregate score
    print(f"Pre-varied image: {name} | Predicted category: {category_name} | Score: {100 * img_agg / N:.1f}%")

print("\n")
print(f"Iterations per class: {N} | Average score: {100 * agg / (len(img_names) * N):.1f}%")

In [None]:
# Removes NSFW filter
sd_pipe.safety_checker = lambda images, clip_input: (images, [False])