Notebook Description

In [None]:
#Import Required Libraries
import os
import numpy as np
#If you don't comment next line, CUDA will only use GPU1 and GPU0 will be invisible
# os.environ["CUDA_VISIBLE_DEVICES"] = str(0)
from dataclasses import dataclass
import wandb
from datasets import load_dataset
import matplotlib.pyplot as plt
from torchvision import transforms
import torchvision.datasets as datasets
import torch
from diffusers import UNet2DModel
from diffusers import AutoencoderKL
from diffusers import DDPMScheduler
import torch.nn.functional as F
from diffusers.optimization import get_cosine_schedule_with_warmup
from diffusers import DDIMPipeline
from diffusers.utils import make_image_grid
from accelerate import Accelerator
from tqdm.auto import tqdm
from pathlib import Path
from accelerate import notebook_launcher
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torchvision
from github import Github
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
from transformers import ViTFeatureExtractor, ViTForImageClassification, ViTImageProcessor
import shutil
from itertools import cycle
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer


In [None]:
#Configuration of Training hyperparameters
@dataclass

class TrainingConfig:

    adaptation = False #If adaptation = True, the problem will be treated as a adaptation/fine-tuning problem where we want to adapt a pre-trained model to new data without overfitting.

    image_size = 32 # the generated image resolution (the training data will be resized to image_size*image_size)
    latent_size = 32 # Latent resolution (ignore if not using latent diffusion)
    diffusion_channels = 1 #1 for b&w images, 3 for RGB, 4 or more for latent diffusion

    primal_batch_size = 128 #number of images in each mini-batch sampled for computing the primal loss
    dual_batch_size = 64 #number of images in each mini-batch sampled for computing each constraint loss
    eval_batch_size = 64 # how many samples to sample during evaluation step

    num_epochs = 30 #number of total epochs
    batches_per_epoch = 4 #number of mini-batches per epoch
    primal_per_dual = 5 #number of primal descent steps after each update of the dual variables
    
    save_model_epochs = 30 # number of epochs between each time the model is saved
    save_plot_epochs = 20 # number of epochs between each time plots of relevant variables are saved
    save_image_epochs = 10 # number of epochs between each time a batch of images generated by the diffusion model are saved
    running_average_length = 5 #length of the running average for plotting average histograms of generated samples during training

    num_gpus = 2 #number of gpus to split the training on using the 'accelerate' library

    load_model_header = 'MODEL_NAME' # header of the initial model to load from the 'save_models_dir' directory. used if continuing training of a previusly trained model or fine-tuning a pre-trained model.
    save_model_header = 'saved' # header to save the model with 

    gradient_accumulation_steps = 1

    lr_primal = 1e-4 # the maximum primal learning rate
    lr_dual_to_primal = 1000 #ratio of dual learning rate to primal learning rate
    lr_warmup_steps = 500 # number of warmup steps to use in the learning rate scheduler

    evaluate = True #set to True if you want the model to sample images from the diffusion model after every #save_image_epochs steps.
    wandb_logging = False #set to True if you want to log relevant variables to wandb

    architecture_size = 128 # the size of the denoising U-net model can be scaled up or down using this parameter

    
    dataset_name = 'mnist' #name of the dataset to use for training. could be one of ['mnist', 'celeb-a', 'image-net']

    mixed_precision = "fp16"  # `no` for float32, `fp16` for automatic mixed precision

    upload_to_github = True #If True, in addition to saving plots and imgs to file, will upload them to github repo
    github_token = " "
    github_repo = " "

    output_dir = "FOLDER_NAME"  # the local directory name to save everything
    save_models_dir = "models" # the local directory to save trained models

config = TrainingConfig()


In [None]:
#Classifier and VAE Setup

#here we load a classifier corresponding to the dataset that is being used. For image-net we use a CLIP model as classifier and also load a pre-trained Variational AutoEncoder.

if config.dataset_name == 'mnist':

    path = 'farleyknight-org-username/vit-base-mnist'
    classifier = ViTForImageClassification.from_pretrained(path)
    classifier_processor = ViTImageProcessor.from_pretrained(path)
    classifier.eval()

elif config.dataset_name == 'celeb-a':

    path = 'cledoux42/GenderNew_v002'
    classifier = ViTForImageClassification.from_pretrained(path)
    classifier_processor = ViTImageProcessor.from_pretrained(path)
    classifier.eval()

elif config.dataset_name == 'image-net':

    #load CLIP model as classifier
    clipmodel = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    classifier_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
    classifier = clipmodel
    classifier.eval()

    #load VAE
    
    vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
    model_encoder = vae.encode
    model_decoder = vae.decode
    

In [None]:
#Github Function
def up_to_hub(image_path, content_path = 'images/uploaded_image.png', token = config.github_token, repo_name = config.github_repo):
    #content_path is Path where you want the image to be uploaded in your repo

    if config.upload_to_github == False:
        return 1
    
    # Initialize Github instance with your token
    g = Github(token)

    # Get the specific repo
    repo = g.get_repo(repo_name)

    # Read the image as binary
    with open(image_path, "rb") as image_file:
        data = image_file.read()
        # Encode the binary data to base64 string
        data_base64 = data
        # data_base64 = base64.b64encode(data).decode()
        
    # Commit message
    commit_message = 'Upload image via PyGithub'

    # Upload the image
    try:
        # Check if the file already exists
        contents = repo.get_contents(content_path)
        repo.update_file(contents.path, commit_message, data_base64, contents.sha, branch="main")
        print(f"Updated existing file: {content_path}")
    except:
        # If the file does not exist, create it
        repo.create_file(content_path, commit_message, data_base64, branch="main")
        print(f"Created new file: {content_path}")

In [None]:
#Evaluate Function 
def evaluate(running_count, config, epoch, pipeline, mu, c, git_folder, accelerator, device, classifier_processor, classifier, classify = False, upload_images = True, final_epoch = False):

    if config.dataset_name != 'image-net':
        
        #sample images, in default PIL format
        imgs = pipeline(
        batch_size = config.eval_batch_size
        ).images

        if config.dataset_name == 'mnist':
    
            imgs_ = [np.array(img) for img in imgs]
            imgs_ = np.stack(imgs_, axis = 0)
            imgs_ = np.expand_dims(imgs_, axis = 1)
            imgs_ = np.concatenate((imgs_, imgs_, imgs_), axis = 1)
            imgs_ = torch.from_numpy(imgs_)
            inputs = classifier_processor.preprocess(imgs_, do_rescale = True)


        elif config.dataset_name == 'celeb-a':

            imgs = pipeline(
                batch_size = config.eval_batch_size, output_type = 'nd.array', num_inference_steps = 50
                ).images
    
            imgs = torch.from_numpy(imgs)
            imgs = torch.permute(imgs, (0, 3, 1, 2))
            imgs = imgs.to(device)

            inputs = classifier_processor.preprocess(imgs, do_rescale = False)
            
        
        outputs = classifier(torch.tensor(inputs['pixel_values']).to(device))
        logits_per_image = outputs.logits
        probs = logits_per_image.softmax(dim = 1)

        predicted_nums = torch.argmax(probs, dim = 1)
        arr = predicted_nums.cpu().numpy()

        nclasses = np.shape(running_count)[0]

        counts = np.bincount(arr, minlength = nclasses)

        running_count += counts
        
        if upload_images == True:

            plt.bar(np.arange(0, nclasses, 1), counts, align='center')
            plt.gca().set_xticks(np.arange(0, nclasses, 1))

            plt.title('histogram of generated sample classes')
            plt.savefig('class histogram.png')
            plt.close()

            up_to_hub('class histogram.png', content_path = git_folder_path + f"/{epoch:04d}_histogram.png")

        if (int((epoch + 1) / config.save_image_epochs)%config.running_average_length == 0) and (upload_images == True):
            plt.bar(np.arange(0, nclasses, 1), running_count, align='center', color = 'tab:red')
            plt.gca().set_xticks(np.arange(0, nclasses, 1))

            plt.title('running sum histogram of generated sample classes')
            plt.savefig('running sum class histogram.png')
            plt.close()

            up_to_hub('running sum class histogram.png', content_path = git_folder_path + f"/{epoch:04d}_rs_histogram.png")
            # running_count = 0
        
        if final_epoch == True:
            plt.bar(np.arange(0, nclasses, 1), running_count, align='center', color = 'olivedrab')
            plt.gca().set_xticks(np.arange(0, nclasses, 1))

            plt.title('final_epoch histogram of generated sample classes')
            plt.savefig('running sum class histogram.png')
            plt.close()

            up_to_hub('running sum class histogram.png', content_path = git_folder_path + '/final_epoch_histogram.png')

        rows = int(np.sqrt(config.eval_batch_size))
        cols = int(np.sqrt(config.eval_batch_size))

        # imgs = imgs.permute(0, 3, 1, 2)
        if config.dataset_name == 'celeb-a':
            imgs_decoded = [torchvision.transforms.functional.to_pil_image(imgs[k, :, :, :]) for k in range((imgs.shape)[0])]

            # Make a grid out of the images
            image_grid = make_image_grid(imgs_decoded, rows, cols)
        
        else:
            image_grid = make_image_grid(imgs, rows, cols)

        # Save the images
        test_dir = os.path.join(config.output_dir, "samples")
        os.makedirs(test_dir, exist_ok=True)

        image_path = f"{test_dir}/{epoch:04d}.png"

        image_grid.save(image_path)

        up_to_hub(image_path, content_path = git_folder_path + f"/{epoch:04d}.png")

    ###################################################
    #######################LATENT######################
    ###################################################

    elif config.dataset_name == 'image-net':

        device = accelerator.device

        with torch.no_grad():

            imgs = pipeline(
            batch_size = config.eval_batch_size, output_type = 'nd.array', num_inference_steps = 50
            ).images
    
            imgs = torch.from_numpy(imgs)
            imgs = torch.permute(imgs, (0, 3, 1, 2))
            imgs = imgs.to(device)

            # print('a1')

            # imgs = (imgs - 0.675)*50
            
            # print('a2')
            

            vae.to(device)
            # print('a4')

            model_decoder = vae.decode

            # print('a6')

            print(imgs.shape)
            
            a = 0.431
            b = 36
            imgs_ = (model_decoder((imgs - a)*b).sample).detach()

            # imgs_ = imgs_ - torch.mean(imgs_) + torch.mean(batch.to(device))

            print(imgs_.shape)

            #Classify
            # print('classify start')

            imgs_ = torch.clip((imgs_.permute(0, 2, 3, 1)), 0, 1)

            # print('c1')
  
            labels = ['photo of a cassette player', 'photo of a tench fish', 'photo of a garbage truck', 'photo of a parachute', 'photo of a fench horn', 'photo of a english springer dog', 'photo of a golf ball', 'photo of a church', 'photo of a gas pump', 'photo of a chainsaw']
            shorts = ['cass', 'fish', 'truck', 'parach', 'horn', 'dog', 'ball', 'church', 'pump', 'saw']

            classifier_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", device_map = device)
            clipmodel = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", device_map = device)
            # print('c1.5')
        
            imgs_ = imgs_.to(device)
            
            # print('c1.6')

            print(imgs_.device)

            inputs = classifier_processor(text = labels, images = imgs_, do_rescale = False, do_convert_rgb = True, return_tensors = "np", padding = True)

            # print('c2')


            inputs['pixel_values'] = torch.from_numpy(inputs['pixel_values']).float().to(device)
            inputs['attention_mask'] = torch.from_numpy(inputs['attention_mask']).int().to(device)
            inputs['input_ids'] = torch.from_numpy(inputs['input_ids']).int().to(device)

            # print('c2.5')
            outputs = clipmodel(**inputs)
            # print('c3')
            logits_per_image = outputs.logits_per_image  # this is the image-text similarity score
            # print('c3.5')
            probs = logits_per_image.softmax(dim = 1)  # we can take the softmax to get the label probabilities
            # print('c4')
            predicted_nums = torch.argmax(probs, dim = 1)
            arr = predicted_nums.cpu().numpy()
            counts = np.bincount(arr, minlength = 10)

            running_count += counts

            plt.bar(np.arange(0, 10), counts, align='center')
            plt.gca().set_xticks(np.arange(0, 10), shorts)

            plt.title('histogram of generated sample classes')
            plt.savefig('class histogram.png')
            plt.close()

            up_to_hub('class histogram.png', content_path = git_folder_path + f"/{epoch:04d}_histogram.png")

            rows = int(np.sqrt(config.eval_batch_size))
            cols = int(np.sqrt(config.eval_batch_size))

        ###########################################################
        #Save and Upload the generated Latents
        imgs_latents = [torchvision.transforms.functional.to_pil_image(imgs[k, :, :, :]) for k in range((imgs.shape)[0])]
        image_grid = make_image_grid(imgs_latents, rows, cols)

        test_dir = os.path.join(config.output_dir, "samples")
        os.makedirs(test_dir, exist_ok=True)

        image_path = f"{test_dir}/{epoch:04d}.png"

        image_grid.save(image_path)
        
        up_to_hub(image_path, content_path = git_folder + f"/{epoch:04d}_latents.png")
        ###########################################################
        #Save and Upload the generated images (decoded latents)
        imgs_ = imgs_.permute(0, 3, 1, 2)
        imgs_decoded = [torchvision.transforms.functional.to_pil_image(imgs_[k, :, :, :]) for k in range((imgs_.shape)[0])]
        image_grid = make_image_grid(imgs_decoded, rows, cols)
        # Save the images
        test_dir = os.path.join(config.output_dir, "samples")
        os.makedirs(test_dir, exist_ok=True)

        image_path = f"{test_dir}/{epoch:04d}.png"
        image_grid.save(image_path)
        
        up_to_hub(image_path, content_path = git_folder + f"/{epoch:04d}.png")
        ###########################################################
        #Save and Upload the latent pixel values histogram


        bat = (torch.flatten(imgs[0, :, :, :])).cpu().detach()
        lat = bat.numpy()
        plt.figure()
        plt.hist(lat, bins = 50)

        plt.savefig('lat.png')
        image_path = 'lat.png'
        up_to_hub(image_path, content_path = git_folder + f"/{epoch:04d}_latents_histogram.png")
        plt.close()
        ###########################################################

    return running_count


class CustomImageDataset(Dataset):
    def __init__(self, img_dir, transform = None, target_transform = None):
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):

        N = int(torch.load(self.img_dir + '/N').item())

        return N

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, '/sample_' + str(idx) + '.pt')
        img_tensor = torch.load(self.img_dir + '/sample_' + str(idx) + '.pt')

        image = img_tensor

        label = 1
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [None]:
#Fast Dataset Creation Function
def create_datasets_fast(N_spc, config, dataloaders_only = False, save_jpgs_only = False):
    
    #################################
    #Delete and recreate dataset directories
    N_datasets = np.shape(N_spc)[0]
    N_classes = np.shape(N_spc)[1]

    dataset_names_list = []
    
    if os.path.exists('temp_data'):
        shutil.rmtree('temp_data')
    os.mkdir('temp_data')

    for i in range(N_datasets):
        dataset_names_list += ["dataset_" + str(i)]
    
    if save_jpgs_only == True:
        dataset_names_list = ["FID_baseline_dataset_" + config.dataset_name]

    if dataloaders_only == False:

        for new_folder_name in dataset_names_list:
            # Specify the name of the new folder
            if os.path.exists(new_folder_name):
                shutil.rmtree(new_folder_name)
            # Create a new directory in the current working directory
            os.mkdir(new_folder_name)
            # print(f"Folder '{new_folder_name}' created successfully.")

        #################################
        #create datasets
        if config.dataset_name == 'image-net':
            dataset_orig = load_dataset("frgfm/imagenette", '320px', split="train")
        elif config.dataset_name == 'mnist':
            dataset_orig = load_dataset("ylecun/mnist", split="train")
        elif config.dataset_name == 'celeb-a':
            dataset_orig = load_dataset("tpremoli/CelebA-attrs", split="validation")
            dataset_orig = dataset_orig.rename_column('Male', 'label')


            
        
        for i in range(N_datasets):

            path = dataset_names_list[i]
            n = 0

            for j in range(N_classes):

                if N_spc[i, j] == 0:
                    continue
                
                
                if config.dataset_name == 'celeb-a':
                    ds_filtered = dataset_orig.filter(lambda example: example["label"] == int((j-0.5)*2))
                else:
                    ds_filtered = dataset_orig.filter(lambda example: example["label"] == j)
                ds_filtered = ds_filtered.filter(lambda example, idx: idx < N_spc[i, j], with_indices=True)
                ds_filtered = ds_filtered.with_format("torch")
            

                dataloader = DataLoader(ds_filtered, batch_size=1)

                trans = torchvision.transforms.Resize((config.image_size, config.image_size))

                
                for batch in dataloader:
                    
                    img = trans(batch['image'][0, :, :, :])

                    if img.shape[0] == 1 and config.dataset_name != 'mnist':
                        img = img.repeat(3, 1, 1)
        
                    if config.dataset_name == 'image-net':
                        torch.save(img, 'temp_data/' + 'sample_' + str(n) + '.pt')
                    if config.dataset_name != 'image-net' and save_jpgs_only == False:
                        img = img/255
                        # print(img)
                        torch.save(img, path + '/sample_' + str(n) + '.pt')
                        img = img*255

                    if save_jpgs_only == True:
                        img = img/255
                        img = img.double()
                        # if j == 0:
                        #     plt.imshow(img.permute(1, 2, 0).detach())
                        #     plt.show()
                        #     assert False
                        torchvision.utils.save_image(img, path + '/sample_' + str(n) + '.jpg')
                    
                    n += 1

                N = torch.tensor(n)
                torch.save(N, 'temp_data/N')
                torch.save(N, path + '/N')

                if save_jpgs_only == True:
                    continue

                if config.dataset_name == 'image-net':    

                    ds = CustomImageDataset(img_dir = 'temp_data')
                    dataloader = torch.utils.data.DataLoader(ds, batch_size = 16, shuffle = True)

                    device ='cuda:0'
                    vae.to(device)
                    model_encoder = vae.encode

                    n = 0

                    with torch.no_grad():
                        for batch in dataloader:
                            batch[0] = (batch[0]/255).to(device).detach()
                            lat = model_encoder(batch[0])
                            lat = lat.to_tuple()[0]
                            lat = (lat.mean).detach()
                            for k in range(lat.shape[0]):
                                torch.save(lat[k, :, :, :].detach(), path + '/sample_' + str(n) + '.pt')
                                n += 1
                
                
            N = torch.tensor(n)

            torch.save(N, path + '/N')
    #Now we create the dataloaders
    dataset_list = []
    preprocess = transforms.Compose(
    [
        # transforms.Resize((config.image_size, config.image_size)),
        # transforms.RandomHorizontalFlip(),
        # transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
    )   

    for i in range(N_datasets):
        if config.dataset_name != 'image-net':
            dataset_list += [CustomImageDataset(img_dir = dataset_names_list[i], transform = preprocess)]
        else:
            dataset_list += [CustomImageDataset(img_dir = dataset_names_list[i])]



    ######################################
    train_dataloader_list = []

    for i in range(len(dataset_list)):
        train_dataloader_list += [torch.utils.data.DataLoader(dataset_list[i], batch_size = config.primal_batch_size, shuffle = True)]

        
    return train_dataloader_list


In [None]:
#Model Setup

#A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
model = UNet2DModel(
    sample_size = config.latent_size,  # the target image resolution
    in_channels = config.diffusion_channels,  # the number of input channels, 3 for RGB images
    out_channels = config.diffusion_channels,  # the number of output channels, 3 for RGB images
    layers_per_block = 2,  # how many ResNet layers to use per UNet block
    block_out_channels = (config.architecture_size, config.architecture_size, 2*config.architecture_size, 2*config.architecture_size, 4*config.architecture_size, 4*config.architecture_size),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)

if config.adaptation == True:

    model_pretrained = UNet2DModel(
        sample_size = config.latent_size,  # the target image resolution
        in_channels = config.diffusion_channels,  # the number of input channels, 3 for RGB images
        out_channels = config.diffusion_channels,  # the number of output channels, 3 for RGB images
        layers_per_block = 2,  # how many ResNet layers to use per UNet block
        block_out_channels = (config.architecture_size, config.architecture_size, 2*config.architecture_size, 2*config.architecture_size, 4*config.architecture_size, 4*config.architecture_size),  # the number of output channels for each UNet block
        down_block_types=(
            "DownBlock2D",  # a regular ResNet downsampling block
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
            "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
            "DownBlock2D",
        ),
        up_block_types=(
            "UpBlock2D",  # a regular ResNet upsampling block
            "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
        ),
    )



#loading parameters of pre-trained model if needed
if config.load_model_header != 'None':
    device = torch.device('cpu')
    state_dict = torch.load(config.save_models_dir + '/trained_model_' + config.load_model_header + '.pt', map_location = device)
    # create new OrderedDict that does not contain `module.`
    from collections import OrderedDict
    new_state_dict = OrderedDict()

    for k, v in state_dict.items():
        name = k[7:] # remove `module.`
        new_state_dict[name] = v
    # load params

    if config.adaptation == True:
        model_pretrained.load_state_dict(new_state_dict)
        model_pretrained.eval()
    model.load_state_dict(new_state_dict)
    

    print('loaded model ' + config.load_model_header)

In [None]:
#Training loop function
def train_loop(config, git_folder_path, model, classifier, noise_scheduler, optimizer, train_dataloaders, lr_scheduler, mu_init, b_init, constrained = False, resilient = False, alpha = 0,  model_pretrained = None):
    
    # Initialize accelerator and tensorboard logging
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with="tensorboard",
        project_dir=os.path.join(config.output_dir, "logs"),
    )

    device = accelerator.device

    #initialize lagrange multipliers and constraint thresholds
    print(mu_init)
    print(b_init)
    mu = mu_init.to(device)
    b = b_init.to(device)

    model, optimizer, lr_scheduler = accelerator.prepare(
        model, optimizer, lr_scheduler
    )

    #Nc is the number of constraints here
    Nc = len(train_dataloaders)

    if constrained == False:
        Nc = 1

    mu[0] = 1
    
    dataloaders = []
    iterators = []

    #prepare all the dataloaders
    for k in range(Nc):
        dataloaders += [accelerator.prepare(train_dataloaders[k])]
        iterators += [cycle(dataloaders[k])]

    
    if config.adaptation == True:
        Nc = 2

    global_step = 0


    #initialize history tensors
    mu_hist = torch.zeros((Nc, int(config.num_epochs/config.primal_per_dual)))
    b_hist = torch.zeros((Nc, int(config.num_epochs/config.primal_per_dual)))
    constraint_hist = torch.zeros((Nc, int(config.num_epochs/config.primal_per_dual)))
    lagrangian_hist = torch.zeros(int(config.num_epochs/config.primal_per_dual))
    running_count_hist = np.zeros((int(int(config.num_epochs / config.save_image_epochs)/config.running_average_length), 10))

    
    i_hist = 0
    j = 0

    mu_hist[:, i_hist] = mu
    b_hist[:, i_hist] = b

    running_count = np.zeros(10)


    #normalizing constants used for image-net dataset since the autoencoder used is not properly normalized

    if config.dataset_name == 'image-net':
        a = -15
        b = 20
        gamma = 2/(b - a)
        beta = -(a + b)/(b - a)
        a = 0.431
        b = 36
    else:
        gamma = 1
        beta = 0

    if config.adaptation == True:
        with torch.no_grad():
            model_pretrained = model_pretrained.to(device)
    
    

    #Actually start training
    for epoch in range(config.num_epochs):

        progress_bar = tqdm(total = config.batches_per_epoch, disable = not accelerator.is_local_main_process, position = 0)
        progress_bar.set_description(f"Epoch {epoch}")


        for step in range(config.batches_per_epoch):

            Clean_images = []
            Noises = []
            Batch_sizes = []
            Timesteps = []
            Noisy_images = []
            Noise_preds = []

            for k in range(Nc):

                device = accelerator.device

                train_features = next(iterators[k])[0]
                Clean_images += [((train_features*gamma) + beta).to(device)]
                Noises += [torch.randn(Clean_images[k].shape, device = Clean_images[k].device)]
                Batch_sizes += [Clean_images[k].shape[0]]
                Timesteps += [torch.randint(
                            0, noise_scheduler.config.num_train_timesteps, (Batch_sizes[k],), device = Clean_images[k].device,
                            dtype = torch.int64
                        )]
                Noisy_images += [noise_scheduler.add_noise(Clean_images[k], Noises[k], Timesteps[k])]

                if config.adaptation == True and constrained == True:
                    break



            with accelerator.accumulate(model):

                Losses = []

                #The Lagrangian
                lag = torch.zeros(1, requires_grad = False).to(device)

                for l in range(Nc):

                    Noise_preds += model(Noisy_images[l], Timesteps[l], return_dict = False)

                    Losses += [F.mse_loss(Noise_preds[l], Noises[l])]

                    #The Lagrangian
                    lag += mu[l]*Losses[l]

                    if constrained == False:
                        if mu[l] != 1:
                            print('ERROR: Unconstrained Lagrange multiplier should be equal to 1')
                        break

                    if config.adaptation == True:
                        
                        Noise_preds += model_pretrained(Noisy_images[0], Timesteps[0], return_dict = False)

                        Losses += [F.mse_loss(Noise_preds[1], Noises[0])]

                        lag += mu[1]*Losses[1]

                        break


                accelerator.backward(lag)
                
                if ((epoch + 1)%config.primal_per_dual == 0):
                    lagrangian_hist[i_hist] = lag.detach()
                
                ############### DUAL STEP START #############
                current_lr = lr_scheduler.get_last_lr()[0]
                lr_dual = config.lr_dual_to_primal*current_lr

                if (epoch + 1)%config.primal_per_dual == 0 and constrained == True:

                    l = torch.tensor(Losses, requires_grad = False).to(device)

                    constraint_hist[:, i_hist] = (l - b).detach()

                    if resilient == False:

                        mu = (mu + lr_dual*(l - b)).detach()

                    elif resilient == True:
                        
                        mu = (mu + lr_dual*(l - 2*alpha*mu)).detach()

                    mu = (torch.nn.functional.relu(mu, inplace = True)).detach()

                    mu[0] = 1

                    mu = accelerator.reduce(mu, reduction = 'mean')

                    mu_hist[:, i_hist] = mu.detach()
                ################ DUAL STEP END ###############

                
                    
                if accelerator.is_main_process and (config.wandb_logging == True):
                
                    wandb.log({'Lagrangian': lag.detach().item()})
                    wandb.log({'lr_dual': lr_dual})
                    wandb.log({'lr_primal': current_lr})

                    for k in range(Nc):
                        wandb.log({('mu_' + str(k)) : mu[k].detach().item()})
                        wandb.log({('loss_' + str(k)) : Losses[k].detach().item()})

                accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                

            progress_bar.update(1)
            logs = {"loss": lag.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1

        # After each epoch you optionally sample some demo images with evaluate() and save the model

        if accelerator.is_main_process and config.evaluate == True:
            
            classifier = classifier.to(device)

            ##important
            pipeline = DDIMPipeline(unet = accelerator.unwrap_model(model), scheduler = noise_scheduler)
            
            
            if ((epoch + 1) % config.save_image_epochs == 0) and ((epoch + 1) % config.save_plot_epochs != 0):
                #just draw samples and show them in a grid
                running_count = evaluate(running_count, config, epoch, pipeline, mu_init, b_init, git_folder_path, accelerator, device, classifier_processor, classifier, classify = True)

            if epoch == config.num_epochs - 1:
                #In the last epoch sample lots of images and check their classes
                running_count_final_epoch = np.zeros(10)

                for k in range(20):
                    running_count_final_epoch = evaluate(running_count_final_epoch, config, epoch, pipeline, mu_init, b_init, git_folder_path, accelerator, device, classifier_processor, classifier, classify = True, upload_images = False)

                running_count_final_epoch = evaluate(running_count_final_epoch, config, epoch, pipeline, mu_init, b_init, git_folder_path, accelerator, device, classifier_processor, classifier, classify = True, upload_images = False, final_epoch = True)

            if ((epoch + 1) % config.save_plot_epochs == 0 ) or (epoch == config.num_epochs - 1):
                #draw samples and show them in a grid. Also update plots and add a class histogram and save model
                running_count = evaluate(running_count, config, epoch, pipeline, mu_init, b_init, git_folder_path, accelerator, device, classifier_processor, classifier, classify = True)

                #Save model
                torch.save(model.state_dict(), 'models/trained_model_' + config.save_model_header + '.pt')

                #######################plots of histories of important stuff#######################
                with torch.no_grad():

                    legend = []

                    for k in range(Nc):
                        plt.plot(mu_hist[k, :])
                        legend += [('mu_' + str(k + 1))]

                    plt.legend(legend)
                    plt.ylabel('mu')
                    plt.xlabel('# dual update step')
                    plt.title('dual variables history')
                    plt.savefig('mu_hist.png')
                    plt.close()

                    torch.save(mu_hist, 'mu_hist.pt')

                    up_to_hub('mu_hist.png', content_path = git_folder_path + '/mu_hist.png')
                    up_to_hub('mu_hist.pt', content_path = git_folder_path + '/mu_hist.pt')

                    #####################################################################################

                    legend = []

                    for k in range(Nc):
                        plt.plot(b_hist[k, :])
                        legend += [('b_' + str(k + 1))]
    
                    plt.ylabel('b')
                    plt.xlabel('# dual update step')
                    plt.title('constraint relaxations history')
                    plt.savefig('b_hist.png')
                    plt.close()

                    torch.save(b_hist, 'b_hist.pt')

                    up_to_hub('b_hist.png', content_path = git_folder_path + '/b_hist.png')
                    up_to_hub('b_hist.pt', content_path = git_folder_path + '/b_hist.pt')

                    #####################################################################################

                    legend = []

                    for k in range(Nc):
                        plt.plot(constraint_hist[k, :])
                        legend += [('(l-b)_' + str(k + 1))]

                    plt.legend(legend)
                    plt.ylabel('l-b')
                    plt.xlabel('# dual update step')
                    plt.title('constraint violations history')
                    plt.savefig('constraint_hist.png')
                    plt.close()

                    torch.save(constraint_hist, 'constraint_hist.pt')

                    up_to_hub('constraint_hist.png', content_path = git_folder_path + '/constraint_hist.png')
                    up_to_hub('constraint_hist.pt', content_path = git_folder_path + '/constraint_hist.pt')

                    #####################################################################################
                    plt.plot(lagrangian_hist)
                    plt.ylabel('lagrangian')
                    plt.xlabel('# dual update step')
                    plt.title('lagrangian history')
                    plt.savefig('lagrangian_hist.png')
                    plt.close()

                    torch.save(lagrangian_hist, 'lagrangian_hist.pt')

                    up_to_hub('lagrangian_hist.png', content_path = git_folder_path + '/lagrangian_hist.png')
                    up_to_hub('lagrangian_hist.pt', content_path = git_folder_path + '/lagrangian_hist.pt')

                    #####################################################################################
                    
                    for c in range(10):
                        
                        if c != 4:
                            plt.plot(running_count_hist[:, c], alpha = 0.25)
                        elif c == 4:
                            plt.plot(running_count_hist[:, c], alpha = 1)

                    plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'])
                    plt.ylabel('number of Generations')
                    plt.xlabel('# dual update step')
                    plt.title('number of Generations')
                    plt.savefig('running_count_hist.png')
                    plt.close()

                    torch.save(running_count_hist, 'running_count_hist.pt')

                    up_to_hub('running_count_hist.png', content_path = git_folder_path + '/running_count_hist.png')
                    up_to_hub('running_count_hist.pt', content_path = git_folder_path + '/running_count_hist.pt')



            if ((epoch + 1) % (config.save_image_epochs*config.running_average_length) == 0):

                running_count_hist[j, :] = running_count
                j += 1
                running_count[:] = 0


        if ((epoch + 1)%config.primal_per_dual == 0):
            i_hist += 1

In [None]:
# Run

######################################################################
# An example of dataset creation for an adaptation task
# N_datasets = 1 # a single dataset for the model to adapt to
# N_spc = np.zeros((N_datasets, 10)) # a tensor specifying the number of Samples Per Class (spc) for each dataset
# N_spc[0, 3] = 256
######################################################################

######################################################################
# An example of dataset creation for minority class constraints setting
N_datasets = 4 # 1 dataset for the main objective + 1 for each constraint
N_spc = np.zeros((N_datasets, 10)) # a tensor specifying the number of Samples Per Class (spc) for each dataset
N_spc[0, :] = 256

#underrepresented classes have fewer training samples
N_spc[0, 4] = 64
N_spc[0, 5] = 64
N_spc[0, 7] = 64

#each constraint dataset consists of samples from a single minority class
N_spc[1, 4] = 64
N_spc[2, 5] = 64
N_spc[3, 7] = 64
######################################################################

Train_dataloaders = create_datasets_fast(N_spc, config, False, False)
################HYPERPARAMS##################
eta_resilience = 0.1

config.num_epochs = 500
config.batches_per_epoch = int((np.sum(N_spc[0, :])/config.primal_batch_size)/2)
if config.batches_per_epoch == 0:
    config.batches_per_epoch = 1
# config.batches_per_epoch = 1
config.primal_per_dual = 2
config.save_image_epochs = 100
config.save_plot_epochs = 500
config.running_average_length = 5
config.lr_primal = 0.0001

# mu_init_scalar = 0
mu_init_scalar = 0.0
b_init_scalar = 0.0

if config.adaptation == True:
    b_init = b_init_scalar*torch.ones(2)
    mu_init = torch.tensor([1, 0])
else:
    mu_init = mu_init_scalar*torch.zeros(N_datasets)
    mu_init[0] = 1
    b_init = b_init_scalar*torch.ones(N_datasets)

constrained = True
resilient = True

alpha = 0.085

#OPTIMIZER + NOISE SCHEDULER + LR_SCHEDULER
noise_scheduler = DDPMScheduler(num_train_timesteps = 1000)
optimizer = torch.optim.AdamW(model.parameters(), lr = config.lr_primal)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer = optimizer,
    num_warmup_steps = config.lr_warmup_steps,
    num_training_steps = (config.batches_per_epoch * config.num_epochs),
    num_cycles = 0.25
)


############################################

git_folder_path = 'FOLDER_NAME'
config.save_model_header = 'MODEL_NAME'

if config.wandb_logging == True:
    wandb.init(
        project = 'PROJECT_NAME',
        config = config
    )

# Open the file in write mode and save the string
with open('run_info.txt', 'w') as file:
    file.write('num_epochs = ' + str(config.num_epochs) + "\n")
    file.write('batches_per_epoch = ' + str(config.batches_per_epoch) + "\n")
    file.write('primal_batch_size = ' + str(config.primal_batch_size) + "\n")
    file.write('dataset_size = ' + str(np.sum(N_spc[0, :]))  + "\n")
    file.write('primal_per_dual = ' + str(config.primal_per_dual) + "\n")
    file.write('save_image_epochs = ' + str(config.save_image_epochs) + "\n")
    file.write('save_plot_epochs = ' + str(config.save_plot_epochs) + "\n")
    file.write('running_average_length_for_classification_of_generated_samples = ' + str(config.running_average_length) + "\n")
    file.write('lr_dual_to_primal = ' + str(config.lr_dual_to_primal) + "\n")
    file.write('eta_resilience = ' + str(eta_resilience) + "\n")
    file.write('eta_main = ' + str(config.lr_primal)  + "\n")
    file.write('alpha (const relaxation cost) = ' + str(alpha) + "\n")
    file.write('initial mu = ' + str(mu_init) + "\n")
    file.write('initial b = ' + str(b_init) + "\n")
    file.write('save_model_header = ' + str(config.save_model_header)  + "\n")
    file.write('load_model_header = ' + str(config.load_model_header)  + "\n")
    file.write('constrained = ' + str(constrained)  + "\n")
    file.write('resilient = ' + str(resilient)  + "\n")



up_to_hub('run_info.txt', content_path = git_folder_path + '/run_info.txt')

# def train_loop(config, git_folder_path, model, classifier, noise_scheduler, optimizer, train_dataloaders, lr_scheduler, mu_init, b_init, constrained = False, resilient = False, alpha = 0,  model_pretrained = None):
if config.adaptation == True:
    args = (config, git_folder_path, model, classifier, noise_scheduler, optimizer, Train_dataloaders, lr_scheduler, mu_init, b_init, constrained, resilient, alpha, model_pretrained)
else:
    args = (config, git_folder_path, model, classifier, noise_scheduler, optimizer, Train_dataloaders, lr_scheduler, mu_init, b_init, constrained, resilient, alpha)


notebook_launcher(train_loop, args, num_processes = config.num_gpus)

#Ignore if not having issues with server ports
# portnum = 8000
# notebook_launcher(train_loop, args, num_processes = config.num_gpus, use_port= str(portnum))

if config.wandb_logging == True:
    wandb.finish()