In [None]:
# Imports
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
from diffusers import ScoreSdeVePipeline, ScoreSdeVeScheduler, UNet2DModel
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
from datasets import load_dataset, DatasetDict
from torchvision.models import mobilenet_v3_small
from torchvision.transforms import Resize, PILToTensor
from tqdm import tqdm
import torch
from typing import List, Optional, Tuple, Union
import math
import requests
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np
from torchvision.transforms import functional as F

In [None]:
torch.autograd.set_detect_anomaly(False)

### Data
- [ ] Even out male and female samples
- [ ] Try a simpler task -- "color" of images

In [None]:
# Hyperparameters
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
PRETRAINED = "google/ncsnpp-celebahq-256"
DATASET_SOURCE = "Ryan-sjtu/celebahq-caption"

N_INFERENCE_STEPS = 1000
BATCH_SIZE = 16
TEST_BATCH_SIZE = 2
N_EPOCHS = 100
NOISE_BATCH = 32

N = 200 # num of data of each gender
RFF_DIM = 1024 # Random Fourier Features Dimensions

In [None]:
# Load
unconditional_pipeline = ScoreSdeVePipeline.from_pretrained(PRETRAINED).to(device=DEVICE)
unconditional_scheduler: ScoreSdeVeScheduler = unconditional_pipeline.scheduler
unet = UNet2DModel.from_pretrained(PRETRAINED)
dataset: DatasetDict = load_dataset(DATASET_SOURCE, streaming=True)

# Demonstrate unconditional generation
    # Also helps to set inference steps
# example_image = unconditional_pipeline(num_inference_steps=N_INFERENCE_STEPS).images[0]
# example_image

#### Utils

In [None]:
from datasets import concatenate_datasets

def balance(dataset: Dataset):
    # Filter to separate the classes
    class_0 = dataset.filter(lambda x: x['label'][0] == 0.0)
    class_1 = dataset.filter(lambda x: x['label'][0] == 1.0)

    # Find the smaller class size
    min_class_size = min(len(class_0), len(class_1))
    print(f"min class size: {min_class_size}")

    # Shuffle and sample from the larger class
    class_0 = class_0.shuffle().select(range(min_class_size)) if len(class_0) > min_class_size else class_0
    class_1 = class_1.shuffle().select(range(min_class_size)) if len(class_1) > min_class_size else class_1

    # Concatenate the balanced classes
    balanced_dataset = concatenate_datasets([class_0, class_1]).shuffle()

    return balanced_dataset

In [None]:
# Utils
def process(img):
    return img.to(torch.float32).transpose(-1,-3).transpose(-1,-2) / 256
resize = Resize(unconditional_pipeline.unet.config.sample_size)



In [None]:
def show_img(img: Image, size = 3, title=None) -> None:
    plt.figure(figsize=(size, size))
    plt.imshow(img)
    plt.axis('off')
    if title != None:
        plt.title(title)
    plt.show()

In [None]:
def tensor_to_image(ts: torch.Tensor) -> Image:
    img_np = ts.squeeze().permute(1, 2, 0).cpu().numpy()
    img_np = (img_np * 255).astype(np.uint8)  # Rescale back to 0-255 range
    return Image.fromarray(img_np)

def image_to_tensor(img: Image) -> torch.Tensor:
    ts = F.to_tensor(img)
    return ts.unsqueeze(0)

In [None]:
def masked_image(data, mask):
    return tensor_to_image(data * mask)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

def plot_tensor_grid(tensor, size=2, columns=5):
    """
    Plots a grid of images from a 4D tensor.

    Parameters:
    - tensor (torch.Tensor): A 4D tensor of shape [n, 3, 256, 256].
    - figsize (tuple): A tuple indicating the size of the figure (width, height).
    - columns (int): The number of columns in the image grid.

    The number of rows is determined automatically.
    """
    # Ensure tensor is on CPU and detach it from the computation graph if necessary
    tensor = tensor.detach().cpu()

    # Calculate the number of rows needed
    n_images = tensor.shape[0]
    rows = np.ceil(n_images / columns).astype(int)

    # Create a matplotlib figure with the specified size
    fig, axes = plt.subplots(rows, columns, figsize=(size*columns, size*rows), squeeze=False)

    # Flatten the axes array for easy iteration
    axes_flat = axes.flatten()

    # Plot each image
    for i in range(n_images):
        # Convert the tensor to a numpy array and transpose the axes from (C, H, W) to (H, W, C)
        image = tensor[i].numpy().transpose(1, 2, 0)
        axes_flat[i].imshow(image)
        axes_flat[i].axis('off')

    # Turn off axes for any unused subplot
    for i in range(n_images, rows * columns):
        axes_flat[i].axis('off')

    plt.tight_layout()
    plt.show()


def plot_image_grid(images, size=2, columns=5):
    # Calculate the number of rows needed
    n_images = len(images)
    rows = np.ceil(n_images / columns).astype(int)

    # Create a matplotlib figure with the specified size
    fig, axes = plt.subplots(rows, columns, figsize=(size*columns, size*rows), squeeze=False)

    # Flatten the axes array for easy iteration
    axes_flat = axes.flatten()

    # Plot each image
    for i in range(n_images):
        # Convert the tensor to a numpy array and transpose the axes from (C, H, W) to (H, W, C)
        axes_flat[i].imshow(images[i])
        axes_flat[i].axis('off')

    # Turn off axes for any unused subplot
    for i in range(n_images, rows * columns):
        axes_flat[i].axis('off')

    plt.tight_layout()
    plt.show()

#### Preprocess

In [None]:
class CustomDataset(Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # print(f"Index: {idx}, Type: {type(idx)}")
        data, label = self.images[idx], self.labels[idx]
        # print(f"Data type: {type(data)}, Label type: {type(label)}")
        return data, label

    def split(self, test=0.2):
        test_size = int(len(self) * test)
        return CustomDataset(self.images[test_size:], self.labels[test_size:]), \
            CustomDataset(self.images[:test_size], self.labels[:test_size])

In [None]:
# Dataset
def get_custom_dataset(dataset, N, start_index=0):
    n_woman, n_man = 0, 0
    sub_dict = {'image':[], 'label':[]}
    end_index = 0
    for i, sample in enumerate(dataset):
        if i < start_index:
            continue
        is_woman = 'woman' in sample['text']
        if n_woman >= N and n_man >= N:
            end_index = i
            break
        if (n_man >= N and not is_woman) or (n_woman >= N and is_woman):
            continue
        sub_dict['image'].append(image_to_tensor(resize(sample['image'])).squeeze())
        sub_dict['label'].append(1 if is_woman else 0)
        if is_woman:
            n_woman+=1
        else:
            n_man+=1
    assert (sum(sub_dict['label']) == (len(sub_dict['label'])//2))
    return CustomDataset(sub_dict['image'], sub_dict['label']), end_index

train_ds, end_index = get_custom_dataset(dataset['train'], N)
test_ds, _ = get_custom_dataset(dataset['train'], N//10, end_index)

In [None]:
print(len(train_ds))
print(sum([x[1] for x in train_ds]))
print(len(test_ds))
print(sum([x[1] for x in test_ds]))

In [None]:
loader = DataLoader(train_ds, batch_size=BATCH_SIZE)
testloader = DataLoader(test_ds, batch_size=TEST_BATCH_SIZE)

In [None]:
for i in range(len(train_ds)):
    print(train_ds[i][0].shape)
    break

In [None]:
for image, label in loader:
    print(label)
    break

In [None]:
# preview the dataset
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(train_ds), size=(1,)).item()
    img, label = train_ds[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(label)
    plt.axis("off")
    plt.imshow(tensor_to_image(img))
plt.show()

### Model
- [ ] Try 32 by 32

#### Other Models

In [None]:
# Create classifer model

# model = mobilenet_v3_small()
# model = nn.Sequential(
#     model.features,
#     model.avgpool,
#     nn.Flatten(),
#     nn.Dropout(0.2, inplace=True),
#     nn.Linear(in_features=576, out_features=1),
#     nn.Sigmoid()
# )

# Simple CNN
# model = nn.Sequential(
#     nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5),
#     nn.LeakyReLU(),
#     nn.MaxPool2d(kernel_size=5),
#     nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5),
#     nn.LeakyReLU(),
#     nn.MaxPool2d(kernel_size=3),
#     nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5),
#     nn.LeakyReLU(),
#     nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3),
#     nn.LazyLinear(out_features=128),
#     nn.LeakyReLU(),
#     nn.Linear(in_features=128, out_features=1),
#     nn.Sigmoid()
# )

# model = nn.Sequential(
#             nn.Conv2d(3, 16, kernel_size=3, padding=1),
#             nn.ReLU(),
#             nn.MaxPool2d(2),
#             nn.Conv2d(16, 32, kernel_size=3, padding=1),
#             nn.ReLU(),
#             nn.MaxPool2d(2),
#             nn.Conv2d(32, 64, kernel_size=3, padding=1),
#             nn.ReLU(),
#             nn.MaxPool2d(2),
#             nn.Flatten(),
#             nn.Linear(64 * 32 * 32, 512),
#             nn.ReLU(),
#             nn.Linear(512, 1),
#             nn.Sigmoid()
#         )



#### CNN with RFF

In [None]:
def generate_rff(log_sigma, dimension=100):  # Set your desired dimension for RFF
    random_matrix = torch.randn((1, dimension), device=log_sigma.device)
    transformed = torch.matmul(log_sigma.unsqueeze(-1), random_matrix)
    return torch.cos(transformed)

class CustomCNNWithRFF(nn.Module):
    def __init__(self, rff_dim=100):
        super(CustomCNNWithRFF, self).__init__()
        self.rff_dim = rff_dim
        self.cnn_layers = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.LeakyReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.LeakyReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(),
            nn.MaxPool2d(2),
            nn.Flatten()
        )
        self.fc_layers = nn.Sequential(
            nn.Linear(64 * 32 * 32 + self.rff_dim, 512),  # Adjusted for RFF
            nn.LeakyReLU(),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )

    def forward(self, x, sigma):
        log_sigma = torch.log(sigma + 1e-8)  # To ensure numerical stability
        rff = generate_rff(log_sigma, self.rff_dim)

        x = self.cnn_layers(x)
        x = x.view(x.size(0), -1)  # Flatten the output
        x = torch.cat([x, rff], dim=1)  # Concatenate RFF at the last linear layer
        x = self.fc_layers(x)
        return x


#### Mobile Net with RFF

In [None]:
class CustomMobileNetWithRFF(nn.Module):
    def __init__(self, rff_dim=100):
        super(CustomMobileNetWithRFF, self).__init__()
        self.rff_dim = rff_dim
        self.mobile_net = mobilenet_v3_small()
        self.mobile_net_layers = nn.Sequential(
            self.mobile_net.features,
            self.mobile_net.avgpool,
            nn.Flatten(),
            nn.Dropout(0.2, inplace=True)
        )
        self.fc_layers = nn.Sequential(
            nn.Linear(576 + self.rff_dim, 512),  # Adjusted for RFF
            nn.LeakyReLU(),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )

    def forward(self, x, sigma):
        log_sigma = torch.log(sigma + 1e-8)  # To ensure numerical stability
        rff = generate_rff(log_sigma, self.rff_dim)

        x = self.mobile_net_layers(x)
        x = x.view(x.size(0), -1)  # Flatten the output
        x = torch.cat([x, rff], dim=1)  # Concatenate RFF at the last linear layer
        x = self.fc_layers(x)
        return x


In [None]:
model = CustomCNNWithRFF(rff_dim=RFF_DIM).to(DEVICE)
# model = CustomMobileNetWithRFF(rff_dim=RFF_DIM).to(DEVICE)

In [None]:
def get_timesteps_and_sigmas(size, high):
    timesteps = torch.randint(
                    low=0,
                    high=high,
                    # high=1,
                    size= (size, )
                ).to(device=DEVICE)
    sigmas = unconditional_scheduler.discrete_sigmas.to(DEVICE)[timesteps]
    return timesteps, sigmas

In [None]:
def compute_accuracy(dataloader, model, loss_fn, sigma=0):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.float().to(DEVICE), y.float().to(DEVICE)
            # timesteps, sigmas = get_timesteps_and_sigmas()
            sigmas = torch.tensor(sigma).float().repeat(X.shape[0]).to(DEVICE)
            # add noise of sigma to X
            X = X + torch.randn_like(X) * sigma
            pred = model(X, sigmas).squeeze(1)
            # plot_tensor_grid(X[:2]); break
            # print(type(pred[0]))
            # print(type(y[0]))
            test_loss += loss_fn(pred, y).item()
            correct += ((pred > 0.5) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    return correct, test_loss

In [None]:
loss_fn = nn.BCELoss()
compute_accuracy(loader, model, loss_fn, 10)

# Method 1: Train Classifier

In [None]:
# Train Classifier
# optim = torch.optim.Adam(model.parameters(), lr=0.001)
optim = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)

loss_fn = nn.BCELoss()
sig = nn.Sigmoid().to(device=DEVICE)
eval_sigmas = np.linspace(0,10,10) # images with sigma too high is not classifiable anyways

# shape [epoch, num_sigmas, 2]
progress = []

with tqdm(total=N_EPOCHS) as pbar:
    for epoch in range(N_EPOCHS):

        epoch_progress = []
        for i, sigma in enumerate(eval_sigmas):
            # if (i > 0): # making testing faster
                # break;
            accuracy, loss = compute_accuracy(loader, model, loss_fn, sigma)
            epoch_progress.append((accuracy,loss))
        # pbar.set_description(f"accuracy {epoch_progress[0][0]} loss {epoch_progress[0][1]}")
        pbar.set_description(f"accuracy {epoch_progress[0][0]}, {epoch_progress[-1][0]} loss {epoch_progress[0][1]}, {epoch_progress[-1][1]}")
        pbar.refresh()
        progress.append(epoch_progress)


        # noise_size = int(np.interp(epoch,[0,N_EPOCHS-1],[1,NOISE_BATCH]))
        # noise_high = int(np.interp(epoch,[0,N_EPOCHS-1],[1,len(unconditional_scheduler.discrete_sigmas)]))
        noise_size = NOISE_BATCH
        noise_high = 20
        # print(f"epoch {epoch}: noise size {noise_size}, noise max {noise_high}")
        for i, data in enumerate(loader):
            inputs, labels = data # learning image s with sigma above 1 is not classificaable anyways
            inputs, labels = inputs.float().to(DEVICE), labels.float().to(DEVICE)
            # print(inputs.shape)
            # print(labels.shape)

            # Preprocess image
            # inputs = process(inputs)
            # inputs = resize(inputs).unsqueeze(1)
            inputs = inputs.unsqueeze(1)
            timesteps, sigmas = get_timesteps_and_sigmas(noise_size, noise_high)
            sigmas = sigmas.repeat(inputs.shape[0])
            # print(sigmas.shape)
            inputs = unconditional_scheduler.add_noise(
                original_samples=inputs,
                noise=None,
                timesteps=timesteps
            ).flatten(0,1)
            # print(inputs.shape)
            # plot_tensor_grid(inputs[:10,:,:,:])
            labels = labels.unsqueeze(-1)
            labels = labels.expand(labels.shape[0],noise_size).flatten()
            # print(labels.shape)
            # Training
            optim.zero_grad()
            # print(f"inputs shape: {inputs.shape}")
            predicted = model(inputs, sigmas).squeeze(1)
            # print(predicted)
            loss = loss_fn(predicted, labels)
            loss.backward()
            optim.step()

        # Logging
        pbar.update(1)

#### Quick Eval

In [None]:
fig, axes = plt.subplots(1,2,figsize=(10,8))

cmap = plt.cm.viridis
colors = cmap(np.linspace(0, 1, len(progress[0])))

# Plot each list on the axis
for i in range(len(progress[0])):
    axes[0].plot([x[i][0] for x in progress], marker='', linestyle='-', color=colors[i], label=f"sigma-{eval_sigmas[i]:.2f}")

# Set legends, title, and labels
axes[0].set_title("Training Accuracies")
axes[0].legend(loc="upper left")
axes[0].set_xlabel("Epochs")
axes[0].set_ylabel("Accuracy")
axes[0].set_ylim(0.0,1.0)
# ax.set_xticks(range(100))

for i in range(len(progress[0])):
    axes[1].plot([x[i][1] for x in progress], marker='', linestyle='-', color=colors[i], label=f"sigma-{eval_sigmas[i]:.2f}")
axes[1].set_title("Training Losses")
axes[1].set_xlabel("Epochs")
axes[1].set_ylabel("Loss")

# Display the plot
plt.tight_layout()
plt.show()

In [None]:
n = 1
for sigma in eval_sigmas:
    print(compute_accuracy(testloader, model, loss_fn, sigma))
# with torch.no_grad():
#     for i, data in enumerate(testloader):
#         if i > n:
#             break
#         inputs, labels = data
#         inputs, labels = inputs.float().to(DEVICE), labels.float().to(DEVICE)

#         inputs = inputs.unsqueeze(1)
#         sigmas = torch.tensor(sigma).float().repeat(inputs.shape[0]).to(DEVICE)
#         print(sigmas)
#         # print(sigmas.shape)
#         inputs = unconditional_scheduler.add_noise(
#             original_samples=inputs,
#             noise=None,
#             timesteps=timesteps
#         ).flatten(0,1)
#         # print(inputs.shape)
#         labels = labels.unsqueeze(-1)
#         labels = labels.expand(labels.shape[0],NOISE_BATCH).flatten()

#         # sigmas = torch.ones(inputs.shape[0]).to(DEVICE)
#         predicted = model(inputs, sigmas).squeeze(1)
#         print(['woman' if x else 'man' for x in (predicted>0.5)])
#         print(f"GT: {['woman' if x==1 else 'man' for x in labels]}")
#             plot_tensor_grid(inputs)

In [None]:
sigma = 2000
with torch.enable_grad():
    for X, y in testloader:
        X, y = X.float().to(DEVICE), y.float().to(DEVICE)
        # timesteps, sigmas = get_timesteps_and_sigmas()
        sigmas = torch.tensor(sigma).float().repeat(X.shape[0]).to(DEVICE)
        # add noise of sigma to X
        X = X + torch.randn_like(X) * sigma
        X.requires_grad = True
        plot_tensor_grid(X)
        # print(sigma)
        predicted = torch.log(model(X, sigmas))
        predicted.sum().backward()
        print(X.grad)
        break


### Conditional Generation

In [None]:
# Checkpoints
# classifier_name = 'toy_classifier.pt'

# torch.save(model, classifier_name)
# model = torch.load(classifier_name)

In [None]:
class BinaryClassConditionalScoreSdeVePipeline(ScoreSdeVePipeline):
    def add_conditional_gradient(self, sample, sigma, output, classifier, target, debug=False):
        with torch.enable_grad():
            sample.requires_grad = True
            sigma = sigma.unsqueeze(-1).repeat(sample.shape[0]).to(DEVICE)
            # print(sigma)
            predicted = torch.log(classifier(sample, sigma)) if target \
                else torch.log(1- classifier(sample, sigma))
            predicted.sum().backward()
            res = torch.nan_to_num(sample.grad)
            # if not target:
                # res = -res
            output += 32 * res
        return output

    @torch.no_grad()
    def __call__(
        self,
        classifier,
        target: bool = True,
        batch_size: int = 1,
        num_inference_steps: int = 2000,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        **kwargs,
    ) -> Union[ImagePipelineOutput, Tuple]:
        img_size = self.unet.config.sample_size
        shape = (batch_size, 3, img_size, img_size)

        model = self.unet

        sample = randn_tensor(shape, generator=generator) * self.scheduler.init_noise_sigma
        sample = sample.to(self.device)

        self.scheduler.set_timesteps(num_inference_steps)
        self.scheduler.set_sigmas(num_inference_steps)

        for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
            sigma = self.scheduler.sigmas[i]
            sigma_t = sigma * torch.ones(shape[0], device=self.device)

            # correction step
            for _ in range(self.scheduler.config.correct_steps):
                model_output = self.unet(sample, sigma_t).sample
                model_output = self.add_conditional_gradient(sample, sigma, model_output, classifier, target)
                sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample

            # prediction step
            model_output = model(sample, sigma_t).sample
            model_output = self.add_conditional_gradient(sample, sigma, model_output, classifier, target)
            output = self.scheduler.step_pred(model_output, t, sample, generator=generator)

            sample, sample_mean = output.prev_sample, output.prev_sample_mean

        sample = sample_mean.clamp(0, 1)
        sample = sample.cpu().permute(0, 2, 3, 1).numpy()
        if output_type == "pil":
            sample = self.numpy_to_pil(sample)

        if not return_dict:
            return (sample,)

        return ImagePipelineOutput(images=sample)


In [None]:
# Example
conditional_pipeline = BinaryClassConditionalScoreSdeVePipeline(unconditional_pipeline.unet, unconditional_pipeline.scheduler)
images = conditional_pipeline(model, batch_size=1, target=False, num_inference_steps=800).images
t = PILToTensor()
images[0]

In [None]:
with torch.no_grad():
    for image in images:
        sample = (t(image) / 256).unsqueeze(0).to(device=DEVICE)
        print(f"Predicted: {model(sample, torch.zeros(1).to(DEVICE)).item()}")

In [None]:
plot_image_grid(images)

In [None]:
conditional_pipeline = BinaryClassConditionalScoreSdeVePipeline(unconditional_pipeline.unet, unconditional_pipeline.scheduler)
images = conditional_pipeline(model, batch_size=10, target=True, num_inference_steps=800).images
t = PILToTensor()
images[0]

In [None]:
with torch.no_grad():
    for image in images:
        sample = (t(image) / 256).unsqueeze(0).to(device=DEVICE)
        print(f"Predicted: {model(sample, torch.zeros(1).to(DEVICE)).item()}")

In [None]:
plot_image_grid(images)