In [None]:
# Imports
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from diffusers import ScoreSdeVePipeline, ScoreSdeVeScheduler, UNet2DModel
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
from datasets import load_dataset, Dataset, DatasetDict
from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
from torchvision.transforms import Resize, PILToTensor, ToPILImage
from tqdm import tqdm
import torch
from typing import List, Optional, Tuple, Union

In [None]:
# Hyperparameters
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
PRETRAINED = "google/ncsnpp-celebahq-256"
DATASET_SOURCE = "Ryan-sjtu/celebahq-caption"

N_INFERENCE_STEPS = 1000
BATCH_SIZE = 32
N_EPOCHS = 20 
NOISE_BATCH = 10 
N_CLASSES = 2

In [None]:
# Load
unconditional_pipeline = ScoreSdeVePipeline.from_pretrained(PRETRAINED).to(device=DEVICE)
unconditional_scheduler: ScoreSdeVeScheduler = unconditional_pipeline.scheduler
unet = UNet2DModel.from_pretrained(PRETRAINED)
dataset: DatasetDict = load_dataset(DATASET_SOURCE)

# Demonstrate unconditional generation
    # Also helps to set inference steps
# example_image = unconditional_pipeline(num_inference_steps=N_INFERENCE_STEPS).images[0]
#example_image

In [None]:
# Create classifer model

class ClassificationNet(nn.Module):
    """Network for classification based on mobilenet
    """

    def __init__(self, n_classes: int=2, rff_dim: Optional[int] = None):
        super().__init__()

        if rff_dim is None:
            rff_dim = 0
        mobile_net = mobilenet_v3_small(MobileNet_V3_Small_Weights.DEFAULT)
        self.cnn = nn.Sequential(
            mobile_net.features,
            mobile_net.avgpool,
            nn.Flatten(),
            nn.Dropout(0.2, inplace=True),
        )
        self.fc = nn.Sequential(
            nn.Linear(in_features=576 + rff_dim, out_features=64),
            nn.LeakyReLU(),
            nn.Linear(in_features=64, out_features=n_classes)
        )
    
    def forward(self, x):
        # TODO: RFF
        x = self.cnn(x)
        x = self.fc(x)
        return x


model = ClassificationNet(n_classes=N_CLASSES).to(device=DEVICE)

In [None]:
# Dataset
train_ds: Dataset = dataset['train'].with_format('torch', device=DEVICE)
train_ds = train_ds.map(lambda x: {
    'label': 
        torch.tensor([
            1 if 'woman' in x['text'] else 0,
            0 if 'woman' in x['text'] else 1
        ],
        dtype=torch.float32,
        device=DEVICE)
})
dataloader = DataLoader(train_ds, BATCH_SIZE)

# Class Weightings
weights = torch.zeros((N_CLASSES,)).to(device=DEVICE)
for i, data in enumerate(dataloader):
    labels = data['label']
    weights += torch.sum(labels, dim=0)
weights = (sum(weights) / N_CLASSES) / weights

In [None]:
# Utils
resize = Resize(unconditional_pipeline.unet.config.sample_size)
def process(img):
    return resize(img.to(torch.float32).transpose(-1,-3).transpose(-1,-2) / 256)

tensor_to_PIL = ToPILImage()

softmax = torch.nn.Softmax()
def compute_accuracy(dataloader, model, loss_fn):
    model.eval()
    num_batches, size, test_loss, correct = 0, 0, 0, 0
    with torch.no_grad():
        for data in dataloader:
            num_batches += 1
            if num_batches >= 10:
                break
            X, y = data['image'], data['label']
            size += X.shape[0]
            X = process(X)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            pred = softmax(pred)
            correct += torch.all(((pred > 0.5) == y), dim=1).sum().item()
    test_loss /= num_batches
    correct /= size
    return correct, test_loss

# Method 1: Train Noise Conditional Classifier

In [None]:
# Train Classifier
optim = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
loss_fn = nn.CrossEntropyLoss(weight=weights)

with tqdm(total=N_EPOCHS) as pbar:
    for epoch in range(N_EPOCHS):

        accuracy, loss = compute_accuracy(dataloader, model, loss_fn)
        pbar.set_description(f"accuracy {accuracy} loss {loss}")
        pbar.refresh()

        for i, data in enumerate(dataloader):
            inputs, labels = data['image'], data['label']
            if i > 10:
                break
            # Preprocess image
            inputs = process(inputs)
#             inputs = inputs.unsqueeze(1)
                # Add noise for conditioning
            # inputs = unconditional_scheduler.add_noise(
            #     original_samples=inputs,
            #     noise=None,
            #     timesteps=torch.randint(
            #         low=0,
            #         high=len(unconditional_scheduler.discrete_sigmas) // 5,
            #         size= (NOISE_BATCH, )
            #     ).to(device=DEVICE)
            # ).flatten(0,1)

            #labels = labels.unsqueeze(0).expand(NOISE_BATCH, *labels.shape).flatten(0,1)
            # Training
            optim.zero_grad()
            predicted = model(inputs)
            loss = loss_fn(predicted, labels)
            loss.backward()
            optim.step()
        
        # Logging
        pbar.update(1)

In [None]:
# Checkpoints
classifier_name = 'uncon_gender_classifier.pt'

#torch.save(model, classifier_name)
model = torch.load(classifier_name)

In [None]:
for i, data in enumerate(dataloader):
    eval_inputs, eval_labels = data['image'], data['label']
    break
eval_inputs = process(inputs)
noise = torch.randn_like(eval_inputs)
model(noise)
tensor_to_PIL(noise)

In [None]:
def tweedie(x_t, score, sigma):
  return x_t + score * sigma**2

class NoiselessClassConditionalPipeline(ScoreSdeVePipeline):
    def add_conditional_gradient(self, sample, score, sigma, classifier, target, debug=False):
        with torch.enable_grad():
            sample.requires_grad = True
            x_hat = tweedie(sample, score, sigma)
            class_probabilities = softmax(classifier(x_hat))
            target = class_probabilities[:, target]
            torch.log(target).sum().backward()
            if debug:
                print(target)
            score += sample.grad * 5
        return score

    @torch.no_grad()
    def __call__(
        self,
        classifier,
        target: bool = True,
        batch_size: int = 1,
        num_inference_steps: int = 2000,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        **kwargs,
    ) -> Union[ImagePipelineOutput, Tuple]:
        img_size = self.unet.config.sample_size
        shape = (batch_size, 3, img_size, img_size)

        model = self.unet
        sample = randn_tensor(shape, generator=generator) * self.scheduler.init_noise_sigma
        sample = sample.to(self.device)

        self.scheduler.set_timesteps(num_inference_steps)
        self.scheduler.set_sigmas(num_inference_steps)

        for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
            sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device)
            std = self.scheduler.sigmas[i]

            # correction step
            for _ in range(self.scheduler.config.correct_steps):
                model_output = self.unet(sample, sigma_t).sample
                model_output = self.add_conditional_gradient(sample, model_output, std, classifier, target, debug=i%50==0)
                sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample

            # prediction step
            model_output = model(sample, sigma_t).sample
            model_output = self.add_conditional_gradient(sample, model_output, std, classifier, target, debug=i%50==0)

            if i % 100 == 0:
                x_hat = tweedie(sample, model_output, std)
                display(tensor_to_PIL(x_hat[0]))

            output = self.scheduler.step_pred(model_output, t, sample, generator=generator)

            sample, sample_mean = output.prev_sample, output.prev_sample_mean
        
        sample = sample_mean.clamp(0, 1)
        sample = sample.cpu().permute(0, 2, 3, 1).numpy()
        if output_type == "pil":
            sample = self.numpy_to_pil(sample)

        if not return_dict:
            return (sample,)

        return ImagePipelineOutput(images=sample)

In [None]:
# Example

conditional_pipeline = NoiselessClassConditionalPipeline(unconditional_pipeline.unet, unconditional_pipeline.scheduler)  
images = conditional_pipeline(model, batch_size=9, target=1, num_inference_steps=500).images
t = PILToTensor()
sample = (t(images[0]) / 256).unsqueeze(0).to(device=DEVICE)
print(f"Predicted: {model(sample).argmax()}")
images[0]

In [None]:
results = [0,0]
for img in images:
    sample = (t(img) / 256).unsqueeze(0).to(device=DEVICE)
    print(model(sample))
    results[model(sample).argmax()] += 1
results

In [None]:
blank = torch.zeros((3,256,256)).to(device=DEVICE)
optim = torch.optim.Adam(params=[blank])

print(model(blank.unsqueeze(0)))
for i in range(2000):
    blank.requires_grad = True
    optim.zero_grad()
    p = -model(blank.unsqueeze(0))
    p[0][1].backward()
    optim.step()
    blank.requires_grad = False
print(model(blank.unsqueeze(0)))
blank.grad = None
tensor_to_PIL(blank)

In [None]:

std = conditional_pipeline.scheduler.sigmas[400]
z = torch.randn_like(blank) * std
blank_noised = blank + z
print(std)
x0 = tweedie(blank_noised, conditional_pipeline.unet(blank_noised.unsqueeze(0), torch.tensor(std)).sample, std)
print(model(x0))
tensor_to_PIL(x0.squeeze())