# This is adapted from the work of [Halafi, 2024](https://arxiv.org/abs/2408.15094).

Original GitHub repo: https://github.com/shervinkhalafi/Constrained_Diffusion_Dual_Training


In [None]:
!pip install datasets

In [None]:
#Import Required Libraries
import os
import numpy as np

from dataclasses import dataclass
import wandb
from datasets import load_dataset
import matplotlib.pyplot as plt
from torchvision import transforms
import torchvision.datasets as datasets
import torch
from diffusers import UNet2DModel
from diffusers import AutoencoderKL
from diffusers import DDPMScheduler, DDIMScheduler
import torch.nn.functional as F
from diffusers.optimization import get_cosine_schedule_with_warmup
from diffusers import DDIMPipeline
from diffusers.utils import make_image_grid
from accelerate import Accelerator
from tqdm.auto import tqdm
from pathlib import Path
from accelerate import notebook_launcher
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torchvision
# from github import Github
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
from transformers import ViTFeatureExtractor, ViTForImageClassification, ViTImageProcessor
import shutil
from itertools import cycle
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer


#### Download Pretrained MNIST Diffusion Model

In [None]:
!gdown '1-c06A48YcgGqGy4Zq8hszmJiUQ9WWTPi'

In [None]:
%%bash

mkdir /content/output
mkdir /content/output/models
mkdir /content/output/samples

In [None]:
#Configuration of Training hyperparameters
@dataclass

class TrainingConfig:

    adaptation = False #If adaptation = True, the problem will be treated as a adaptation/fine-tuning problem where we want to adapt a pre-trained model to new data without overfitting.

    image_size = 32 # the generated image resolution (the training data will be resized to image_size*image_size)
    latent_size = 32 # Latent resolution (ignore if not using latent diffusion)
    diffusion_channels = 3 #1 for b&w images, 3 for RGB, 4 or more for latent diffusion

    primal_batch_size = 128 #number of images in each mini-batch sampled for computing the primal loss
    dual_batch_size = 64 #number of images in each mini-batch sampled for computing each constraint loss
    eval_batch_size = 64 # how many samples to sample during evaluation step

    num_epochs = 30 #number of total epochs
    batches_per_epoch = 4 #number of mini-batches per epoch
    primal_per_dual = 5 #number of primal descent steps after each update of the dual variables

    save_model_epochs = 30 # number of epochs between each time the model is saved
    save_plot_epochs = 20 # number of epochs between each time plots of relevant variables are saved
    save_image_epochs = 10 # number of epochs between each time a batch of images generated by the diffusion model are saved
    running_average_length = 5 #length of the running average for plotting average histograms of generated samples during training

    num_gpus = 2 #number of gpus to split the training on using the 'accelerate' library

    load_model_header = 'MODEL_NAME' # header of the initial model to load from the 'save_models_dir' directory. used if continuing training of a previusly trained model or fine-tuning a pre-trained model.
    save_model_header = 'MODEL_NAME' # header to save the model with

    gradient_accumulation_steps = 1

    lr_primal = 1e-4 # the maximum primal learning rate
    lr_dual_to_primal = 1000 #ratio of dual learning rate to primal learning rate
    lr_warmup_steps = 500 # number of warmup steps to use in the learning rate scheduler

    evaluate = True #set to True if you want the model to sample images from the diffusion model after every #save_image_epochs steps.
    wandb_logging = False #set to True if you want to log relevant variables to wandb

    architecture_size = 128 # the size of the denoising U-net model can be scaled up or down using this parameter


    dataset_name = 'mnist' #name of the dataset to use for training. could be one of ['mnist', 'celeb-a', 'image-net']
    include_default_color = True # whether to include the default background color ("black")

    mixed_precision = "fp16"  # `no` for float32, `fp16` for automatic mixed precision

    upload_to_github = False #If True, in addition to saving plots and imgs to file, will upload them to github repo
    github_token = " "
    github_repo = " "

    output_dir = "/content/output"  # the local directory name to save everything
    save_models_dir = "models" # the local directory to save trained models

config = TrainingConfig()
gdrive_path = "/content/output/"
model_path = "/content/trained_model_MODEL_NAME.pt"

In [None]:
#Classifier and VAE Setup

#here we load a classifier corresponding to the dataset that is being used. For image-net we use a CLIP model as classifier and also load a pre-trained Variational AutoEncoder.

if config.dataset_name == 'mnist':

    path = 'farleyknight-org-username/vit-base-mnist'
    classifier = ViTForImageClassification.from_pretrained(path)
    classifier_processor = ViTImageProcessor.from_pretrained(path)
    classifier.eval()

elif config.dataset_name == 'celeb-a':

    path = 'cledoux42/GenderNew_v002'
    classifier = ViTForImageClassification.from_pretrained(path)
    classifier_processor = ViTImageProcessor.from_pretrained(path)
    classifier.eval()

elif config.dataset_name == 'image-net':

    #load CLIP model as classifier
    clipmodel = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    classifier_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
    classifier = clipmodel
    classifier.eval()

    #load VAE

    vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
    model_encoder = vae.encode
    model_decoder = vae.decode


In [None]:
#Github Function
def up_to_hub(image_path, content_path = 'images/uploaded_image.png', token = config.github_token, repo_name = config.github_repo):
    #content_path is Path where you want the image to be uploaded in your repo

    if config.upload_to_github == False:
        return 1

    # Initialize Github instance with your token
    g = Github(token)

    # Get the specific repo
    repo = g.get_repo(repo_name)

    # Read the image as binary
    with open(image_path, "rb") as image_file:
        data = image_file.read()
        # Encode the binary data to base64 string
        data_base64 = data

    # Commit message
    commit_message = 'Upload image via PyGithub'

    # Upload the image
    try:
        # Check if the file already exists
        contents = repo.get_contents(content_path)
        repo.update_file(contents.path, commit_message, data_base64, contents.sha, branch="main")
        print(f"Updated existing file: {content_path}")
    except:
        # If the file does not exist, create it
        repo.create_file(content_path, commit_message, data_base64, branch="main")
        print(f"Created new file: {content_path}")

In [None]:
#Evaluate Function
def evaluate(running_count, config, epoch, pipeline, mu, c, git_folder, accelerator, device, classifier_processor, classifier, classify = False, upload_images = False, final_epoch = False):

    if config.dataset_name != 'image-net':

        #sample images, in default PIL format
        imgs = pipeline(
        batch_size = config.eval_batch_size
        ).images

        if config.dataset_name == 'mnist':

            imgs_ = [np.array(img) for img in imgs]
            imgs_ = np.stack(imgs_, axis = 0)
            imgs_ = torch.from_numpy(imgs_)
            inputs = classifier_processor.preprocess(imgs_, do_rescale = True)


        elif config.dataset_name == 'celeb-a':

            imgs = pipeline(
                batch_size = config.eval_batch_size, output_type = 'nd.array', num_inference_steps = 50
                ).images

            imgs = torch.from_numpy(imgs)
            imgs = torch.permute(imgs, (0, 3, 1, 2))
            imgs = imgs.to(device)

            inputs = classifier_processor.preprocess(imgs, do_rescale = False)


        outputs = classifier(torch.tensor(inputs['pixel_values']).to(device))
        logits_per_image = outputs.logits
        probs = logits_per_image.softmax(dim = 1)

        predicted_nums = torch.argmax(probs, dim = 1)
        arr = predicted_nums.cpu().numpy()

        nclasses = np.shape(running_count)[0]

        counts = np.bincount(arr, minlength = nclasses)

        running_count += counts

        if upload_images == True:

            plt.bar(np.arange(0, nclasses, 1), counts, align='center')
            plt.gca().set_xticks(np.arange(0, nclasses, 1))

            plt.title('histogram of generated sample classes')
            plt.savefig('class histogram.png')
            plt.close()

            up_to_hub('class histogram.png', content_path = git_folder_path + f"/{epoch:04d}_histogram.png")

        if (int((epoch + 1) / config.save_image_epochs)%config.running_average_length == 0) and (upload_images == True):
            plt.bar(np.arange(0, nclasses, 1), running_count, align='center', color = 'tab:red')
            plt.gca().set_xticks(np.arange(0, nclasses, 1))

            plt.title('running sum histogram of generated sample classes')
            plt.savefig('running sum class histogram.png')
            plt.close()

            up_to_hub('running sum class histogram.png', content_path = git_folder_path + f"/{epoch:04d}_rs_histogram.png")

        if final_epoch == True:
            plt.bar(np.arange(0, nclasses, 1), running_count, align='center', color = 'olivedrab')
            plt.gca().set_xticks(np.arange(0, nclasses, 1))

            plt.title('final_epoch histogram of generated sample classes')
            plt.savefig('running sum class histogram.png')
            plt.close()

            up_to_hub('running sum class histogram.png', content_path = git_folder_path + '/final_epoch_histogram.png')

        rows = int(np.sqrt(config.eval_batch_size))
        cols = int(np.sqrt(config.eval_batch_size))

        if config.dataset_name == 'celeb-a':
            imgs_decoded = [torchvision.transforms.functional.to_pil_image(imgs[k, :, :, :]) for k in range((imgs.shape)[0])]

            # Make a grid out of the images
            image_grid = make_image_grid(imgs_decoded, rows, cols)

        else:
            image_grid = make_image_grid(imgs, rows, cols)

        # Save the images
        test_dir = os.path.join(config.output_dir, "samples")
        os.makedirs(test_dir, exist_ok=True)

        image_path = f"{test_dir}/{epoch:04d}.png"

        image_grid.save(image_path)

        up_to_hub(image_path, content_path = git_folder_path + f"/{epoch:04d}.png")

    ###################################################
    #######################LATENT######################
    ###################################################

    elif config.dataset_name == 'image-net':

        device = accelerator.device

        with torch.no_grad():

            imgs = pipeline(
            batch_size = config.eval_batch_size, output_type = 'nd.array', num_inference_steps = 50
            ).images

            imgs = torch.from_numpy(imgs)
            imgs = torch.permute(imgs, (0, 3, 1, 2))
            imgs = imgs.to(device)

            vae.to(device)

            model_decoder = vae.decode

            print(imgs.shape)

            a = 0.431
            b = 36
            imgs_ = (model_decoder((imgs - a)*b).sample).detach()

            print(imgs_.shape)

            #Classify

            imgs_ = torch.clip((imgs_.permute(0, 2, 3, 1)), 0, 1)


            labels = ['photo of a cassette player', 'photo of a tench fish', 'photo of a garbage truck', 'photo of a parachute', 'photo of a fench horn', 'photo of a english springer dog', 'photo of a golf ball', 'photo of a church', 'photo of a gas pump', 'photo of a chainsaw']
            shorts = ['cass', 'fish', 'truck', 'parach', 'horn', 'dog', 'ball', 'church', 'pump', 'saw']

            classifier_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", device_map = device)
            clipmodel = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", device_map = device)

            imgs_ = imgs_.to(device)

            print(imgs_.device)

            inputs = classifier_processor(text = labels, images = imgs_, do_rescale = False, do_convert_rgb = True, return_tensors = "np", padding = True)

            inputs['pixel_values'] = torch.from_numpy(inputs['pixel_values']).float().to(device)
            inputs['attention_mask'] = torch.from_numpy(inputs['attention_mask']).int().to(device)
            inputs['input_ids'] = torch.from_numpy(inputs['input_ids']).int().to(device)

            outputs = clipmodel(**inputs)
            logits_per_image = outputs.logits_per_image  # this is the image-text similarity score
            probs = logits_per_image.softmax(dim = 1)  # we can take the softmax to get the label probabilities
            predicted_nums = torch.argmax(probs, dim = 1)
            arr = predicted_nums.cpu().numpy()
            counts = np.bincount(arr, minlength = 10)

            running_count += counts

            plt.bar(np.arange(0, 10), counts, align='center')
            plt.gca().set_xticks(np.arange(0, 10), shorts)

            plt.title('histogram of generated sample classes')
            plt.savefig('class histogram.png')
            plt.close()

            up_to_hub('class histogram.png', content_path = git_folder_path + f"/{epoch:04d}_histogram.png")

            rows = int(np.sqrt(config.eval_batch_size))
            cols = int(np.sqrt(config.eval_batch_size))

        ###########################################################
        #Save and Upload the generated Latents
        imgs_latents = [torchvision.transforms.functional.to_pil_image(imgs[k, :, :, :]) for k in range((imgs.shape)[0])]
        image_grid = make_image_grid(imgs_latents, rows, cols)

        test_dir = os.path.join(config.output_dir, "samples")
        os.makedirs(test_dir, exist_ok=True)

        image_path = f"{test_dir}/{epoch:04d}.png"

        image_grid.save(image_path)

        up_to_hub(image_path, content_path = git_folder + f"/{epoch:04d}_latents.png")
        ###########################################################
        #Save and Upload the generated images (decoded latents)
        imgs_ = imgs_.permute(0, 3, 1, 2)
        imgs_decoded = [torchvision.transforms.functional.to_pil_image(imgs_[k, :, :, :]) for k in range((imgs_.shape)[0])]
        image_grid = make_image_grid(imgs_decoded, rows, cols)
        # Save the images
        test_dir = os.path.join(config.output_dir, "samples")
        os.makedirs(test_dir, exist_ok=True)

        image_path = f"{test_dir}/{epoch:04d}.png"
        image_grid.save(image_path)

        up_to_hub(image_path, content_path = git_folder + f"/{epoch:04d}.png")
        ###########################################################
        #Save and Upload the latent pixel values histogram


        bat = (torch.flatten(imgs[0, :, :, :])).cpu().detach()
        lat = bat.numpy()
        plt.figure()
        plt.hist(lat, bins = 50)

        plt.savefig('lat.png')
        image_path = 'lat.png'
        up_to_hub(image_path, content_path = git_folder + f"/{epoch:04d}_latents_histogram.png")
        plt.close()
        ###########################################################

    return running_count


class CustomImageDataset(Dataset):
    def __init__(self, img_dir, transform = None, target_transform = None, use_images = False, images = None):
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
        self.use_images = use_images
        self.images = images

    def __len__(self):
        if self.use_images:
          return self.images.shape[0]
        N = int(torch.load(self.img_dir + '/N').item())

        return N

    def __getitem__(self, idx):
        if self.use_images:
          image = self.images[idx]
          label = 1
        else:
          img_path = os.path.join(self.img_dir, '/sample_' + str(idx) + '.pt')
          img_tensor = torch.load(self.img_dir + '/sample_' + str(idx) + '.pt')
          label =torch.load(self.img_dir + '/label_' + str(idx) + '.pt')

          image = img_tensor

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [None]:
#Fast Dataset Creation Function
def create_datasets_fast(N_spc, config, dataloaders_only = False, save_jpgs_only = False, generated = False, model = None, noise_scheduler = None):

    #################################
    #Delete and recreate dataset directories
    N_datasets = np.shape(N_spc)[0]
    N_classes = np.shape(N_spc)[1]

    dataset_names_list = []

    if os.path.exists('temp_data'):
        shutil.rmtree('temp_data')
    os.mkdir('temp_data')

    for i in range(N_datasets):
        dataset_names_list += ["dataset_" + str(i)]

    if save_jpgs_only == True:
        dataset_names_list = ["FID_baseline_dataset_" + config.dataset_name]

    if dataloaders_only == False:

        for new_folder_name in dataset_names_list:
            # Specify the name of the new folder
            if os.path.exists(new_folder_name):
                shutil.rmtree(new_folder_name)
            # Create a new directory in the current working directory
            os.mkdir(new_folder_name)

        #################################
        #create datasets
        if config.dataset_name == 'image-net':
            dataset_orig = load_dataset("frgfm/imagenette", '320px', split="train")
        elif config.dataset_name == 'mnist':
            dataset_orig = load_dataset("ylecun/mnist", split="train")
        elif config.dataset_name == 'celeb-a':
            dataset_orig = load_dataset("tpremoli/CelebA-attrs", split="validation")
            dataset_orig = dataset_orig.rename_column('Male', 'label')




        for i in range(N_datasets):

            path = dataset_names_list[i]
            n = 0

            for j in range(N_classes):

                if N_spc[i, j] == 0:
                    continue


                if config.dataset_name == 'celeb-a':
                    ds_filtered = dataset_orig.filter(lambda example: example["label"] == int((j-0.5)*2))
                else:
                    ds_filtered = dataset_orig.filter(lambda example: example["label"] == j)
                ds_filtered = ds_filtered.filter(lambda example, idx: idx < N_spc[i, j], with_indices=True)
                ds_filtered = ds_filtered.with_format("torch")


                dataloader = DataLoader(ds_filtered, batch_size=1)

                trans = torchvision.transforms.Resize((config.image_size, config.image_size))


                for batch in dataloader:

                    img = trans(batch['image'][0, :, :, :])

                    if img.shape[0] == 1:
                        img = img.repeat(3, 1, 1)

                    if config.dataset_name == 'image-net':
                        torch.save(img, 'temp_data/' + 'sample_' + str(n) + '.pt')
                    if config.dataset_name != 'image-net' and save_jpgs_only == False:
                        img = img/255
                        torch.save(img, path + '/sample_' + str(n) + '.pt')
                        torch.save(j, path + '/label_' + str(n) + '.pt')
                        img = img*255

                    if save_jpgs_only == True:
                        img = img/255
                        img = img.double()
                        torchvision.utils.save_image(img, path + '/sample_' + str(n) + '.jpg')

                    n += 1

                N = torch.tensor(n)
                torch.save(N, 'temp_data/N')
                torch.save(N, path + '/N')

                if save_jpgs_only == True:
                    continue

                if config.dataset_name == 'image-net':

                    ds = CustomImageDataset(img_dir = 'temp_data')
                    dataloader = torch.utils.data.DataLoader(ds, batch_size = 16, shuffle = True)

                    device ='cuda:0'
                    vae.to(device)
                    model_encoder = vae.encode

                    n = 0

                    with torch.no_grad():
                        for batch in dataloader:
                            batch[0] = (batch[0]/255).to(device).detach()
                            lat = model_encoder(batch[0])
                            lat = lat.to_tuple()[0]
                            lat = (lat.mean).detach()
                            for k in range(lat.shape[0]):
                                torch.save(lat[k, :, :, :].detach(), path + '/sample_' + str(n) + '.pt')
                                n += 1


            N = torch.tensor(n)

            torch.save(N, path + '/N')
    #Now we create the dataloaders
    dataset_list = []
    preprocess = transforms.Compose(
    [
        transforms.Normalize([0.5], [0.5]),
    ]
    )

    for i in range(N_datasets):
        if config.dataset_name != 'image-net':
            dataset_list += [CustomImageDataset(img_dir = dataset_names_list[i], transform = preprocess)]
        else:
            dataset_list += [CustomImageDataset(img_dir = dataset_names_list[i])]



    ######################################
    train_dataloader_list = []

    for i in range(len(dataset_list)):
        train_dataloader_list += [torch.utils.data.DataLoader(dataset_list[i], batch_size = config.primal_batch_size, shuffle = True)]


    return train_dataloader_list


In [None]:
#### MNIST CLASSIFIER

import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        # First convolutional layer
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        # Second convolutional layer
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # Shortcut connection to match dimensions
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += self.shortcut(residual)
        out = F.relu(out)

        return out

class MNISTResNet(nn.Module):
    def __init__(self, num_classes=10):
        super(MNISTResNet, self).__init__()

        # Initial input processing
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)

        # Residual blocks with increasing channels and decreasing spatial dimensions
        self.layer1 = self._make_layer(16, 32, stride=1)
        self.layer2 = self._make_layer(32, 64, stride=1)
        self.layer3 = self._make_layer(64, 64, stride=1)

        # Additional downsampling to reduce feature map size further
        self.pool1 = nn.MaxPool2d((2, 2))
        self.pool2 = nn.MaxPool2d((2, 2))
        self.pool3 = nn.MaxPool2d((2, 2))

        # Final classifier - 64 * 4 * 4 = 1024
        self.fc1 = nn.Linear(64 * 4 * 4, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def _make_layer(self, in_channels, out_channels, stride):
        layers = []
        # First block may downsample
        layers.append(ResidualBlock(in_channels, out_channels, stride))
        # Second block maintains dimensions
        layers.append(ResidualBlock(out_channels, out_channels, 1))
        return nn.Sequential(*layers)

    def forward(self, x):
        # Initial processing
        out = F.relu(self.bn1(self.conv1(x)))

        # Residual blocks
        out = self.layer1(out)
        out = self.pool1(out)
        out = self.layer2(out)
        out = self.pool2(out)
        out = self.layer3(out)
        out = self.pool3(out)

        # Flatten and classify
        out = torch.flatten(out, 1)
        out = self.fc1(out)
        out = F.relu(out)
        out = self.fc2(out)

        return out

In [None]:
# MNIST CLASSIFIER TRAINING
def train_mnist_classifer(config, model, optimizer, lr_scheduler, train_dataloader, epochs = 100):

    # Initialize accelerator and tensorboard logging
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with="tensorboard",
        project_dir=os.path.join("FOLDER_NAME", "logs"),
    )

    device = accelerator.device

    model, optimizer, dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )

    iterator = cycle(dataloader)

    global_step = 0

    #Actually start training
    for epoch in range(epochs):

        progress_bar = tqdm(total = config.batches_per_epoch, disable = not accelerator.is_local_main_process, position = 0)
        progress_bar.set_description(f"Epoch {epoch}")

        for step in range(config.batches_per_epoch):

            device = accelerator.device

            train_features, labels = next(iterator)
            Clean_images = train_features.to(device)
            Noises = torch.randn(Clean_images.shape, device = Clean_images.device)
            Batch_sizes = Clean_images.shape[0]

            with accelerator.accumulate(model):
                digit_preds = model(Clean_images)
                Losses = F.cross_entropy(digit_preds, labels, reduction = 'none')

                loss_scales = 1.

                loss = (loss_scales * Losses).mean()

                accelerator.backward(loss)

                accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()


            progress_bar.update(1)
            logs = {"loss": loss.detach().item(), "step": global_step}
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1

In [None]:
import numpy as np

def evaluate_classifier(dataloader, classifier):
  iterator = cycle(dataloader)
  device = "cuda:0"
  train_features, labels = next(iterator)
  true_labels = labels.to(device)
  Clean_images = train_features.to(device)
  digit_preds = classifier(Clean_images.half())
  preds = torch.softmax(digit_preds, dim = 1)
  return preds


In [None]:

#### TRAIN MNIST CLASSIFIER ####
from torch.optim.lr_scheduler import LambdaLR
classifier_dataset_spc = np.zeros((1, 10)) # a tensor specifying the number of Samples Per Class (spc) for each dataset
classifier_dataset_spc[0, :] = 256
classifer_dataloaders = create_datasets_fast(classifier_dataset_spc, config, False, False)

mnist_classifier = MNISTResNet(num_classes = 10)
optimizer = torch.optim.AdamW(mnist_classifier.parameters(), lr = 0.003)
lr_scheduler = LambdaLR(optimizer, lr_lambda=lambda step: 1.0)

train_mnist_classifer(config, mnist_classifier, optimizer, lr_scheduler, classifer_dataloaders[0], epochs = 100)

In [None]:
for param in mnist_classifier.parameters():
    param.requires_grad = False

In [None]:
# Evaluate classifier
result = evaluate_classifier(classifer_dataloaders[0], mnist_classifier)

In [None]:
result

In [None]:
#Model Setup

#A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
model = UNet2DModel(
    sample_size = config.latent_size,  # the target image resolution
    in_channels = config.diffusion_channels,  # the number of input channels, 3 for RGB images
    out_channels = config.diffusion_channels,  # the number of output channels, 3 for RGB images
    layers_per_block = 2,  # how many ResNet layers to use per UNet block
    block_out_channels = (config.architecture_size, config.architecture_size, 2*config.architecture_size, 2*config.architecture_size, 4*config.architecture_size, 4*config.architecture_size),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)

frozen_model = UNet2DModel(
    sample_size = config.latent_size,  # the target image resolution
    in_channels = config.diffusion_channels,  # the number of input channels, 3 for RGB images
    out_channels = config.diffusion_channels,  # the number of output channels, 3 for RGB images
    layers_per_block = 2,  # how many ResNet layers to use per UNet block
    block_out_channels = (config.architecture_size, config.architecture_size, 2*config.architecture_size, 2*config.architecture_size, 4*config.architecture_size, 4*config.architecture_size),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)

#loading parameters of pre-trained model if needed
if config.load_model_header != 'None':
    device = torch.device('cpu')
    state_dict = torch.load(model_path, map_location = device)

    # load params
    model.load_state_dict(state_dict)

    # Frozen model
    frozen_model.load_state_dict(state_dict)
    for param in frozen_model.parameters():
      param.requires_grad = False

    print('loaded model ' + config.load_model_header)

In [None]:

# Dataset with Sampled Images
def sample_images_with_model(model, noise_scheduler, total_samples, accelerator = None, as_torch=True, batch_size = 128):
  if not accelerator:
    accelerator = Accelerator(
          mixed_precision=config.mixed_precision,
          gradient_accumulation_steps=config.gradient_accumulation_steps,
          log_with="tensorboard",
          project_dir=os.path.join("FOLDER_NAME", "logs"),
      )
    model = accelerator.prepare(model)
  pipeline = DDIMPipeline(unet = accelerator.unwrap_model(model), scheduler = noise_scheduler)
  imgs_ = []
  for i in range(0, total_samples, batch_size):
    imgs_.extend(pipeline(batch_size = min(batch_size, total_samples - i)).images)
  if as_torch:
    imgs_ = [np.array(img) for img in imgs_]
    imgs_ = np.stack(imgs_, axis = 0)
    imgs_ = torch.from_numpy(imgs_)
  return imgs_

def create_dataset_sampled(model, noise_scheduler, total_samples, accelerator = None):
  samples = sample_images_with_model(model, noise_scheduler, total_samples, accelerator)
  preprocess = transforms.Compose(
    [
        transforms.Normalize([0.5], [0.5]),
    ]
  )
  dataloader = torch.utils.data.DataLoader(
      CustomImageDataset(
          img_dir = "",
          images = samples.permute((0, 3, 1, 2))/255,
          transform=preprocess,
          use_images = True),
      batch_size = config.primal_batch_size,
      shuffle = True)
  return dataloader

In [None]:
import matplotlib.pyplot as plt

#Training loop function
def finetune_loop(config, git_folder_path, model, classifier, mnist_classifier, noise_scheduler, optimizer, train_dataloaders, lr_scheduler, mu_init, b_init,
                  discriminative_class = 4,
                  delta = 0.1,
                  initial_penalty_lambda = 0.05, # Start small
                  lambda_increase_every = 10,   # Increase every n iterations
                  num_optimization_per_batch = 101,
                  lambda_growth_factor = 1.3,
                  early_stop_threshold = 0.004
                  ):

    # Initialize accelerator and tensorboard logging
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with="tensorboard",
        project_dir=os.path.join("FOLDER_NAME", "logs"),
    )

    device = accelerator.device

    model, optimizer, lr_scheduler = accelerator.prepare(
        model, optimizer, lr_scheduler
    )

    # Prepare dataloader and iterator
    dataloaders = accelerator.prepare(train_dataloaders)
    iterators = cycle(dataloaders)

    global_step = 0

    running_count = np.zeros(10)

    ddim_scheduler = DDIMScheduler.from_config(noise_scheduler.config)
    ddim_scheduler.set_timesteps(50)

    loss_hist = torch.zeros(int(config.num_epochs))

    classifier = classifier.to(device)

    ##important
    pipeline = DDIMPipeline(unet = accelerator.unwrap_model(model), scheduler = noise_scheduler)

    evaluate(running_count, config, 0, pipeline, 0.0, 0.0, git_folder_path, accelerator, device, classifier_processor, classifier, classify = True)

    #Actually start training
    for epoch in range(config.num_epochs):

        progress_bar = tqdm(total = config.batches_per_epoch, disable = not accelerator.is_local_main_process, position = 0)
        progress_bar.set_description(f"Epoch {epoch}")

        total_loss = 0.0

        for step in range(config.batches_per_epoch):

            device = accelerator.device

            train_features, _ = next(iterators)
            Clean_images = train_features.to(device)
            Noises = torch.randn(Clean_images.shape, device = Clean_images.device)
            Batch_sizes = Clean_images.shape[0]
            Timesteps = torch.randint(
                        0, noise_scheduler.config.num_train_timesteps, (Batch_sizes,), device = Clean_images.device,
                        dtype = torch.int64
                    )
            Noisy_images = noise_scheduler.add_noise(Clean_images, Noises, Timesteps)
            Train_labels = torch.from_numpy(np.array([discriminative_class] * Batch_sizes)).to(device)

            penalty_lambda = initial_penalty_lambda

            batch_total_loss = 0.0

            for n in range(num_optimization_per_batch):

              with accelerator.accumulate(model):

                  #The Lagrangian

                  Noise_preds = model(Noisy_images, Timesteps, return_dict = False)[0]

                  timestep = Timesteps.cpu()

                  # compute alphas, betas
                  alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep][:, None, None, None].to(device)
                  beta_prod_t = 1. - alpha_prod_t
                  sample = Noisy_images

                  pred_original_sample = (sample - beta_prod_t ** (0.5) * Noise_preds) / alpha_prod_t ** (0.5)
                  pred_original_sample = pred_original_sample.clamp(
                      -ddim_scheduler.config.clip_sample_range, ddim_scheduler.config.clip_sample_range
                  )

                  # Calculate \grad log p(y | x)
                  pred_probs = F.softmax(mnist_classifier(pred_original_sample), dim = 1) # p(y | x)

                  one_hot_vec = F.one_hot(Train_labels, num_classes = 10)

                  # PENALTY TERM: p(y | x) for the specific labels
                  pred_prev_classification = (pred_probs * one_hot_vec).sum(dim = 1)
                  penalty = torch.mean((Timesteps < 800) * torch.relu(pred_prev_classification - delta)**2) ** 2

                  early_stop = penalty <= early_stop_threshold

                  diffusion_loss = F.mse_loss(Noise_preds, Noises)

                  loss = diffusion_loss + penalty_lambda * torch.norm(diffusion_loss) / (torch.norm(penalty) + 1e-3) * penalty

                  batch_total_loss += loss.detach()

                  accelerator.backward(loss)

                  accelerator.clip_grad_norm_(model.parameters(), 1.0)
                  optimizer.step()
                  lr_scheduler.step()
                  optimizer.zero_grad()

                  if early_stop:
                    break

              # Gradually increase lambda
              if n % lambda_increase_every == 0:
                  penalty_lambda *= lambda_growth_factor

            if not early_stop:
              print("Batch did not converge with penalty:")
              print(penalty)

            batch_average_loss = batch_total_loss / num_optimization_per_batch
            total_loss += batch_average_loss
            progress_bar.update(1)
            logs = {"loss": batch_average_loss }
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1

        loss_hist[epoch] = total_loss / config.batches_per_epoch

        # After each epoch you optionally sample some demo images with evaluate() and save the model

        if accelerator.is_main_process and config.evaluate == True:

            classifier = classifier.to(device)

            ##important
            pipeline = DDIMPipeline(unet = accelerator.unwrap_model(model), scheduler = noise_scheduler)


            if ((epoch + 1) % config.save_image_epochs == 0) and ((epoch + 1) % config.save_plot_epochs != 0):
                #just draw samples and show them in a grid
                running_count = evaluate(running_count, config, epoch, pipeline, 0.0, 0.0, git_folder_path, accelerator, device, classifier_processor, classifier, classify = True)

            if epoch == config.num_epochs - 1:
                #In the last epoch sample lots of images and check their classes
                running_count_final_epoch = np.zeros(10)

                for k in range(20):
                    running_count_final_epoch = evaluate(running_count_final_epoch, config, epoch, pipeline, 0.0, 0.0, git_folder_path, accelerator, device, classifier_processor, classifier, classify = True, upload_images = False)

                running_count_final_epoch = evaluate(running_count_final_epoch, config, epoch, pipeline, 0.0, 0.0, git_folder_path, accelerator, device, classifier_processor, classifier, classify = True, upload_images = False, final_epoch = True)

            if ((epoch + 1) % config.save_plot_epochs == 0 ) or (epoch == config.num_epochs - 1):
                #draw samples and show them in a grid. Also update plots and add a class histogram and save model
                running_count = evaluate(running_count, config, epoch, pipeline, mu_init, b_init, git_folder_path, accelerator, device, classifier_processor, classifier, classify = True)

                #Save model
                torch.save(model.state_dict(), gdrive_path + 'models/trained_model_' + config.save_model_header + '.pt')

                #######################plots of histories of important stuff#######################
                with torch.no_grad():

                    plt.plot(loss_hist)
                    plt.ylabel('loss')
                    plt.xlabel('# epochs')
                    plt.title('loss history')
                    plt.savefig(gdrive_path + 'loss_hist.png')
                    plt.close()

                    torch.save(loss_hist, gdrive_path + 'loss_hist.pt')


In [None]:
# Run

noise_scheduler = DDPMScheduler(num_train_timesteps = 1000)
total_samples = 512
discriminative_class = 4
Train_dataloaders = [create_dataset_sampled(frozen_model, noise_scheduler, total_samples)]
################HYPERPARAMS##################
eta_resilience = 0.1

config.num_epochs = 100
config.batches_per_epoch = int((total_samples/config.primal_batch_size))
if config.batches_per_epoch == 0:
    config.batches_per_epoch = 1
# config.batches_per_epoch = 1
config.primal_per_dual = 2
config.save_image_epochs = 10
config.save_plot_epochs = 500
config.running_average_length = 5
config.lr_primal = 0.0003
config.lr_dual_to_primal = 1000

delta = 0.1
initial_penalty_lambda = 0.05 # Start small
lambda_increase_every = 10   # Increase every n iterations
num_optimization_per_batch = 151
lambda_growth_factor = 1.2
early_stop_threshold = 0.005

# mu_init_scalar = 0
mu_init_scalar = 0.0
b_init_scalar = 0.0

if config.adaptation == True:
    b_init = b_init_scalar*torch.ones(2)
    mu_init = torch.tensor([1, 0])
else:
    mu_init = mu_init_scalar*torch.zeros(1)
    mu_init[0] = 1
    b_init = b_init_scalar*torch.ones(1)

constrained = False
resilient = False

alpha = 0.085

#OPTIMIZER + LR_SCHEDULER
optimizer = torch.optim.AdamW(model.parameters(), lr = config.lr_primal)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer = optimizer,
    num_warmup_steps = config.lr_warmup_steps,
    num_training_steps = (config.batches_per_epoch * config.num_epochs),
    num_cycles = 0.25
)


############################################

git_folder_path = 'FOLDER_NAME'
config.save_model_header = 'MODEL_NAME'

if config.wandb_logging == True:
    wandb.init(
        project = 'PROJECT_NAME',
        config = config
    )

# Open the file in write mode and save the string
with open(gdrive_path + 'run_info.txt', 'w') as file:
    file.write('num_epochs = ' + str(config.num_epochs) + "\n")
    file.write('batches_per_epoch = ' + str(config.batches_per_epoch) + "\n")
    file.write('primal_batch_size = ' + str(config.primal_batch_size) + "\n")
    file.write('dataset_size = ' + str(total_samples)  + "\n")
    file.write('primal_per_dual = ' + str(config.primal_per_dual) + "\n")
    file.write('save_image_epochs = ' + str(config.save_image_epochs) + "\n")
    file.write('save_plot_epochs = ' + str(config.save_plot_epochs) + "\n")
    file.write('running_average_length_for_classification_of_generated_samples = ' + str(config.running_average_length) + "\n")
    file.write('lr_dual_to_primal = ' + str(config.lr_dual_to_primal) + "\n")
    file.write('eta_resilience = ' + str(eta_resilience) + "\n")
    file.write('eta_main = ' + str(config.lr_primal)  + "\n")
    file.write('alpha (const relaxation cost) = ' + str(alpha) + "\n")
    file.write('initial mu = ' + str(mu_init) + "\n")
    file.write('initial b = ' + str(b_init) + "\n")
    file.write('save_model_header = ' + str(config.save_model_header)  + "\n")
    file.write('load_model_header = ' + str(config.load_model_header)  + "\n")
    file.write('constrained = ' + str(constrained)  + "\n")
    file.write('resilient = ' + str(resilient)  + "\n")
    file.write(f'delta = {delta}\n')
    file.write(f'initial_penalty_lambda = {initial_penalty_lambda}\n')
    file.write(f'lambda_increase_every = {lambda_increase_every}\n')
    file.write(f'num_optimization_per_batch = {num_optimization_per_batch}\n')
    file.write(f'lambda_growth_factor = {lambda_growth_factor}\n')
    file.write(f'early_stop_threshold = {early_stop_threshold}\n')


# up_to_hub('run_info.txt', content_path = git_folder_path + '/run_info.txt')

# def train_loop(config, git_folder_path, model, classifier, noise_scheduler, optimizer, train_dataloaders, lr_scheduler, mu_init, b_init, constrained = False, resilient = False, alpha = 0,  model_pretrained = None):
args = (config, git_folder_path, model, classifier, mnist_classifier,
        noise_scheduler, optimizer, Train_dataloaders[0], lr_scheduler, mu_init, b_init,
        discriminative_class, delta, initial_penalty_lambda, lambda_increase_every,
        num_optimization_per_batch, lambda_growth_factor, early_stop_threshold)

# result = finetune_loop(*args)


notebook_launcher(finetune_loop, args, num_processes = config.num_gpus)

#Ignore if not having issues with server ports
# portnum = 8000
# notebook_launcher(train_loop, args, num_processes = config.num_gpus, use_port= str(portnum))

if config.wandb_logging == True:
    wandb.finish()

Please see `output/samples` for generated samples.