In [None]:
# Imports
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from diffusers import ScoreSdeVePipeline, ScoreSdeVeScheduler, UNet2DModel
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
from datasets import load_dataset, Dataset, DatasetDict
from torchvision.models import mobilenet_v3_small
from torchvision.transforms import Resize, PILToTensor
from tqdm import tqdm
import torch
from typing import List, Optional, Tuple, Union

In [None]:
# Hyperparameters
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'     # Update this line if you want to use a different device such as TPU or Macbook's MPS
PRETRAINED = "google/ncsnpp-celebahq-256"
DATASET_SOURCE = "Ryan-sjtu/celebahq-caption"

N_INFERENCE_STEPS = 1000
BATCH_SIZE = 32
N_EPOCHS = 100
NOISE_BATCH = 10 

In [None]:
# Load
unconditional_pipeline = ScoreSdeVePipeline.from_pretrained(PRETRAINED).to(device=DEVICE)
unconditional_scheduler: ScoreSdeVeScheduler = unconditional_pipeline.scheduler
unet = UNet2DModel.from_pretrained(PRETRAINED)
dataset: DatasetDict = load_dataset(DATASET_SOURCE)

# Demonstrate unconditional generation
    # Also helps to set inference steps
example_image = unconditional_pipeline(num_inference_steps=N_INFERENCE_STEPS).images[0]
example_image

In [None]:
# Create classifer model

model = mobilenet_v3_small()
model = nn.Sequential(
    model.features,
    model.avgpool,
    nn.Flatten(),
    nn.Dropout(0.2, inplace=True),
    nn.Linear(in_features=576, out_features=1),
    nn.Sigmoid()
)

# Simple CNN
# model = nn.Sequential(
#     nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5),
#     nn.LeakyReLU(),
#     nn.MaxPool2d(kernel_size=5),
#     nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5),
#     nn.LeakyReLU(),
#     nn.MaxPool2d(kernel_size=3),
#     nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5),
#     nn.LeakyReLU(),
#     nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3),
#     nn.LazyLinear(out_features=128),
#     nn.LeakyReLU(),
#     nn.Linear(in_features=128, out_features=1),
#     nn.Sigmoid()
# )

model = model.to(device=DEVICE)

In [None]:
# Construct dataloader

# Utils
def process(img):
    return img.to(torch.float32).transpose(-1,-3).transpose(-1,-2) / 256
resize = Resize(unconditional_pipeline.unet.config.sample_size)

def compute_accuracy(dataloader, model, loss_fn):
    model.eval()
    num_batches, size, test_loss, correct = 0, 0, 0, 0
    with torch.no_grad():
        for data in dataloader:
            num_batches += 1
            if num_batches >= 10:
                break
            X, y = data['image'], data['label']
            size += X.shape[0]
            X = process(X)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += ((pred > 0.5) == y).sum().item()
    test_loss /= num_batches
    correct /= size
    return correct, test_loss

# Dataset
train_ds: Dataset = dataset['train'].with_format('torch', device=DEVICE)
train_ds = train_ds.map(lambda x: {
    'label': 
        torch.tensor([
            1 if process(x['image']).mean() > 0.43 else 0     # Toy Problem
#           1 if 'woman' in x['text'] else 0                  # Gender Classification
        ],
        dtype=torch.float32,
        device=DEVICE)
        
})
dataloader = DataLoader(train_ds, BATCH_SIZE)

# Method 1: Train Classifier

In [None]:
# Train Classifier
optim = torch.optim.Adam(model.parameters())
loss_fn = nn.BCELoss()
sig = nn.Sigmoid().to(device=DEVICE)

with tqdm(total=N_EPOCHS) as pbar:
    for epoch in range(N_EPOCHS):

        accuracy, loss = compute_accuracy(dataloader, model, loss_fn)
        pbar.set_description(f"accuracy {accuracy} loss {loss}")
        pbar.refresh()

        for i, data in enumerate(dataloader):
            inputs, labels = data['image'], data['label']

            # Preprocess image
            inputs = process(inputs)
            inputs = resize(inputs).unsqueeze(1)
            inputs = unconditional_scheduler.add_noise(
                original_samples=inputs,
                noise=None,
                timesteps=torch.randint(
                    low=0,
                    high=len(unconditional_scheduler.discrete_sigmas),
                    size= (NOISE_BATCH, )
                ).to(device=DEVICE)
            ).flatten(0,1)
            labels = labels.expand(labels.shape[0],NOISE_BATCH).flatten().unsqueeze(-1)
            # Training
            optim.zero_grad()
            predicted = model(inputs)
            loss = loss_fn(predicted, labels)
            loss.backward()
            optim.step()
        
        # Logging
        pbar.update(1)

In [None]:
# Checkpoints
classifier_name = 'toy_classifier.pt'

# torch.save(model, classifier_name)
model = torch.load(classifier_name)

In [None]:
class BinaryClassConditionalScoreSdeVePipeline(ScoreSdeVePipeline):
    def add_conditional_gradient(self, sample, output, classifier, target, debug=False):
        with torch.enable_grad():
            sample.requires_grad = True
            predicted = torch.log(classifier(sample))
            predicted.sum().backward()
            res = sample.grad
            if not target:
                res = -res
            output += res
        return output

    @torch.no_grad()
    def __call__(
        self,
        classifier,
        target: bool = True,
        batch_size: int = 1,
        num_inference_steps: int = 2000,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        **kwargs,
    ) -> Union[ImagePipelineOutput, Tuple]:
        img_size = self.unet.config.sample_size
        shape = (batch_size, 3, img_size, img_size)

        model = self.unet

        sample = randn_tensor(shape, generator=generator) * self.scheduler.init_noise_sigma
        sample = sample.to(self.device)

        self.scheduler.set_timesteps(num_inference_steps)
        self.scheduler.set_sigmas(num_inference_steps)

        for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
            sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device)

            # correction step
            for _ in range(self.scheduler.config.correct_steps):
                model_output = self.unet(sample, sigma_t).sample
                model_output = self.add_conditional_gradient(sample, model_output, classifier, target)
                sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample

            # prediction step
            model_output = model(sample, sigma_t).sample
            model_output = self.add_conditional_gradient(sample, model_output, classifier, target)
            output = self.scheduler.step_pred(model_output, t, sample, generator=generator)

            sample, sample_mean = output.prev_sample, output.prev_sample_mean
        
        sample = sample_mean.clamp(0, 1)
        sample = sample.cpu().permute(0, 2, 3, 1).numpy()
        if output_type == "pil":
            sample = self.numpy_to_pil(sample)

        if not return_dict:
            return (sample,)

        return ImagePipelineOutput(images=sample)


In [None]:
# Example

conditional_pipeline = BinaryClassConditionalScoreSdeVePipeline(unconditional_pipeline.unet, unconditional_pipeline.scheduler)  
images = conditional_pipeline(model, batch_size=1, target=False, num_inference_steps=N_INFERENCE_STEPS).images
t = PILToTensor()
sample = (t(images[0]) / 256).unsqueeze(0).to(device=DEVICE)
print(f"Predicted: {model(sample).item()}")
images[0]