In [1]:
from typing import List, Union

import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from torchvision.datasets import CIFAR10
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline



In [2]:
dataset = CIFAR10(root="~/.cache", download=True)

full_labels = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]

Files already downloaded and verified


In [3]:
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16,
    safety_checker=None,
    requires_safety_checker=False,
).to("cuda")
pipe.set_progress_bar_config(disable=True)

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

In [4]:
import statistics

import scipy  # .stats.ttest_ind
from tqdm import tqdm


preprocess = transforms.Compose(
    [
        transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
        # transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
        # transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)

def generative_classification_clark(
    model: StableDiffusionImg2ImgPipeline,
    image: Image.Image,
    classes: List[str],
    num_inference_steps: int = 500,
    min_scores: int = 20,
    max_scores: int = 2000,
    cutoff_pval: float = 5e-3,
    g = None,
):
    image = preprocess(image)
    # scheduler = model.scheduler
    
    scores = {_class: [] for _class in classes}
    n = 0
    denoised_images = []
    min_mean = float("inf")
    min_mean_class = None
    pbar = tqdm(total=max_scores)
    while len(scores) > 1 and n < max_scores:
        n += 1
        pbar.update(1)
        
        # Noise the image by sampling t ~ U[0,1] then generating x_t ~ q(x_t|x)
        # t = torch.randint(200, 201, (1,))  # 1, scheduler.num_train_timesteps, (1,))
        # noise = torch.randn(image.shape, generator=g)
        # noisy_image = scheduler.add_noise(image, noise, t).half().unsqueeze(0)
        
        # for _class, score in scores.items():
        # TODO: Text conditioning
        # TODO: Weighting function
        # image_latent = model.vae.encode(noisy_image.cuda()).latent_dist.sample() * model.vae.config.scaling_factor
        # prompt = f"A photo of a {_class}"
        denoised_images = model(
            prompt=[f"A photo of a {c}" for c in classes],
            image=image,
            generator=g,
            strength=0.5,
            num_inference_steps=num_inference_steps,
            output_type="np.array",
        ).images

        for i, (_class, score) in enumerate(scores.items()):
            score.append(int((image.permute(1, 2, 0) - denoised_images[i]).norm()))
            if (class_mean := statistics.mean(score)) < min_mean:
                min_mean = class_mean
                min_mean_class = _class
            
        if n >= min_scores:
            pvals = []
            for _class in list(scores.keys()):
                if _class == min_mean_class:
                    continue
                score = scores[_class]
                x = scipy.stats.ttest_ind(scores[min_mean_class], scores[_class]).pvalue
                pvals.append(x)
                if x < cutoff_pval:
                    del scores[_class]

    pbar.close()

    statistical_significance = len(scores) == 1
    
    return min_mean_class, statistical_significance

In [5]:
label_idxs = [3, 5]
dataset_max_size = 10

idxs = [i for i, (img, _class) in enumerate(dataset) if _class in label_idxs]
idxs = idxs[:min(len(idxs), dataset_max_size)]

filtered_labels = [l for i, l in enumerate(full_labels) if i in label_idxs]
print(filtered_labels)

['cat', 'dog']


In [22]:
g = torch.Generator(device='cpu')
g.manual_seed(42)

preds = []
significance = []
for i in idxs:
    raw_image, _class = dataset[i]
    print()
    pred, significant = generative_classification_clark(
        pipe,
        raw_image,
        filtered_labels,
        num_inference_steps=500,
        g=g,
        min_scores=10,
        max_scores=40,
        cutoff_pval=0.1,
    )
    print(f"Actual class: {full_labels[_class]}, Predicted class: {pred}")
    preds.append(full_labels.index(pred))
    significance.append(significant)




  deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
  deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
 28%|███████▉                     | 11/40 [01:55<05:04, 10.50s/it]


Actual class: cat, Predicted class: cat



100%|█████████████████████████████| 40/40 [09:44<00:00, 14.61s/it]


Actual class: cat, Predicted class: dog



 25%|███████▎                     | 10/40 [02:06<06:18, 12.60s/it]


Actual class: cat, Predicted class: dog



 25%|███████▎                     | 10/40 [01:44<05:14, 10.47s/it]


Actual class: cat, Predicted class: dog



 62%|██████████████████▏          | 25/40 [04:21<02:37, 10.47s/it]


Actual class: dog, Predicted class: dog



100%|█████████████████████████████| 40/40 [06:58<00:00, 10.47s/it]


Actual class: cat, Predicted class: cat



100%|█████████████████████████████| 40/40 [06:59<00:00, 10.48s/it]


Actual class: cat, Predicted class: dog



100%|█████████████████████████████| 40/40 [06:58<00:00, 10.47s/it]


Actual class: cat, Predicted class: dog



100%|█████████████████████████████| 40/40 [06:58<00:00, 10.47s/it]


Actual class: cat, Predicted class: cat



100%|█████████████████████████████| 40/40 [06:58<00:00, 10.46s/it]

Actual class: dog, Predicted class: dog





In [23]:
from sklearn.metrics import accuracy_score, confusion_matrix

labels = [dataset[i][1] for i in idxs]
print("Confusion Matrix")
print(confusion_matrix(preds, labels))
print()
print("Accuracy:", accuracy_score(preds, labels), "\n")
print("Significance:", sum(significance) / len(significance))

Confusion Matrix
[[3 0]
 [5 2]]

Accuracy: 0.5 

Significance: 0.4


In [24]:
from collections import Counter
Counter(labels)

Counter({3: 8, 5: 2})