In [None]:
# Setup
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from diffusers import ScoreSdeVePipeline, ScoreSdeVeScheduler, UNet2DModel
from diffusers.models.unet_2d import UNet2DOutput
from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
from diffusers.utils import numpy_to_pil
from diffusers.utils.torch_utils import randn_tensor
from datasets import load_dataset, Dataset, DatasetDict
from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
from torchvision.transforms import Resize
from matplotlib import pyplot as plt
from tqdm import tqdm
from typing import List, Optional, Tuple, Union
import torch

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'     # Update this line if you want to use a different device such as TPU or Macbook's MPS
PRETRAINED = "google/ncsnpp-celebahq-256"
DATASET_SOURCE = "Ryan-sjtu/celebahq-caption"

N_INFERENCE_STEPS = 2000

In [None]:
# Load
unconditional_pipeline = ScoreSdeVePipeline.from_pretrained(PRETRAINED).to(device=DEVICE)
unconditional_scheduler: ScoreSdeVeScheduler = unconditional_pipeline.scheduler
unet = UNet2DModel.from_pretrained(PRETRAINED)
dataset: DatasetDict = load_dataset(DATASET_SOURCE)

# Demonstrate unconditional generation
    # Also helps to set inference steps
# example_image = unconditional_pipeline(num_inference_steps=N_INFERENCE_STEPS).images[0]
# example_image

In [None]:
# Setup Dataset

BATCH_SIZE = 64
N_EPOCHS = 10

model = mobilenet_v3_small()

# We use an adjusted mobilenet small for classification
model = nn.Sequential(
    model.features,
    model.avgpool,
    nn.Flatten(),
    nn.Dropout(0.2, inplace=True),
    nn.Linear(in_features=576, out_features=1),
    nn.Sigmoid()
)
model = model.to(device=DEVICE)

# Datase
train_ds: Dataset = dataset['train'].with_format('torch', device=DEVICE)

train_ds = train_ds.map(lambda x: {
    'label': torch.ones(1, device=DEVICE) if 'woman' in x['text'] else torch.zeros(1, device=DEVICE)
})
dataloader = DataLoader(train_ds, BATCH_SIZE)

In [None]:
NOISE_BATCH = 10 

optim = torch.optim.Adam(model.parameters())
loss_fn = nn.BCELoss()
sig = nn.Sigmoid().to(device=DEVICE)
resize = Resize(unconditional_pipeline.unet.config.sample_size)

with tqdm(total=N_EPOCHS) as pbar:
    for epoch in range(N_EPOCHS):
        running_loss = 0
        for i, data in enumerate(dataloader):
            inputs, labels = data['image'], data['label']

            # Preprocess image
            inputs = inputs.to(torch.float32).transpose(-1,-3) / 256
            inputs = resize(inputs).unsqueeze(1)
            inputs = unconditional_scheduler.add_noise(
                original_samples=inputs,
                noise=None,
                timesteps=torch.randint(
                    low=0,
                    high=len(unconditional_scheduler.discrete_sigmas),
                    size= (NOISE_BATCH, )
                ).to(device=DEVICE)
            ).flatten(0,1)
            labels = labels.expand(labels.shape[0],NOISE_BATCH).flatten().unsqueeze(-1)
            # Training
            optim.zero_grad()
            predicted = model(inputs)
            loss = loss_fn(predicted, labels)
            running_loss += loss.item()
            loss.backward()
            optim.step()
        
        # Logging
        pbar.set_description(f"loss {running_loss}")
        pbar.update(1)

In [None]:
# Save and load classifier
torch.save(model, 'gender_classifier.pt')
# model = torch.load('gender_classifier.pt')

In [None]:
id = 0
img = 

In [None]:
from typing import Any, List, Optional, Tuple, Union

import torch
from diffusers.pipelines.pipeline_utils import ImagePipelineOutput

class BinaryClassConditionalScoreSdeVePipeline(ScoreSdeVePipeline):

    def add_conditional_gradient(self, sample, output, classifier, target):
        with torch.enable_grad():
            sample.requires_grad = True
            predicted = classifier(sample)
            predicted.sum().backward()
            res = sample.grad
            if not target:
                res = -res
            output += res
        return output

    @torch.no_grad()
    def __call__(
        self,
        classifier,
        target: bool = True,
        batch_size: int = 1,
        num_inference_steps: int = 2000,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        **kwargs,
    ) -> Union[ImagePipelineOutput, Tuple]:
        img_size = self.unet.config.sample_size
        shape = (batch_size, 3, img_size, img_size)

        model = self.unet

        sample = randn_tensor(shape, generator=generator) * self.scheduler.init_noise_sigma
        sample = sample.to(self.device)

        self.scheduler.set_timesteps(num_inference_steps)
        self.scheduler.set_sigmas(num_inference_steps)

        for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
            sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device)

            # correction step
            for _ in range(self.scheduler.config.correct_steps):
                model_output = self.unet(sample, sigma_t).sample
                model_output = self.add_conditional_gradient(sample, model_output, classifier, target)
                sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample

            # prediction step
            model_output = model(sample, sigma_t).sample
            model_output = self.add_conditional_gradient(sample, model_output, classifier, target)
            output = self.scheduler.step_pred(model_output, t, sample, generator=generator)

            sample, sample_mean = output.prev_sample, output.prev_sample_mean
        
        sample = sample_mean.clamp(0, 1)
        sample = sample.cpu().permute(0, 2, 3, 1).numpy()
        if output_type == "pil":
            sample = self.numpy_to_pil(sample)

        if not return_dict:
            return (sample,)

        return ImagePipelineOutput(images=sample)
    

    
conditional_pipeline = BinaryClassConditionalScoreSdeVePipeline(unconditional_pipeline.unet, unconditional_pipeline.scheduler)  
images = conditional_pipeline(model, batch_size=9, target=False, num_inference_steps=N_INFERENCE_STEPS).images

In [None]:
images[8]